简体   繁体   中英

What's the reason for the weights of my NN model don't change a lot?

I am training a neural network model, and my model fits the training data well. The training loss decreases stably. Everything works fine.
However, when I output the weights of my model, I found that it didn't change too much since random initialization (I didn't use any pretrained weights. All weights are initialized by default in PyTorch). All dimension of the weights only changed about 1%, while the accuracy on training data climbed from 50% to 90%. What could account for this phenomenon? Is the dimension of weights too high and I need to reduce the size of my model? Or is there any other possible explanations?

I understand this is a quite broad question, but I think it's impractical for me to show my model and analyze it mathematically here. So I just want to know what could be the general / common cause for this problem.

There are almost always many local optimal points in a problem so one thing you can't say specially in high dimensional feature spaces is which optimal point your model parameters will fit into. one important point here is that for every set of weights that you are computing for your model to find a optimal point, because of real value weights, there are infinite set of weights for that optimal point, the proportion of weights to each other is the only thing that matters, because you are trying to minimize the cost, not finding a unique set of weights with loss of 0 for every sample. every time you train you may get different result based on initial weights. when weights change very closely with almost same ratio to each others this means your features are highly correlated(ie redundant) and since you are getting very high accuracy just with a little bit of change in weights, only thing i can think of is that your data set classes are far away from each other. try to remove features one at a time, train and see results if accuracy was good continue to remove another one till you hopefully reach to a 3 or 2 dimensional space which you can plot your data and visualize it to see how data points are distributed and make some sense out of this.

EDIT: Better approach is to use PCA for dimensionality reduction instead of removing one by one

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM