繁体   English   中英

使用 __getitem__ 方法赋值时出现 KeyError

[英]Getting KeyError when assigning value with __getitem__ method

我想实现bert model。

所以我构建了一个 class ,里面有__getitem__

我可以打印类似test[0]的东西,但是当我分配一个值时,比如data = test[0] ,就会发生KeyError


import random
"""
corpus_file = 'vocab'
vocab_size = 6
vocab_freq = 1
save_path = 'obj/'
max_sentence = 16

corpus -> org_line -> ope_line
corpus -> org_line -> token_list -> idx_to_token + token_to_idx
"""
class vocab():
    def __init__(self, corpus_file, vocab_size, vocab_freq,save_path,max_sentence):
        self.max_sentence = max_sentence
        self.special_labels = ['PAD', 'UNK', 'SEP', 'CLS', 'MASK']

        # output
        self.data = []
        self.idx_to_token = []
        self.token_to_idx = {}

        # ope
        self.pre_ope(corpus_file,vocab_size,vocab_freq)
        #self.save_data(save_path)
        #self.print_data()

    def pre_ope(self,corpus_file,vocab_size,vocab_freq):
        token_list = {}
        with open(corpus_file, 'r') as f:
            while 1:
                new_org_line = f.readline()
                if new_org_line != '':
                    new_org_line = new_org_line.strip('\n')
                    new_sentence = new_org_line.split('\t')
                    sentence = []
                    for tmp in new_sentence:
                        token_sentence = tmp.split()
                        sentence.append(token_sentence)
                        for token in token_sentence:
                            if token_list.get(token):
                                token_list[token] += 1
                            else:
                                new_token = {token: 1}
                                token_list.update(new_token)
                    self.data.append(sentence)
                else:
                    break
            f.close()
        token_list = sorted(token_list.items(), key=lambda i: (-i[1], i[0]))

        self.build_dictionary(token_list,vocab_freq,vocab_size)

    '''
    Special labels:
    PAD
    UNK
    SEP sentence separator
    CLS classifier token
    MASK
    '''
    def build_dictionary(self,token_list,vocab_freq,vocab_size):

        for idx, label in enumerate(self.special_labels):
            self.idx_to_token.append(label)
            self.token_to_idx[label] = idx

        for idx, (token, freq) in enumerate(token_list):
            if freq >= vocab_freq :
                self.idx_to_token.append(token)
                self.token_to_idx[token] = idx + len(self.special_labels)
                if len(self.idx_to_token) >= vocab_size + len(self.special_labels) and vocab_size != 0 :
                    break

    def __len__(self):
        return len(self.data)

    def print_data(self):
        print(self.data)
        print(self.idx_to_token)
        print(self.token_to_idx)

    def __getitem__(self, item):
        s1,s2,is_next_sentence = self.get_random_next_sentence(item)
        s1,s1_label = self.get_random_sentence(s1)
        s2,s2_label = self.get_random_sentence(s2)
        sentence = [self.token_to_idx['CLS']] +s1 +[self.token_to_idx['SEP']] +s2 +[self.token_to_idx['SEP']]
        label = [-1] +s1_label +[-1] +s2_label +[-1]
        if len(sentence) > self.max_sentence :
            print('sentence is greater than the setting of max sentence')
        for pos in range(len(sentence),self.max_sentence):
            sentence.append(self.token_to_idx['PAD'])
            label.append(-1)
        return {
            'token' : sentence,
            'label' : label,
            'is_next' : is_next_sentence
        }

    def get_random_next_sentence(self,item):
        s1 = self.data[item][0]
        s2 = self.data[item][1]
        if random.random() < 0.5 :
            is_next = 0
            s2 = self.data[self.get_random_line(item)][1]
        else:
            is_next = 1
        return s1,s2,is_next

    def get_random_line(self,item):
        rand = random.randint(0,len(self.data)-1)
        while rand == item :
            rand = random.randint(0,len(self.data)-1)
        return rand


    def get_random_sentence(self,sentence):
        label = []
        for idx,token in enumerate(sentence):
            rand = random.random()
            if rand < 0.15:
                rand = rand/0.15
                if rand < 0.8: #mask
                    sentence[idx] = self.token_to_idx['MASK']
                elif rand < 0.9: #rand
                    sentence[idx] = random.randint(len(self.special_labels),len(self.token_to_idx)-1)
                else: # still
                    sentence[idx] = self.token_to_idx[token]
                label.append(self.token_to_idx[token])
            else:
                sentence[idx] = self.token_to_idx[token]
                label.append(-1)
        return sentence,label

if __name__ == '__main__':
    test = vocab('vocab', 0, 1,'obj/',16)
    print(len(test))
    print(test[0])
    print(test[1])

    data = test[0]


结果:

2
{'token': [3, 4, 18, 12, 15, 11, 2, 7, 9, 13, 2, 0, 0, 0, 0, 0], 'label': [-1, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 'is_next': 0}
{'token': [3, 6, 4, 5, 8, 5, 17, 2, 16, 5, 14, 20, 2, 0, 0, 0], 'label': [-1, -1, 19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 'is_next': 0}

Traceback (most recent call last):
  File "vocab.py", line 146, in <module>
    data = test[0]
  File "vocab.py", line 90, in ```__getitem__```
    s1,s1_label = self.get_random_sentence(s1)
  File "vocab.py", line 136, in get_random_sentence
    sentence[idx] = self.token_to_idx[token]
KeyError: 4

vocab文件:

hello this is my home   nice to meet you
I want to go to school  and have lunch

更改代码:

    def get_random_next_sentence(self,item):
        s1 = self.data[item][0]
        s2 = self.data[item][1]
        if random.random() < 0.5 :
            is_next = 0
            s2 = self.data[self.get_random_line(item)][1]
        else:
            is_next = 1
        return s1,s2,is_next

至:

    def get_random_next_sentence(self,item):
        s1 = copy.deepcopy(self.data[item][0])
        s2 = copy.deepcopy(self.data[item][1])
        if random.random() < 0.5 :
            is_next = 0
            s2 = copy.deepcopy(self.data[self.get_random_line(item)][1])
            print(s2)
        else:
            is_next = 1
        return s1,s2,is_next

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM