簡體   English   中英

在 Tensorflow 2.0 的 tf.function input_signature 中使用字典

[英]Use dictionary in tf.function input_signature in Tensorflow 2.0

我正在使用 Tensorflow 2.0 並面臨以下情況:

@tf.function
def my_fn(items):
    .... #do stuff
    return

如果 items 是張量的字典,例如:

item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}

有沒有辦法使用 tf.function 的 input_signature 參數,以便當 item1 是例如tf.zeros([2,1])時,我可以強制 tf2 避免創建多個圖形?

輸入簽名必須是一個列表,但列表中的元素可以是字典或張量規范列表。 在你的情況下,我會嘗試:( name屬性是可選的)

signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
                   "item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") } 
              

# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
    .... # do stuff
    return

# calling the TensorFlow function
my_fun(items)

但是,如果要調用由my_fn創建的特定具體函數,則必須解壓縮字典。 您還必須在tf.TensorSpec提供name屬性。

# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs 
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
                                             
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)

這很煩人,但應該在 TensorFlow 2.3 中解決。 (參見“具體函數”TF 指南的結尾)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM