简体   繁体   中英

Pytorch: How to get all model's parameters that require grad?

I have some model in pytorch, whose updatable weights I want to access and change manually.

How would that be done correctly?

Ideally, I would like a tensor of those weights.

It seems to me that

for parameter in model.parameters():
    do_something_to_parameter(parameter)

wouldn't be the right way to go, because

  1. It doesn't utilize GPU, and is not able to
  2. It doesn't even utilize low level implementation

What is the correct way of accessing a model's weights manually (not through loss.backward and optimizer.step )?

here's my method, you can generally input any model here and it will return a list of all torch.nn.* things, just add a wrap around it to return not module but it's weights

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(loopthrough(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(loopthrough(n))
        except:
            ret.append(modules)
    return flatten_list(ret)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM