简体   繁体   中英

Chatbot using Huggingface Transformers

I would like to use Huggingface Transformers to implement a chatbot. Currently, I have the code shown below. The transformer model already takes into account the history of past user input.

Is there something else (additional code) I have to take into account for building the chatbot?

Second, how can I modify my code to run with TensorFlow instead of PyTorch?

Later on, I also plan to fine-tune the model on other data. I also plan to test different models such as BlenderBot and GPT2. I think to test this different models it should be as easy as replacing the corresponding model in AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") and AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")

for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

    # pretty print last ouput tokens from bot
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

Here is an example of using the DialoGPT model with Tensorflow:

from transformers import TFAutoModelForCausalLM, AutoTokenizer, BlenderbotTokenizer, TFBlenderbotForConditionalGeneration
import tensorflow as tf

chat_bots = {
    'BlenderBot': [BlenderbotTokenizer.from_pretrained('facebook/blenderbot-400M-distill'), TFT5ForConditionalGeneration.from_pretrained('facebook/blenderbot-400M-distill')],
    'DialoGPT': [AutoTokenizer.from_pretrained("microsoft/DialoGPT-small"), TFAutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")],
} 
key = 'DialoGPT'
tokenizer, model = chat_bots[key]

for step in range(5):
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='tf')
    if step > 0:
      bot_input_ids = tf.concat([chat_history_ids, new_user_input_ids], axis=-1)  
    else:
      bot_input_ids = new_user_input_ids

    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

    print(key + ": {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
>> User:How are you?
DialoGPT: I'm here
>> User:Why are you here
DialoGPT: I'm here
>> User:But why
DialoGPT: I'm here
>> User:Where is here
DialoGPT: Where is where?
>> User:Here
DialoGPT: Where is here?

If you want to compare different chatbots, you might want to adapt their decoder parameters, because they are not always identical. For example, using BlenderBot and a max_length of 50 you get this kind of response with the current code:

>> User:How are you?
BlenderBot: ! I am am great! how how how are are are???

In general, you should ask yourself which special characters are important for a chatbot (depending on your domain) and which characters should / can be omitted?

You should also experiment with different decoding methods such as greedy search, beam search, random sampling, top-k sampling, and nucleus sampling and find out what works best for your use case. For more information on this topic check out this post

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM