繁体   English   中英

IndexError用1D数组索引2D数组(NumPy)

[英]IndexError indexing a 2D array with a 1D array (NumPy)

我有一个NumPy标签数组:

labels = np.ndarray(10000, dtype=np.float32)

数组中的元素如下所示:

print(labels[1:5])
Output: [ 9.  9.  4.  1.]

我想将它们转换为一个热编码标签,并且使用了以下代码:

one_hot_labels = np.eye(10)[labels]

我收到以下错误:

IndexError     Traceback (most recent call last)
<ipython-input-21-dccf85afc031> in <module>()
  1 
----> 2 s=np.eye(10)[labels]

IndexError: arrays used as indices must be of integer (or boolean) type

我该如何解决?

您已将标签定义为np.float32 如果将它们用作数组或矩阵的索引,则它们必须是整数。 要转换np.float32使用.astype(int)

 one_hot_labels=np.eye(10)[labels.astype(int)]

或直接将标签定义为int:

labels=np.ndarray(10000,dtype=int)

如果labelsfloat而你不想改变它的dtype ,你可以简单地使用MultiLabelBinarizer 此代码段应完成工作:

from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer()
one_hot_labels = mlb.fit_transform(labels[:, None])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM