![](/img/trans.png)
[英]Replace all the value except for the one a specified index of a numpy matrix with zero
[英]Efficent way of constructing a matrix with all elements zero except one in numpy
我想通过比较输出信号及其真实输出值来计算每个输入的神经网络输出误差,因此我需要两个矩阵来计算此任务。
我有一个形状为(n * 1)的输出矩阵,但是在标签中,我只有应该激活的神经元的索引,因此我需要一个形状相同的矩阵,所有元素都等于零,但索引为等于标签。 我可以用一个函数来做到这一点,但我想知道在numpy
python中有一个内置的方法可以帮我吗?
您可以使用numpy或标准库以多种方式执行此操作,一种方法是创建零数组,并将index对应的值设置为1。
n = len(result)
a = np.zeros((n,));
a[id] = 1
它可能也是最快的一种:
>> %timeit a = np.zeros((n,)); a[id] = 1
1000000 loops, best of 3: 634 ns per loop
另外,您可以使用numpy.pad将零填充[1]数组。 但这几乎肯定会由于填充逻辑而变慢。
np.lib.pad([1],(id,n-id),'constant', constant_values=(0))
如预期的数量级变慢:
>> %timeit np.lib.pad([1],(id,n-id),'constant', constant_values=(0))
10000 loops, best of 3: 47.4 µs per loop
您可以尝试按照注释的建议进行列表理解:
results = [7]
np.matrix([1 if x == id else 0 for x in results])
但这也比第一种方法慢得多:
>> %timeit np.matrix([1 if x == id else 0 for x in results])
100000 loops, best of 3: 7.25 µs per loop
编辑:但是在我看来,如果要计算神经网络误差。 您应该只使用np.argmax并计算它是否成功。 错误计算可能会给您带来更多杂讯,而不是有用的。 如果您觉得自己的网络容易相似,则可以创建一个混淆矩阵。
其他一些方法似乎也比上述@umutto慢:
%timeit a = np.zeros((n,)); a[id] = 1 #umutto's method
The slowest run took 45.34 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.53 µs per loop
布尔构造:
%timeit a = np.arange(n) == id
The slowest run took 13.98 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.76 µs per loop
布尔构造为整数:
%timeit a = (np.arange(n) == id).astype(int)
The slowest run took 15.31 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 5.47 µs per loop
清单结构:
%timeit a = [0]*n; a[id] = 1; a=np.asarray(a)
10000 loops, best of 3: 77.3 µs per loop
使用scipy.sparse
%timeit a = sparse.coo_matrix(([1], ([id],[0])), shape=(n,1))
10000 loops, best of 3: 51.1 µs per loop
现在实际的速度可能取决于所缓存的内容,但似乎构造零数组可能最快,尤其是如果您可以使用np.zeros_like(result)
而不是np.zeros(len(result))
一班轮:
x = np.identity(n)[id]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.