簡體   English   中英

Pydantic:類型提示 tensorflow 張量

[英]Pydantic: Type hinting tensorflow tensor

關於如何使用 pydantic 鍵入提示 tf 張量的任何想法??。 試過默認的 tf.Tensor

RuntimeError: no validator found for <class 'tensorflow.python.framework.ops.Tensor'>, see `arbitrary_types_allowed` in Config

和 tf.flaot32

RuntimeError: error checking inheritance of tf.float32 (type: DType)

查看 pydantic 中的文檔,我相信需要定義像這樣任意的 class 之類的東西......

class Tensor:
    def __init__(self, Tensor):

        self.Tensor = Union[
            tensorflow.python.framework.ops.Tensor,
            tensorflow.python.framework.sparse_tensor.SparseTensor,
            tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor,
            tensorflow.python.framework.ops.EagerTensor,
        ]

主要有以下..

 class Main(BaseModel):
     tensor : Tensor


 class Config:
    arbitary_types_allowed = True

我會試試這個作為一個最小的例子:

from pydantic import BaseModel
import tensorflow as tf


class MyTensor(BaseModel):

    tensor: tf.Tensor

    class Config:
        arbitrary_types_allowed = True

請注意, Config類實際上是MyTensor中的子類。

工作代碼:

from pydantic import BaseModel
import tensorflow as tf


class MyTensor(BaseModel):

    tensor: tf.Variable

    class Config:
        arbitrary_types_allowed = True

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM