簡體   English   中英

如何使用tf.where()根據條件替換特定值

[英]How to replace particular values based on condition by using tf.where()

我想替換條件下的值。
NumPy版本會像這樣

intensity=np.where(
  np.abs(intensity)<1e-4,
  1e-4,
  intensity)

但是TensorFlow對tf.where()的用法有些不同
當我嘗試這個

intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  1e-4,
  intensity)

我得到這個錯誤

ValueError: Shapes must be equal rank, but are 0 and 4 for 'Select' (op: 'Select') with input shapes: [?,512,512,1], [], [?,512,512,1].

這是否意味着我應該為1e-4 4維張量?

以下代碼傳遞了錯誤

# Create an array which has small value (1e-4),  
# whose shape is (2,512,512,1)
small_val=np.full((2,512,512,1),1e-4).astype("float32")

# Convert numpy array to tf.constant
small_val=tf.constant(small_val)

# Use tf.where()
intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  small_val,
  intensity)

# Error doesn't occur
print(intensity.shape)
# (2, 512, 512, 1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM