简体   繁体   中英

Saving BERT Sentence Embedding

I'm currently working on an information retrieval task. I'm using SBERT to perform a semantic search. I already follows the documentation here

The model i use

model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

The outline is

  1. You have a list of corpus like this:
    data = ['A man is eating food.',
          'A man is eating a piece of bread.',
          'The girl is carrying a baby.',
          'A man is riding a horse.',
          'A woman is playing violin.',
          'Two men pushed carts through the woods.',
          'A man is riding a white horse on an enclosed ground.',
          'A monkey is playing drums.',
          'A cheetah is running behind its prey.'
          ]
  1. You have a query like this:
queries = ['A man is eating pasta.']
  1. Perform encoding with both query and corpus
query_embedding = model.encode(query)
doc_embedding = model.encode(data)

the encode function outputs a numpy.ndarray like this outputs of model.encode(data)

  1. And calculates the similarity using cosine similarity like this
similarity = util.cos_sim(query_embedding, doc_embedding)
  1. And if you print the similarity, you'll get the torch.Tensor containing score of similarity like this
tensor([[0.4389, 0.4288, 0.6079, 0.5571, 0.4063, 0.4432, 0.5467, 0.3392, 0.4293]])

And it works fine and fast. But ofcourse it is only using a small amount of corpus. When using a large amount of corpus it will take time for the encoding to work.

note: The encoding of query takes no time because it is only one sentence, but the encoding of the corpus will take some time

So, the question is can we save the doc_embedding locally, and use it again? especially when using a large corpus

is there any built-in class/function to do it from the transformers?

When dealing with a big corpus you need to use a vector database, I wrote a few guides on Faiss here and here that you might find useful. Faiss does require a lot of learning to get reasonable performance, and only stores the vectors (not any other information like IDs, text, etc), so you would need to set up another database like SQL to handle that.

From experience, that can be super annoying. So I'd recommend looking into a managed vector database like Pinecone , you can get that up and running from where you are now in ~10 minutes - it's free up to around 1M vectors, and performance is incredible.

Save them as pickle files and load them later = ]

import pickle

with open('doc_embedding.pickle', 'wb') as pkl:
    pickle.dump(doc_embedding, pkl)

with open('doc_embedding.pickle', 'rb') as pkl:
    doc_embedding = pickle.load(pkl)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM