[英]ValueError with ReLU function in python
我這樣聲明 ReLU function:
def relu(x):
return (x if x > 0 else 0)
並且發生了 ValueError 並且其回溯消息是
ValueError:具有多個元素的數組的真值不明確。 使用 a.any() 或 a.all()
但是,如果我用 numpy 更改 ReLU function,它可以工作:
def relu_np(x):
return np.maximum(0, x)
為什么這個函數( relu(x)
)不起作用? 我不明白...
=================================
使用的代碼:
>>> x = np.arange(-5.0, 5.0, 0.1)
>>> y = relu(x)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "filename", line, in relu
return (x if x > 0 else 0)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
TLDR; 您的第一個 function 沒有使用矢量化方法,這意味着它需要一個浮點/整數值作為輸入,而您的第二個 function 則利用了 Numpy 的矢量化。
您的第二個 function 使用 numpy 函數,這些函數被矢量化並在數組的每個單獨元素上運行。
import numpy as np
arr = np.arange(-5.0, 5.0, 0.5)
def relu_np(x):
return np.maximum(0, x)
relu_np(arr)
# array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.5, 1. ,
# 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5])
但是,您的第二個 function 使用三元運算符(x if x > 0 else 0)
,它需要單個值輸入並輸出單個值。 這就是為什么當您傳遞單個元素時,它會起作用,但在傳遞數組時,它無法在每個元素上獨立運行 function。
def relu(x):
return (x if x > 0 else 0)
relu(-8)
## 0
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
注意:此錯誤的原因是由於您使用的三元運算符
(x if x > 0 else 0)
。 對於給定的整數/浮點值,條件x>0
只能取值True
或False
。 但是,當您傳遞一個數組時,您需要使用類似any()
或all()
將 boolean 值列表聚合為一個值,然后才能應用 if, else 子句。
有幾種方法可以使這項工作 -
import numpy as np
arr = np.arange(-5.0, 5.0, 0.5)
def relu(x):
return (x if x > 0.0 else 0.0)
relu_vec = np.vectorize(relu)
relu_vec(arr)
# array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.5, 1. ,
# 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5])
import numpy as np
arr = np.arange(-5.0, 5.0, 0.5)
def relu(x):
return (x if x > 0 else 0)
arr = np.array(arr)
np.array([relu(i) for i in arr])
# array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.5, 1. ,
# 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5])
請記住, x > 0
是一個布爾數組,如果您願意,可以使用掩碼:
array([False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False,
False, False, False, False, False, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True,
True])
所以if x>0
沒有意義,因為 x 包含多個元素,可以是True
或False
。 這是你錯誤的根源。
numpy 你的第二個實現好不好? 另一個實現(也許更清楚:)可能是:
def relu(x):
return x * (x > 0)
在這個實現中,如果 x 的元素低於 0,我們對 x(沿 x 軸的一系列值)進行元素乘法,如果 x 的元素高於 0,則乘以 1。
免責聲明:如果我錯了,請有人糾正我,我不能 100% 確定 numpy 是如何做事的。
您的 function relu
需要一個數值並將其與 0 進行比較並返回較大的值。 x if x > 0 else 0
將等於max(x, 0)
其中max
是內置 Python function。
另一方面, relu_np
使用 numpy function maximum
,它接受 2 個數字或arrays或iterables。 這意味着您可以傳遞您的 numpy 數組x
並將最大 function 自動應用於每個項目。 我相信這被稱為“矢量化”。
要使relu
function 以它的方式工作,您需要以不同的方式調用它。 您必須手動將 function 應用於每個元素。 你可以做類似y = np.array(list(map(relu, x)))
的事情。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.