繁体   English   中英

使用@tf.function 将输入转换为 UTF-8

[英]Converting an input to UTF-8 with @tf.function

我正在使用 TensorFlow 2.4.1。 我尝试使用 tf.strings.unicode_decode 来解码带有@tf.function 的 base64 编码字符串,但是发生了错误,其中 ValueError: input的等级必须是静态已知的。 我检查了 tf.strings.unicode_decode 在没有@tf.function 的情况下可以正常工作。 有没有办法用@tf.function 解码 base64 编码字符串? 我会很感激你的回答。

我加载了 SavedModel 并想更改serving_default 但我被困在将输入转换为UTF-8 这是我尝试过的代码。

class CustomTransformer(tf.keras.Model):
    def __init__(self):
        super(CustomTransformer, self).__init__()
        self.model = tf.saved_model.load('./models/transformer/1')
  
    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.string)])
    def call(self, input):

        # Error occurred. ValueError: Rank of `input` must be statically known.
        _input_str = tf.strings.unicode_decode(input_data, 'UTF-8')

        return _input_str

这是错误消息。

ValueError: Rank of `input` must be statically known.

尝试从加载的 SavedModel 更改UTF-8时,有没有办法将输入转换为serving_default

使用tf.strings.unicode_decode时,您需要指定一个形状。 (请参阅文档)。 在这种情况下,因为您使用的是没有任何维度(简单字符串)的张量,所以只需提供一个空元组作为形状:

@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.string)])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM