I have some model in pytorch, whose updatable weights I want to access and change manually.
How would that be done correctly?
Ideally, I would like a tensor of those weights.
It seems to me that
for parameter in model.parameters():
do_something_to_parameter(parameter)
wouldn't be the right way to go, because
What is the correct way of accessing a model's weights manually (not through loss.backward
and optimizer.step
)?
here's my method, you can generally input any model here and it will return a list of all torch.nn.* things, just add a wrap around it to return not module but it's weights
def flatten_model(modules):
def flatten_list(_2d_list):
flat_list = []
# Iterate through the outer list
for element in _2d_list:
if type(element) is list:
# If the element is of type list, iterate through the sublist
for item in element:
flat_list.append(item)
else:
flat_list.append(element)
return flat_list
ret = []
try:
for _, n in modules:
ret.append(loopthrough(n))
except:
try:
if str(modules._modules.items()) == "odict_items([])":
ret.append(modules)
else:
for _, n in modules._modules.items():
ret.append(loopthrough(n))
except:
ret.append(modules)
return flatten_list(ret)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.