[英]loading pretrained model in pytorch
First, I'd like to apologize this question may sound stupid but I'm new to deep learning.首先,我想道歉这个问题可能听起来很愚蠢,但我是深度学习的新手。 Can anybody explain to me the following lines of code which was used to load the pre-trained model in PyTorch?
任何人都可以向我解释以下用于在 PyTorch 中加载预训练的 model 的代码行吗?
# Retrieving model parameters from checkpoint.
vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
num_classes = checkpoint["model"]["_classification.4.weight"].size(0)
I can't understand the projection, weight, classification, size(0), size(1) in the above text.看不懂上面文字中的投影、权重、分类、size(0)、size(1)。
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
vocab_size = 10000
embed_size = 100
# word embedding layer
self._word_embedding = nn.Embedding(vocab_size, embed_size)
# linear transformation layers (no bias)
self._projection = nn.ModuleList([nn.Linear(100, 50, bias=False)
for i in range(2)])
# linear transformation layers (no bias)
self._classification = nn.ModuleList([nn.Linear(50, 50, bias=False)
for i in range(4)])
def forward(self):
return
model = Model()
checkpoint = {
'model': model.state_dict() # OrderedDict
}
# _word_embedding.weight --> torch.Size([10000, 100])
# _projection.0.weight --> torch.Size([50, 100])
# _projection.1.weight --> torch.Size([50, 100])
# _classification.0.weight --> torch.Size([50, 50])
# _classification.1.weight --> torch.Size([50, 50])
# _classification.2.weight --> torch.Size([50, 50])
# _classification.3.weight --> torch.Size([50, 50])
for name, param in checkpoint['model'].items():
print(name, '-->', param.size()) # see above
# similarly, we can print as follows
print(checkpoint["model"]["_word_embedding.weight"].size(0)) # 10000
print(checkpoint["model"]["_word_embedding.weight"].size(1)) # 100
print(checkpoint["model"]["_projection.0.weight"].size(0)) # 50
print(checkpoint["model"]["_classification.0.weight"].size(0)) # 50
Prepared an example to help you understand the meaning of those four lines.准备了一个例子来帮助你理解这四行的含义。
I can't understand the projection, weight, classification, size(0), size(1) in the above text.
看不懂上面文字中的投影、权重、分类、size(0)、size(1)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.