![](/img/trans.png)
[英]TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [None, 16, 16384, 128]. Consider casting elements to a supported type
[英]Saving SentencepieceTokenizer in Keras model throws TypeError: Failed to convert elements of [None, None] to Tensor
我正在尝试保存使用SentencepieceTokenizer
的 Keras model 。
到目前为止一切正常,但我无法保存 Keras model。
在训练完sentencepiece
model 之后,我正在创建 Keras model,首先用一些示例调用它,然后尝试像这样保存它:
proto = tf.io.gfile.GFile(model_path, "rb").read()
model = Model(tokenizer=proto)
embed = model(examples)
assert embed.shape[0] == len(examples)
model.save("embed_model")
model 本身是直截了当的,看起来像这样:
class Model(keras.Model):
def __init__(self, tokenizer: spm.SentencePieceProcessor, embed_size: int = 32, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = tf_text.SentencepieceTokenizer(model=tokenizer, nbest_size=1)
self.embeddings = layers.Embedding(input_dim=self.tokenizer.vocab_size(), output_dim=embed_size)
def call(self, inputs, training=None, mask=None):
x = self.tokenizer.tokenize(inputs)
if isinstance(x, tf.RaggedTensor):
x = x.to_tensor()
x = self.embeddings(x)
return x
我得到的错误是:
TypeError: Failed to convert elements of [None, None] to Tensor.
Consider casting elements to a supported type.
See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
在我看来,好像 model 在调用model.save()
之后实际上是用model([None, None])
调用的。
准确地说,错误似乎发生在ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
:
E TypeError: Exception encountered when calling layer "model" (type Model).
E
E in user code:
E
E File "/home/sfalk/workspaces/technical-depth/ris-ml/tests/ris/ml/text/test_tokenizer.py", line 20, in call *
E x = self.tokenizer.tokenize(inputs)
E File "/home/sfalk/miniconda3/envs/ris-ml/lib/python3.10/site-packages/tensorflow_text/python/ops/sentencepiece_tokenizer.py", line 133, in tokenize *
E input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
E
E TypeError: Failed to convert elements of [None, None] to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
E
E
E Call arguments received by layer "model" (type Model):
E • inputs=['None', 'None']
E • training=False
E • mask=None
/tmp/__autograph_generated_file99ftv9jw.py:22: TypeError
也许尝试为call
方法定义一个input_signature
。 也调用self.tokenizer.vocab_size().numpy()
而不是self.tokenizer.vocab_size()
,因为急切的张量是不可序列化的:
import tensorflow as tf
import tensorflow_text as tf_text
import requests
url = "https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/test_data/test_oss_model.model?raw=true"
sp_model = requests.get(url).content
class Model(tf.keras.Model):
def __init__(self, tokenizer, embed_size: int = 32, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = tf_text.SentencepieceTokenizer(model=tokenizer, nbest_size=1)
self.embeddings = tf.keras.layers.Embedding(input_dim=self.tokenizer.vocab_size().numpy(), output_dim=embed_size)
@tf.function(input_signature=(tf.TensorSpec([None], tf.string), tf.TensorSpec([None], tf.int32)))
def call(self, inputs, mask=None):
x = self.tokenizer.tokenize(inputs)
if isinstance(x, tf.RaggedTensor):
x = x.to_tensor()
x = self.embeddings(x)
return x
model = Model(sp_model)
embed = model(["What you know you can't explain, but you feel it."], training=False, mask=[1, 1, 1, 1, 0])
model.save("embed_model")
请注意,我从call
方法中删除了training
参数,因为它已经退出。 此外,如果您可以在构造函数中设置self.built=True
,那么您不必在实际数据上调用 model ,但这取决于您:
class Model(tf.keras.Model):
def __init__(self, tokenizer, embed_size: int = 32, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = tf_text.SentencepieceTokenizer(model=tokenizer, nbest_size=1)
self.embeddings = tf.keras.layers.Embedding(input_dim=self.tokenizer.vocab_size().numpy(), output_dim=embed_size)
self.built = True
@tf.function(input_signature=(tf.TensorSpec([None], tf.string), tf.TensorSpec([None], tf.int32)))
def call(self, inputs, mask=None):
...
return x
model = Model(sp_model)
model.save("embed_model")
哦,您可能需要根据您使用的标记器 model 更改input_signature
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.