简体   繁体   中英

How to convert embedding matrix to torch.Tensor

I am new to pytorch and not sure how to convert an embedding matrix to a torch.Tensor type

I have 240 rows of input text data that I convert to embedding using Sentence Transformer library like below

embedding_model = SentenceTransformer('bert-base-nli-mean-tokens')
features = embedding_model.encode(df.features.values)

Now this features is a numpy.ndarray of shape (240, 768)

I have defined the model as

class NClassifier(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, tagset_size):
        super(NClassifier, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(input_dim, embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

        # The linear layer that maps from hidden state space to code space (output clases)
        self.hidden2code = nn.Linear(hidden_dim, tagset_size)

    def forward(self, features):
        embeds = self.word_embeddings(features)
        lstm_out, _ = self.lstm(embeds.view(len(features), 1, -1))
        code_space = self.hidden2code(lstm_out.view(len(features), -1))
        code_scores = F.log_softmax(code_space, dim=1)
        return code_scores


INPUT_DIM = 240
EMBEDDING_DIM = 768
HIDDEN_DIM = 256
OUTPUT_DIM = 34

model = NClassifier(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

Now when I do scores = model(features) I get error as features is NOT a tensor. I see the example of converting the input to tensor here but it is not clear to me.

Can anyone please help?

A numpy.ndarray can be converted to a torch.Tensor with the torch.tensor() function, like this:

features_tensor = torch.tensor(features)

Does that work for you?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM