简体   繁体   English

正确获取 numpy 数组的维度以绘制转换后的灰度图像

[英]Getting the dimensions of a numpy array right to plot converted greyscale image

as part of Unity's ML Agents images fed to a reinforcement learning agent can be converted to greyscale like so:作为 Unity 的 ML Agents 的一部分,馈送到强化学习代理的图像可以转换为灰度,如下所示:

def _process_pixels(image_bytes=None, bw=False):
    s = bytearray(image_bytes)
    image = Image.open(io.BytesIO(s))
    s = np.array(image) / 255.0
    if bw:
        s = np.mean(s, axis=2)
        s = np.reshape(s, [s.shape[0], s.shape[1], 1])
    return s

As I'm not familiar enough with Python and especially numpy, how can I get the dimensions right for plotting the reshaped numpy array?由于我对 Python 尤其是 numpy 不够熟悉,我怎样才能获得正确的尺寸来绘制重塑的 numpy 数组? To my understanding, the shape is based on the image's width, height and number of channels.据我了解,形状基于图像的宽度、高度和通道数。 So after reshaping there is only one channel to determine the greyscale value.所以整形后只有一个通道来确定灰度值。 I just didn't find a way yet to plot it yet.我只是还没有找到一种方法来绘制它。

Here is a link to the mentioned code of the Unity ML Agents repository .这是Unity ML Agents 存储库中提到的代码的链接。

That's how I wanted to plot it:这就是我想绘制它的方式:

plt.imshow(s)
plt.show()

Won't just doing this work?不就是做这个工作吗?

plt.imshow(s[..., 0])
plt.show()

Explanation说明

plt.imshow expects either a 2-D array with shape (x, y) , and treats it like grayscale, or dimensions (x, y, 3) (treated like RGB) or (x, y, 4) (treated as RGBA). plt.imshow需要形状为(x, y)的二维数组,并将其视为灰度,或维度(x, y, 3) (视为 RGB)或(x, y, 4) (视为 RGBA )。 The array you had was (x, y, 1) .您拥有的数组是(x, y, 1) To get rid of the last dimension we can do Numpy indexing to remove the last dimension.为了摆脱最后一个维度,我们可以使用 Numpy 索引来删除最后一个维度。 s[..., 0] says, "take all other dimensions as-is, but along the last dimension, get the slice at index 0". s[..., 0]表示,“按原样获取所有其他维度,但沿着最后一个维度,获取索引 0 处的切片”。

It looks like the grayscale version has an extra single dimension at the end.看起来灰度版本最后有一个额外的单一维度。 To plot, you just need to collapse it, eg with np.squeeze :要绘制,您只需要折叠它,例如使用np.squeeze

plt.imshow(np.squeeze(s))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM