[英](Keras) My CNN model training progress get stuck
我基於存儲庫 [ https://github.com/matterport/Mask_RCNN]開發了我的 CNN model。 當我運行程序時(使用 cmd: coco.py train --dataset=/DATASET/COCO/2017 --model=None,我建議加載語句跳過 model 權重加載),該過程經歷了 Z20F35E630DAF39DDB 構建 model加載然后開始調用 model.train()。
# Create model
if args.command == "train":
model = modellib.MeshMask_RCNN(mode="training", config=config,
model_dir=args.logs)
else:
model = modellib.MeshMask_RCNN(mode="inference", config=config,
model_dir=args.logs)
# Select weights file to load
if args.model.lower() == "coco":
model_path = COCO_MODEL_PATH
elif args.model.lower() == "last":
# Find last trained weights
model_path = model.find_last()
elif args.model.lower() == "imagenet":
# Start from ImageNet trained weights
model_path = IMAGENET_MODEL_PATH()
else:
model_path = args.model
# Load weights
print("Loading weights ", model_path)
# model.load_weights(model_path, by_name=True)
# Train or evaluate
if args.command == "train":
# Training dataset. Use the training set and 35K from the
# validation set, as as in the Mask RCNN paper.
dataset_train = CocoDataset()
dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
if args.year in '2014':
dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
dataset_train.prepare()
# Validation dataset
dataset_val = CocoDataset()
val_type = "val" if args.year in '2017' else "minival"
dataset_val.load_coco(args.dataset, val_type, year=args.year, auto_download=args.download)
dataset_val.prepare()
# Image Augmentation
# Right/Left flip 50% of the time
augmentation = imgaug.augmenters.Fliplr(0.5)
# *** This training schedule is an example. Update to your needs ***
# Training - Stage 0
print("Fine tune all layers")
# get stuck when invoking this function #
> model.train(dataset_train, dataset_val,
> learning_rate=config.LEARNING_RATE,
> epochs=160,
> layers='all',
> augmentation=augmentation)
在 model.train() 中,它開始從磁盤加載圖像,memory 使用量開始增加到大約 80GB,然后所有進度都卡住了(沒有訓練消息,Cpu/Gpu 使用率很低)。 我暫停了一下,發現 multiprocessing/pool.py 的 404~406 行之間的程序循環。
@staticmethod
def _handle_workers(pool):
thread = threading.current_thread()
# Keep maintaining workers until the cache gets drained, unless the pool
# is terminated.
404 while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
405 pool._maintain_pool()
406 time.sleep(0.1)
# send sentinel to stop workers
pool._taskqueue.put(None)
util.debug('worker handler exiting')
這是否意味着有一些資源沒有滿足需求,所以卡住了? 我是 keras 和 tensorflow 的新人。 任何人都可以幫忙嗎? 謝謝。
修正:當我追查時,我找到了程序卡住的確切語句。
# tensorflow_core/python/client/session.py
class _Callable(object):
def __init__(self, session, callable_options):
self._session = session
self._handle = None
options_ptr = tf_session.TF_NewBufferFromString(
compat.as_bytes(callable_options.SerializeToString()))
try:
> slef._handle = tf_session.TF_SessionMakeCallable(
> session._session, options_ptr)
finally:
tf_session.TF_DeleteBuffer(options_ptr)
確保您使用的是 Tenorflow gpu:
import tensorflow-gpu
另外,為 tensorflow session 添加一個設備
with tf.device('/gpu:0'):
實際上,它並沒有卡住,它只是消耗了太多時間。 我沒有意識到我正在建造的 model 有多大。 我認為它卡住了,因為 tf 在打印“epoch 1/160”后花了將近一個小時才准備好繼續進行(我意識到在讓它運行了一整夜之后)。
model 本身絕對不能訓練,之后會拋出 OOM 錯誤,所以我需要重新設計我的 model。 對不起,我錯了。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.