簡體   English   中英

tf.py_func,自定義tensorflow函數僅應用於張量中的第一個元素

[英]tf.py_func , custom tensorflow function getting applied to only the first element in the tensor

我是tensorflow的新手,正在玩一個深度學習網絡。 我想在每次迭代后對所有權重進行自定義四舍五入。 由於tensorflow庫中的舍入函數無法為您提供將值舍入到小數位數的選項。 所以我寫了這個

import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops

np_prec = lambda x: np.round(x,3).astype(np.float32)
def tf_prec(x,name=None):
     with ops.name_scope( "d_spiky", name,[x]) as name:
          y = tf.py_func(np_prec,
                         [x],
                         [tf.float32],
                         name=name,
                         stateful=False)
          return y[0]
with tf.Session() as sess:

     x = tf.constant([0.234567,0.712,1.2,1.7])
     y = tf_prec(x)
     y = tf_prec(x)
     tf.global_variables_initializer

     print(x.eval(), y.eval())

我得到的輸出是這個

[ 0.234567    0.71200001  1.20000005  1.70000005] [ 0.235       0.71200001  1.20000005  1.70000005]

因此,自定義舍入僅在張量中的第一項上起作用,並且我不確定自己做錯了什么。 提前致謝。

由於以下行,此處出現錯誤,

np_prec = lambda x: np.round(x,3).astype(np.float32)

您正在將輸出轉換為np.float32 您可以通過以下代碼驗證錯誤,

print(np.round([0.234567,0.712,1.2,1.7], 3).astype(np.float32)) #prints [ 0.235       0.71200001  1.20000005  1.70000005]

np.round的默認輸出是float64 此外,還必須將tf.py_func中Tout參數更改float64

我已通過以下修復程序提供了以下代碼,並在必要時進行了注釋。

import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops

np_prec = lambda x: np.round(x,3)
def tf_prec(x,name=None):
     with ops.name_scope( "d_spiky", name,[x]) as name:
          y = tf.py_func(np_prec,
                         [x],
                         [tf.float64], #changed this line to tf.float64
                         name=name,
                         stateful=False)
          return y[0]
with tf.Session() as sess:

     x = tf.constant([0.234567,0.712,1.2,1.7],dtype=np.float64) #specify the input data type np.float64
     y = tf_prec(x)
     y = tf_prec(x)
     tf.global_variables_initializer

     print(x.eval(), y.eval())

希望這可以幫助。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM