[英]How to get a flattened view of PyTorch model parameters?
我想要屬於我的 Pytorch 模型的參數的平面視圖。 請注意,它應該是一個視圖而不是參數的副本。 換句話說,當我修改視圖中的參數時,它也應該修改模型參數。 我可以得到模型參數如下:
import torch
model = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.Tanh(),
torch.nn.Linear(10, 1)
)
params = list(model.parameters())
for p in params:
print(p)
這里的params
是一個張量列表。 我需要它是所有參數的一維張量。 做起來很簡單
params = torch.cat([p.flatten() for p in model.parameters()])
print(params.shape) # torch.Size([31])
但是,現在修改params
中的參數不會更改實際的模型參數(因為torch.cat()
復制內存)。 是否可以獲得模型參數的一維張量視圖?
您應該能夠通過首先在 1d 中構造參數張量然后將視圖復制到模型中來實現此目的:
import torch
model = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.Tanh(),
torch.nn.Linear(10, 1)
)
def fuse_parameters(model):
"""Move model parameters to a contiguous tensor, and return that tensor."""
n = sum(p.numel() for p in model.parameters())
params = torch.zeros(n)
i = 0
for p in model.parameters():
params_slice = params[i:i + p.numel()]
params_slice.copy_(p.flatten())
p.data = params_slice.view(p.shape)
i += p.numel()
return params
print("before fusing parameters")
with torch.no_grad(): print(model(torch.ones(3, 1)).flatten())
params = fuse_parameters(model)
print("after fusing parameters")
with torch.no_grad(): print(model(torch.ones(3, 1)).flatten());
params.mul_(2)
print("after modifying fused parameters")
with torch.no_grad(): print(model(torch.ones(3, 1)).flatten())
這打印:
before fusing parameters
tensor([-0.3356, -0.3356, -0.3356])
after fusing parameters
tensor([-0.3356, -0.3356, -0.3356])
after modifying fused parameters
tensor([-0.7728, -0.7728, -0.7728])
(事后做這種事情——從多個不同的張量中創建一個張量視圖——從 2022 年到 12 年,PyTorch 似乎不支持它。對嵌套張量有初步支持,但嵌套張量仍然復制數據。 )
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.