简体   繁体   English

训练时遇到的BERT model bug

[英]BERT model bug encountered during training

So, I made a custom dataset consisting of reviews form several E-learning sites.因此,我制作了一个自定义数据集,其中包含来自多个电子学习网站的评论。 What I am trying to do is build a model that can recognize emotions based on text and for training I am using the dataset I've made via scraping.我正在尝试做的是构建一个 model 可以识别基于文本的情绪和训练我正在使用我通过抓取制作的数据集。 While working on BERT, I encountered this error在使用BERT时,我遇到了这个错误

normalize() argument 2 must be str, not float

here's my code:-这是我的代码:-

import numpy as np 
import pandas as pd
import numpy as np


import tensorflow as tf
print(tf.__version__)
import ktrain
from ktrain import text

from sklearn.model_selection import train_test_split
import pickle


#class_names = ["Frustration", "Not satisfied", "Satisfied", "Happy", "Excitement"]


data = pd.read_csv("Final_scraped_dataset.csv")
print(data.head())

X = data['Text']
y = data['Emotions']


class_names = np.unique(data['Emotions'])
print(class_names)
        
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 42)


    
    
print(X_train.shape)
print(y_train.shape)

print(X_test.shape)
print(y_test.shape)
print(X_train.head(10))

encoding = {
    'Frustration': 0,
    'Not satisfied': 1,
    'Satisfied': 2,
    'Happy': 3,
    'Excitement' : 4
}

y_train = [encoding[x] for x in y_train]
y_test = [encoding[x] for x in y_test]



X_train = X_train.tolist()
X_test = X_test.tolist()




#print(X_train)

(x_train,  y_train), (x_test, y_test), preproc = text.texts_from_array(x_train=X_train, y_train=y_train,
                                                                       x_test=X_test, y_test=y_test,
                                                                       class_names=class_names,
                                                                       preprocess_mode='bert',
                                                                       maxlen=200, 
                                                                       max_features=15000) #I've encountered the error here


'''model = text.text_classifier('bert', train_data=(x_train, y_train), preproc=preproc)

learner = ktrain.get_learner(model, train_data=(x_train, y_train), 
                             val_data=(x_test, y_test),
                             batch_size=4)

learner.fit_onecycle(2e-5, 3)


learner.validate(val_data=(x_test, y_test))

predictor = ktrain.get_predictor(learner.model, preproc)
predictor.get_classes()

import time 

message = 'I hate you a lot'

start_time = time.time() 
prediction = predictor.predict(message)

print('predicted: {} ({:.2f})'.format(prediction, (time.time() - start_time)))

# let's save the predictor for later use
predictor.save("new_model/bert_model")


print("SAVED  _______")'''

here's the complete error:-这是完整的错误:-


  File "D:\Sentiment analysis\BERT_model_new_dataset.py", line 73, in <module>
    max_features=15000)

  File "D:\Anaconda3\envs\pythy37\lib\site-packages\ktrain\text\data.py", line 373, in texts_from_array
    trn = preproc.preprocess_train(x_train, y_train, verbose=verbose)

  File "D:\Anaconda3\envs\pythy37\lib\site-packages\ktrain\text\preprocessor.py", line 796, in preprocess_train
    x = bert_tokenize(texts, self.tok, self.maxlen, verbose=verbose)

  File "D:\Anaconda3\envs\pythy37\lib\site-packages\ktrain\text\preprocessor.py", line 166, in bert_tokenize
    ids, segments = tokenizer.encode(doc, max_len=max_length)

  File "D:\Anaconda3\envs\pythy37\lib\site-packages\keras_bert\tokenizer.py", line 73, in encode
    first_tokens = self._tokenize(first)

  File "D:\Anaconda3\envs\pythy37\lib\site-packages\keras_bert\tokenizer.py", line 103, in _tokenize
    text = unicodedata.normalize('NFD', text)

TypeError: normalize() argument 2 must be str, not float

It sounds like you may have a float value in your data['Text'] column somehow.听起来您的data['Text']列中可能有一个浮点值。

You can try something like this to shed more light on what's happening:您可以尝试这样的事情来进一步了解正在发生的事情:

for i, s in enumerate(data['Text']):
    if not isinstance(s, str):  print('Text in row %s is not a string: %s' % (i, s))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM