簡體   English   中英

如何修復 pytorch conv2d function 的錯誤?

[英]How to fix error with pytorch conv2d function?

我正在嘗試在這兩個張量上使用 conv2d function :

Z = np.random.choice([0,1],size=(100,100))
Z = torch.from_numpy(Z).type(torch.FloatTensor)

print(Z)

tensor([[0., 0., 1.,  ..., 1., 0., 0.],
        [1., 0., 1.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 0., 1., 1.],
        ...,
        [1., 0., 1.,  ..., 1., 1., 1.],
        [1., 0., 1.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 1., 0., 0.]

filters = torch.tensor(np.array([[1,1,1],
                        [1,0,1],
                        [1,1,1]]), dtype=torch.float32)

print(filters)

tensor([[1., 1., 1.],
        [1., 0., 1.],
        [1., 1., 1.]])

但是當我嘗試做torch.nn.functional.conv2d(Z,filters)這個錯誤返回:

RuntimeError: weight should have at least three dimensions

我真的不明白這里有什么問題。 如何解決?

torch.nn.functional.conv2d(input, weight)的輸入應該是

在此處輸入圖像描述

您可以使用unsqueeze()添加虛假批次和通道尺寸,從而具有尺寸:輸入: (1, 1, 100, 100)和重量: (1, 1, 3, 3)

torch.nn.functional.conv2d(Z.unsqueeze(0).unsqueeze(0), filters.unsqueeze(0).unsqueeze(0))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM