[英]Pytorch copy a neuron in a layer
I am using pytorch 0.3.0.我正在使用 pytorch 0.3.0。 I'm trying to selectively copy a neuron and it's weights within the same layer, then replace the original neuron with an another set of weights.
我试图有选择地复制一个神经元和它在同一层内的权重,然后用另一组权重替换原始神经元。 Here's my attempt at that:
这是我的尝试:
reshaped_data2 = data2.unsqueeze(0)
new_layer_data = torch.cat([new_layer.data, reshaped_data2], dim=0)
new_layer_data[i] = data1
new_layer.data.copy_(new_layer_data)
First I unsqueezed data2
to make it a 1*X
tensor instead of 0*X
.首先,我解压
data2
使其成为1*X
张量而不是0*X
。 Then I concatenate my layer's tensor with the reshaped data2
along dimension 0. I then replace the original data2
located at index i
with data1
.然后我将我的层的张量与沿维度 0 重新整形的
data2
连接起来。然后我将位于索引i
的原始data2
替换为data1
。 Finally, I copy all of that into my layer.最后,我将所有这些复制到我的图层中。
The error I get is:我得到的错误是:
RuntimeError: inconsistent tensor size, expected tensor [10 x 128] and src [11 x 128] to have the same number of elements, but got 1280 and 1408 elements respectively at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorCopy.c:86
If I do a simple assignment instead of copy I get如果我做一个简单的任务而不是复制,我会得到
RuntimeError: The expanded size of the tensor (11) must match the existing size (10) at non-singleton dimension 1. at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensor.c:309
I understand the error, but what is the right way to go about this?我理解错误,但是解决此问题的正确方法是什么?
You're trying to replace a 10x128
tensor with a 11x128
tensor, which the model doesn't allow.你试图更换
10x128
与张11x128
张量,该模型不允许。 Is new_layer
initialised with the size (11, 128)
? new_layer
是否用大小(11, 128)
初始化? If not, try creating your new layer with your desired size (11, 128)
and then copy/assign your new_layer_data
.如果没有,请尝试使用所需大小
(11, 128)
创建新图层,然后复制/分配new_layer_data
。
The solution here is to create a new model with the correct size and pass in weights as default values.这里的解决方案是创建一个具有正确大小的新模型,并将权重作为默认值传递。 No dynamic expansion solution was found.
没有找到动态扩展解决方案。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.