简体   繁体   中英

Find PyTorch model parameters that don't contribute to loss

In PyTorch (v1.10) Distibuted DataParallel, unused parameters in a model that don't contribute to the final loss can raise a RuntimeError (as mentioned in this other question , this PyTorch forums thread ).

"RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel , and by making sure all forward function outputs participate in calculating loss."

Although it's possible to inspect which parameters are affected at error-time (as mentioned above, or setting env var TORCH_DISTRIBUTED_DEBUG="INFO" ), it seems like there should be a way to statically inspect a model to locate (and presumably prune or disable gradient on) parameters that aren't contributing to the current loss objective?

So given a torch.nn.Module -based model whose forward() function returns some loss tensor (maybe alongside others) - How can we programmatically, before starting to train, find all parameters (including nested modules) that aren't contributing to loss ?

By default, PyTorch tensors that are the result of some computation record their history, that is their ancestors. This is needed for the backward pass to compute the gradient.

We can make use of this to find all tensors that contribute to some new tensors by just going through the whole history.

Note that this works for a static network that always has the same architecture. As soon as you have conditionals that eg depend on some intermediate value this won't work, and I claim in that case it is impossible to find what tensors are involved in advance. (It's similar to the halting problem.)

import torch
import torch.nn as nn
# Example of a simple network
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.x = nn.Parameter(torch.tensor([999999.0]))  # not contributing
        self.layers = nn.ModuleList([nn.Sequential(nn.Linear(1, 4), nn.Linear(4, 1)) for _ in range(3)])
    def forward(self, x):
        for m in self.layers: x = m(x) + x
        return x

net = Net()
x = torch.ones((1, 1))
# compute the forward pass to create the computation graph
y = net(x)

# use computation graph to find all contributing tensors
def get_contributing_params(y, top_level=True):
    nf = y.grad_fn.next_functions if top_level else y.next_functions
    for f, _ in nf:
        try:
            yield f.variable
        except AttributeError:
            pass  # node has no tensor
        if f is not None:
            yield from get_contributing_params(f, top_level=False)

contributing_parameters = set(get_contributing_params(y))
all_parameters = set(net.parameters())
non_contributing = all_parameters - contributing_parameters
print(non_contributing)  # returns the [999999.0] tensor

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM