[英]Pydantic: Type hinting tensorflow tensor
關於如何使用 pydantic 鍵入提示 tf 張量的任何想法??。 試過默認的 tf.Tensor
RuntimeError: no validator found for <class 'tensorflow.python.framework.ops.Tensor'>, see `arbitrary_types_allowed` in Config
和 tf.flaot32
RuntimeError: error checking inheritance of tf.float32 (type: DType)
查看 pydantic 中的文檔,我相信需要定義像這樣任意的 class 之類的東西......
class Tensor:
def __init__(self, Tensor):
self.Tensor = Union[
tensorflow.python.framework.ops.Tensor,
tensorflow.python.framework.sparse_tensor.SparseTensor,
tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor,
tensorflow.python.framework.ops.EagerTensor,
]
主要有以下..
class Main(BaseModel):
tensor : Tensor
class Config:
arbitary_types_allowed = True
我會試試這個作為一個最小的例子:
from pydantic import BaseModel
import tensorflow as tf
class MyTensor(BaseModel):
tensor: tf.Tensor
class Config:
arbitrary_types_allowed = True
請注意, Config
類實際上是MyTensor
中的子類。
工作代碼:
from pydantic import BaseModel
import tensorflow as tf
class MyTensor(BaseModel):
tensor: tf.Variable
class Config:
arbitrary_types_allowed = True
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.