簡體   English   中英

使用張量流的viterbi_decode時出錯

[英]Error when using viterbi_decode of tensorflow

我正在使用這個github.com/Determined22/zh-NER-TF我只是使用了另一個相同格式的train_data。 代碼沒什么錯,因為當我使用原始的train_data運行時,這沒關系。 是什么原因造成的?

Traceback (most recent call last):
  File "main.py", line 83, in <module>
    model.train(train=train_data, dev=dev_data)
  File "/home/mengyuguang/NER/model.py", line 161, in train
    self.run_one_epoch(sess, train, dev, self.tag2label, epoch, saver)
  File "/home/mengyuguang/NER/model.py", line 221, in run_one_epoch
    label_list_dev, seq_len_list_dev = self.dev_one_epoch(sess, dev)
  File "/home/mengyuguang/NER/model.py", line 256, in dev_one_epoch
    label_list_, seq_len_list_ = self.predict_one_batch(sess, seqs)
  File "/home/mengyuguang/NER/model.py", line 277, in predict_one_batch
    viterbi_seq, _ = viterbi_decode(logit[:seq_len], transition_params)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py", line 333, in viterbi_decode
    trellis[0] = score[0]
IndexError: index 0 is out of bounds for axis 0 with size 0
def read_corpus(self, corpus_path):
    data = []
    with open(corpus_path, 'r') as r_file:
        sent_, tag_ = [], []
        for line in r_file:
            line = line.strip()
            if len(line) != 0 and line != '-DOCSTART-':
                ls = line.split('\t')
                char, tag = ls[0], ls[-1]
                sent_.append(char)
                tag_.append(tag)
            else:
                    data.append((sent_, tag_))
                    sent_, tag_ = [], []
        # Bug-fix
        # Here, since the last tuple (sent_, tag_) will be added into data
        # It will case IndexError in viterbi_decode since the sequence_length is 0
        if sent_ and tag_:
            data.append((sent_, tag_))
    self.data = data

該代碼應更改為以下內容:

def read_corpus(corpus_path):
    """
    read corpus and return the list of samples
    :param corpus_path:
    :return: data
    """
    data = []
    with open(corpus_path, encoding='utf-8') as fr:
        lines = fr.readlines()
    sent_, tag_ = [], []
    for line in lines:
        if line != '\n' and line != '\t\n':  #
            [char, label] = line.strip().split()
            sent_.append(char)
            tag_.append(label)
        #else:
        elif sent_ !=[] and tag_ !=[]: # 
            data.append((sent_, tag_))
            sent_, tag_ = [], []

    return data

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM