簡體   English   中英

張量值的 function 生成此錯誤:“false_fn”必須是可調用的

[英]The function for tensor value generates this Error: 'false_fn' must be callable

我正在創建一個 function,它采用張量值並通過應用以下公式返回結果,有 3 個條件,所以我使用 @tf.functions。

def Spa(x):
    x= tf.convert_to_tensor(float(x), dtype=tf.float32)
    p= tf.convert_to_tensor(float(0.05), dtype=tf.float32)
    
    p_dash=x
    K = p*logp_dash
    Ku=K.sum(Ku)
    
    Ku= tf.convert_to_tensor(float(Ku), dtype=tf.float32)
    

    y= tf.convert_to_tensor(float(0), dtype=tf.float32)
    def a(): return tf.constant(0)

    r = tf.case([(tf.less(x, y), a), (tf.greater(x, Ku), a)], default=x, exclusive=False)
    return r

該代碼生成以下錯誤: “false_fn”必須是可調用的。 我做了很多轉換,int 到 float 和 float 到 int 但不知道是什么問題。

tf.case x 應該是 function 而不是張量。 TF 頁面中所述

def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)], default=f3, exclusive=True)

可以看到這里的f3也是一個function。

它應該是這樣的。

x = tf.where(x < 0, tf.zeros_like(x), x)
p = 0.05
p_hat = x
KL_divergence = p * (tf.math.log(p / p_hat)) + (1 - p) * (tf.math.log(1 - p / 1 - p_hat))
x = tf.where(x < KL_divergence, tf.zeros_like(x), x)
return x

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM