簡體   English   中英

如何獲得 PyTorch 模型參數的平面視圖?

[英]How to get a flattened view of PyTorch model parameters?

我想要屬於我的 Pytorch 模型的參數的平面視圖。 請注意,它應該是一個視圖而不是參數的副本。 換句話說,當我修改視圖中的參數時,它也應該修改模型參數。 我可以得到模型參數如下:

import torch

model = torch.nn.Sequential(
    torch.nn.Linear(1, 10), 
    torch.nn.Tanh(), 
    torch.nn.Linear(10, 1)
)

params = list(model.parameters())

for p in params:
    print(p)

這里的params是一個張量列表。 我需要它是所有參數的一維張量。 做起來很簡單

params = torch.cat([p.flatten() for p in model.parameters()])
print(params.shape) # torch.Size([31])

但是,現在修改params中的參數不會更改實際的模型參數(因為torch.cat()復制內存)。 是否可以獲得模型參數的一維張量視圖?

您應該能夠通過首先在 1d 中構造參數張量然后將視圖復制到模型中來實現此目的:

import torch

model = torch.nn.Sequential(
    torch.nn.Linear(1, 10), 
    torch.nn.Tanh(), 
    torch.nn.Linear(10, 1)
)

def fuse_parameters(model):
    """Move model parameters to a contiguous tensor, and return that tensor."""
    n = sum(p.numel() for p in model.parameters())
    params = torch.zeros(n)
    i = 0
    for p in model.parameters():
        params_slice = params[i:i + p.numel()]
        params_slice.copy_(p.flatten())
        p.data = params_slice.view(p.shape)
        i += p.numel()
    return params

print("before fusing parameters")
with torch.no_grad(): print(model(torch.ones(3, 1)).flatten())
params = fuse_parameters(model)
print("after fusing parameters")
with torch.no_grad(): print(model(torch.ones(3, 1)).flatten());
params.mul_(2)
print("after modifying fused parameters")
with torch.no_grad(): print(model(torch.ones(3, 1)).flatten())

這打印:

before fusing parameters
tensor([-0.3356, -0.3356, -0.3356])
after fusing parameters
tensor([-0.3356, -0.3356, -0.3356])
after modifying fused parameters
tensor([-0.7728, -0.7728, -0.7728])

(事后做這種事情——從多個不同的張量中創建一個張量視圖——從 2022 年到 12 年,PyTorch 似乎不支持它。對嵌套張量有初步支持,但嵌套張量仍然復制數據。 )

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM