![](/img/trans.png)
[英]key dataset lost during training using the Hugging Face Trainer
[英]Getting KeyErrors when training Hugging Face Transformer
我通常遵循本教程( https://huggingface.co/docs/transformers/training#:~:text=%F0%9F%A4%97%20Transformers%20provides%20access%20to,an%20incredibly%20powerful%20training %20technique. ) 在预训练的变压器上实现微调。 主要区别在于我使用的是我自己的自定义数据集,该数据集来源于一个 JSON 文件,该文件具有文档的文本和它应该属于的标签。 为了能够做到这一点,我需要创建自己的类,该类基于 pytorch 的 Dataset 类。 这就是那个类的样子:
class PDFsDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
print("\n\n\n\nindex",idx)
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
我在训练变压器时遇到错误,上面写着
Traceback (most recent call last):
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\pandas\core\indexes\base.py", line 3621, in get_loc
return self._engine.get_loc(casted_key)
File "pandas\_libs\index.pyx", line 136, in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\index.pyx", line 163, in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\hashtable_class_helper.pxi", line 2131, in pandas._libs.hashtable.Int64HashTable.get_item
File "pandas\_libs\hashtable_class_helper.pxi", line 2140, in pandas._libs.hashtable.Int64HashTable.get_item
KeyError: 19
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "c:\Users\e417922\Downloads\enwiki-20220601-pages-meta-history1.xml-p1p857\HF_Transformer.py", line 147, in <module>
transformer.train_transformer()
File "c:\Users\e417922\Downloads\enwiki-20220601-pages-meta-history1.xml-p1p857\HF_Transformer.py", line 135, in train_transformer
trainer.train()
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\transformers\trainer.py", line 1409, in train
return inner_training_loop(
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\transformers\trainer.py", line 1625, in _inner_training_loop
for step, inputs in enumerate(epoch_iterator):
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
data = self._next_data()
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\dataloader.py", line 570, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "c:\Users\e417922\Downloads\enwiki-20220601-pages-meta-history1.xml-p1p857\HF_Transformer.py", line 42, in __getitem__
for key in self.encodings[idx]:
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\pandas\core\series.py", line 958, in __getitem__
return self._get_value(key)
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\pandas\core\series.py", line 1069, in _get_value
loc = self.index.get_loc(label)
File "C:\Users\e417922\AppData\Roaming\Python\Python39\site-packages\pandas\core\indexes\base.py", line 3623, in get_loc
raise KeyError(key) from err
KeyError: 19
每次我运行它时它失败的 KeyError 都会发生变化。 我是 Transformers 和 HuggingFace 的初学者,所以我不知道是什么导致了这个问题。
编辑:示例输入是一个 JSON 文件,其中元素看起来像这样:{“text_clean”:[“有几百个单词的文章”,另一个有很多单词的文章”,“又一篇文章”],“most_similar_label”: [《量子》《人工智能》《材料》]}
完整代码:
import tkinter as tk
from tkinter import filedialog
import json
import pandas as pd
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import TrainingArguments, Trainer
import numpy as np
from datasets import load_metric
from sklearn.model_selection import train_test_split
import torch
class PDFsDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
print("\n\n\n\nindex",idx)
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
class HFTransformer:
def __init__ (self):
pass
def import_from_json(self):
#Prompts user to select json file
root = tk.Tk()
root.withdraw()
self.json_file_path = filedialog.askopenfile().name
#opens json file and loads data
with open(self.json_file_path, "r") as json_file:
try:
json_load = json.load(json_file)
except:
raise ValueError("No PDFs to convert to JSON")
self.pdfs = json_load
#converts json file data to dataframe for easier manipulation
self.pdfs = pd.DataFrame.from_dict(self.pdfs)
for index in range(len(self.pdfs["new_tags"])):
if self.pdfs["new_tags"][index] == "":
self.pdfs["new_tags"][index] = self.pdfs["most_similar_label"][index]
self.pdfs["labels"] = self.pdfs["new_tags"].apply(lambda val: self.change_tag_to_num(val))
# for label in self.data["labels"]:
def change_tag_to_num(self, value):
if value == "Quantum":
return 0
elif value == "Artificial intelligence":
return 1
elif value == "Materials":
return 2
elif value == "Energy":
return 3
elif value == "Defense":
return 4
elif value == "Satellite":
return 5
elif value == "Other":
return 6
def tokenize_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
X_train, X_test, y_train, y_test = train_test_split(self.pdfs["text_clean"], self.pdfs["labels"],test_size=0.2)
train_encodings = X_train.apply(lambda string: tokenizer(string, truncation=True, padding=True,max_length=10))
test_encodings = X_test.apply(lambda string: tokenizer(string, truncation=True, padding=True,max_length=10))
self.train_dataset = PDFsDataset(train_encodings, y_train)
data_to_add = {"input_ids": [], "token_type_ids": [], "attention_mask": []}
for i in self.train_dataset.encodings:
data_to_add["input_ids"].append(i["input_ids"])
data_to_add["token_type_ids"].append(i["token_type_ids"])
data_to_add["attention_mask"].append(i["attention_mask"])
self.train_dataset.encodings = data_to_add
self.eval_dataset = PDFsDataset(test_encodings,y_test)
data_to_add = {"input_ids": [], "token_type_ids": [], "attention_mask": []}
for i in self.eval_dataset.encodings:
data_to_add["input_ids"].append(i["input_ids"])
data_to_add["token_type_ids"].append(i["token_type_ids"])
data_to_add["attention_mask"].append(i["attention_mask"])
self.eval_dataset.encodings = data_to_add
def train_transformer(self):
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=7)
training_args = TrainingArguments(output_dir="test_trainer")
self.metric = load_metric("accuracy")
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
compute_metrics=self.compute_metrics
)
trainer.train()
def compute_metrics(self, eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return self.metric.compute(predictions=predictions, references=labels)
if __name__ == "__main__":
transformer = HFTransformer()
transformer.import_from_json()
transformer.tokenize_dataset()
transformer.train_transformer()
将pandas.Series
转换为一个简单的 python 列表并去掉一些额外的材料可以解决这个问题
class PDFsDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
class HFTransformer:
def __init__ (self):
pass
def import_from_json(self):
#Prompts user to select json file
self.json_file_path = '/content/truncated_data.json'
#opens json file and loads data
with open(self.json_file_path, "r") as json_file:
try:
json_load = json.load(json_file)
except:
raise ValueError("No PDFs to convert to JSON")
self.pdfs = json_load
#converts json file data to dataframe for easier manipulation
self.pdfs = pd.DataFrame.from_dict(self.pdfs)
for index in range(len(self.pdfs["new_tags"])):
if self.pdfs["new_tags"][index] == "":
self.pdfs["new_tags"][index] = self.pdfs["most_similar_label"][index]
self.pdfs["labels"] = self.pdfs["new_tags"].apply(lambda val: self.change_tag_to_num(val))
# for label in self.data["labels"]:
def change_tag_to_num(self, value):
if value == "Quantum":
return 0
elif value == "Artificial intelligence":
return 1
elif value == "Materials":
return 2
elif value == "Energy":
return 3
elif value == "Defense":
return 4
elif value == "Satellite":
return 5
elif value == "Other":
return 6
def tokenize_dataset(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
X_train, X_test, y_train, y_test = train_test_split(self.pdfs["text_clean"].to_list(), self.pdfs["labels"].to_list(),test_size=0.2)
train_encodings = tokenizer(X_train, truncation=True, padding=True,max_length=100)
test_encodings = tokenizer(X_test, truncation=True, padding=True,max_length=100)
self.train_dataset = PDFsDataset(train_encodings, y_train)
self.eval_dataset = PDFsDataset(test_encodings,y_test)
def train_transformer(self):
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=7)
training_args = TrainingArguments(output_dir="test_trainer")
self.metric = load_metric("accuracy")
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch", )
trainer = Trainer(
model=model,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
compute_metrics=self.compute_metrics
)
trainer.train()
def compute_metrics(self, eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return self.metric.compute(predictions=predictions, references=labels)
if __name__ == "__main__":
tr = HFTransformer()
tr.import_from_json()
tr.tokenize_dataset()
tr.train_transformer()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.