[英]retrieving data from lmdb in python is throwing error
這就是數據集的創建方式
def createDataset1(
outputPath,
imagePathList,
labelList,
lexiconList=None,
validset_percent=10,
testset_percent=0,
random_seed=1111,
checkValid=True,
):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
train_path = os.path.join(outputPath, "training", "9M")
valid_path = os.path.join(outputPath, "validation", "9M")
# CAUTION: if train_path (lmdb) already exists, this function add dataset
# into it. so remove former one and re-create lmdb.
if os.path.exists(train_path):
os.system(f"rm -r {train_path}")
if os.path.exists(valid_path):
os.system(f"rm -r {valid_path}")
os.makedirs(train_path, exist_ok=True)
os.makedirs(valid_path, exist_ok=True)
gt_train_path = gt_file.replace(".txt", "_train.txt")
gt_valid_path = gt_file.replace(".txt", "_valid.txt")
data_log = open(gt_train_path, "w", encoding="utf-8")
if testset_percent != 0:
test_path = os.path.join(outputPath, "evaluation", dataset_name)
if os.path.exists(test_path):
os.system(f"rm -r {test_path}")
os.makedirs(test_path, exist_ok=True)
gt_test_path = gtFile.replace(".txt", "_test.txt")
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
num_valid_dataset = int(nSamples * validset_percent / 100.0)
num_test_dataset = int(nSamples * testset_percent / 100.0)
num_train_dataset = nSamples - num_valid_dataset - num_test_dataset
print("validation datasets: ",num_valid_dataset,"\n", "test datasets: ", num_test_dataset, " \n training datasets: ", num_train_dataset)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
random.seed(random_seed)
random.shuffle(imagePathList)
for i in tqdm(range(nSamples)):
data_log.write(imagePathList[i])
imagePath = imagePathList[i]
label = labelList[i]
if len(label) == 0:
continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
embed_vec = fasttext_model[label]
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
embedKey = 'embed-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
cache[embedKey] = ' '.join(str(v) for v in embed_vec.tolist()).encode()
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
#finish train dataset and start validation dataset
if i + 1 == num_train_dataset:
print(f"# Train dataset: {num_train_dataset} is finished")
cache["num-samples".encode()] = str(num_train_dataset).encode()
writeCache(env, cache)
data_log.close()
#start validation set
env = lmdb.open(valid_path, map_size=30 * 2 ** 30)
cache = {}
cnt = 0
data_log = open(gt_valid_path, "w", encoding="utf-8")
# Finish train/valid dataset and Start test dataset
if (i + 1 == num_train_dataset + num_valid_dataset) and num_test_dataset != 0:
print(f"# Valid dataset: {num_valid_dataset} is finished")
cache["num-samples".encode()] = str(num_valid_dataset).encode()
writeCache(env, cache)
data_log.close()
# start test set
env = lmdb.open(test_path, map_size=30 * 2 ** 30)
cache = {}
cnt = 0 # not 1 at this time
data_log = open(gt_test_path, "w", encoding="utf-8")
cnt += 1
if testset_percent == 0:
cache["num-samples".encode()] = str(num_valid_dataset).encode()
writeCache(env, cache)
print(f"# Valid datast: {num_valid_dataset} is finished")
else:
cache["num-samples".encode()] = str(num_test_dataset).encode()
writeCache(env, cache)
print(f"# Test datast: {num_test_dataset} is finished")
這就是我試圖檢索數據的方式
class LmdbDataset(data.Dataset):
def __init__(self, root, voc_type, max_len, num_samples, transform=None):
super(LmdbDataset, self).__init__()
if global_args.run_on_remote:
dataset_name = os.path.basename(root)
data_cache_url = "/cache/%s" % dataset_name
if not os.path.exists(data_cache_url):
os.makedirs(data_cache_url)
if mox.file.exists(root):
mox.file.copy_parallel(root, data_cache_url)
else:
raise ValueError("%s not exists!" % root)
self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
else:
self.env = lmdb.open(root, max_readers=32, readonly=True)
assert self.env is not None, "cannot create lmdb from %s" % root
self.txn = self.env.begin()
self.voc_type = voc_type
self.transform = transform
self.max_len = max_len
# nums = b"num-samples"
# print('NUM SAMPLES ------ \n',nums)
nSamples = self.txn.get('num-samples'.encode())
print("STRING nSamples :", nSamples)
self.nSamples = int(self.txn.get(b"num-samples"))
self.nSamples = min(self.nSamples, num_samples)
assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']
self.EOS = 'EOS'
self.PADDING = 'PADDING'
self.UNKNOWN = 'UNKNOWN'
self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
self.char2id = dict(zip(self.voc, range(len(self.voc))))
self.id2char = dict(zip(range(len(self.voc)), self.voc))
self.rec_num_classes = len(self.voc)
self.lowercase = (voc_type == 'LOWERCASE')
每當代碼嘗試調用 elf.txn.get(b"num-samples") 時,我都會收到以下錯誤
Traceback (most recent call last):
File "main.py", line 268, in <module>
main(args)
File "main.py", line 157, in main
train_dataset, train_loader = get_data_lmdb(args.synthetic_train_data_dir, args.voc_type, args.max_len, args.num_train,
File "main.py", line 66, in get_data_lmdb
dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples))
File "/Users/SEED/lib/datasets/dataset.py", line 189, in __init__
self.nSamples = int(self.txn.get(b"num-samples"))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'
我在網上嘗試了許多不同的建議和一些 stackoverflow 線程,但無法找出問題所在。
是什么導致了這個錯誤,我該如何解決?
您的代碼密集、復雜且難以跟蹤數據流。 它將大量計算與 IO、副作用,甚至系統調用混合在一起( os.system(f"rm -r {test_path}")
,您應該改用shutil.rmtree )。
嘗試分解您希望執行的每個操作,以便它主要做一件特定的事情:
在每個階段,您都應該進行驗證,並使用最小功率規則。 如果你希望self.txn
總是有 'num-samples , then you should use
self.txn[b'num-samples'] rather
.get , which defaults to
None 。 這使得更容易在鏈中更早地捕獲錯誤。
我也不知道lmbd
模塊是什么。 那是一個庫,還是代碼庫中的另一個文件? 您應該鏈接到您正在使用的任何庫,如果它們不是眾所周知的話。 我看到幾個 python 包與 lmbd 有關。
從代碼示例中,我可以為您剖析。 可能會幫助您解決問題。
日志說..
self.nSamples = int(self.txn.get(b"num-samples"))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'
self.txn 應該是一個字典,在那里使用 get 方法。 當你試圖得到它時,可能有兩種情況
它們都可能導致無法通過類型轉換(到 int)處理的 NoneType。 因此,您必須使用 self.txn.keys() 檢查該字段是否存在於數據中,並查看該鍵是否存在
此外,如果
self.nSamples = int(self.txn.get(b"num-samples"))
失敗而不是上述聲明
nSamples = self.txn.get('num-samples'.encode())
然后你可以簡單地解碼變量並在那里使用它
喜歡
Samples = nSamples.decode("utf-8")
try:
int_samples = int(Samples)
except TypeError:
print(Samples)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.