I would like to train the spacy text classifier using labels and words from a dataframe. But I can't get right the training_data and pass it to train.
Dataframe example:
category word score
0 anger fasten 0.0
1 anger morals 1.0
2 anger tributary 0.0
3 anger changer 0.0
4 anger morality 0.0
... ... ... ...
184125 trust amber 0.0
184126 trust pulmonary 0.0
184127 trust ambient 0.0
184128 trust amaze 0.0
184129 trust zoom 0.0
SAMPLE CODE
TRAIN_DATA = [
# HERE THE TRAIN DATA FROM THE DATAFRAME
# anger : words related with anger
# trust : words related with trust
]
nlp = spacy.load("en_core_web_sm")
category = nlp.create_pipe("textcat", config={"exclusive_classes": True})
nlp.add_pipe(category)
# add label to text classifier
category.add_label("Cat")
category.add_label('False')
optimizer = nlp.begin_training()
losses = {}
for i in range(100):
random.shuffle(TRAIN_DATA)
for batch in minibatch(TRAIN_DATA, size=8):
texts = [nlp(text) for text, entities in batch]
annotations = [{"cats": entities} for text, entities in batch]
nlp.update(texts, annotations, sgd=optimizer, losses=losses)
print(i, losses)
EXPECTED OUTPUT:
doc = nlp(u'confidence') --> prediction : trust
So I make the training_data format with this code, it seems is taking 4 hours per each training though.
def cat_dict_funct(cat_dict, n):
for i in range(8):
if i == n:
cat_dict[lexicon_labels[i]] = 1
else:
cat_dict[lexicon_labels[i]] = 0
train_data = df
train_texts = train_data['word'].tolist()
train_cats = train_data['category'].tolist()
final_train_cats, cat_dict = [], {}
for cat in train_cats:
if cat == 'trust':
cat_dict_funct(cat_dict, 0)
elif cat == 'fear':
cat_dict_funct(cat_dict, 1)
elif cat == 'disgust':
cat_dict_funct(cat_dict, 2)
elif cat == 'surprise':
cat_dict_funct(cat_dict, 3)
elif cat == 'anticipation':
cat_dict_funct(cat_dict, 4)
elif cat == 'anger':
cat_dict_funct(cat_dict, 5)
elif cat == 'joy':
cat_dict_funct(cat_dict, 6)
else:
cat_dict_funct(cat_dict, 7)
final_train_cats.append(cat_dict)
TRAIN_DATA = list(zip(train_texts, [{"cats": cats} for cats in final_train_cats]))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.