简体   繁体   English

TensorFlow-keras 2.X 多 GPU 预测

[英]Tensorflow-keras 2.X multi gpu prediction

I have 4GPU(rtx 3090) in one pc.我在一台电脑上有 4GPU(rtx 3090)。

I used only 1GPU for training and prediction, but now I'm going to use 4GPU.我只使用 1GPU 进行训练和预测,但现在我将使用 4GPU。

During training, 4gpu activation was successful, but only 1GPU is active for prediction.在训练期间,4gpu 激活成功,但只有 1GPU 用于预测。

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
    model = MyModel()

model_checkpoint = ModelCheckpoint(model_path, monitor='loss', verbose=1, save_best_only=True)
history = model.fit(trainData, steps_per_epoch=num_images//batch_size, verbose=1)

It Use 4 GPU它使用 4 个 GPU

model.predict(testData, verbose=1)
#for x in testData:
#   model.predict_on_batch(x)

It use only 1 GPU它仅使用 1 个 GPU

How can I use my all GPU?如何使用我的所有 GPU?

I solved!我解决了! changed my test dataset!改变了我的测试数据集!

I'm using test dataset with generator.我正在使用带有生成器的测试数据集。

def testGenerator(image_path):
for file in image_path:
    img = file
    img = img / 255
    img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img 
    yield img

I wanted to load the dataset into GPU memory.我想将数据集加载到 GPU 内存中。

So I used "tf.data.Dataset.from_generator" and my GPU worked!所以我使用了“tf.data.Dataset.from_generator”并且我的 GPU 工作了!

Load saved model加载保存的模型

model = tf.keras.models.load_model(model_path)

Make tensorflow dataset制作张量流数据集

testGene = testGenerator(img_data)
test_dataset = tf.data.Dataset.from_generator(
    lambda: testGene,
    output_types = tf.float64,
    #output_shapes = tf.TensorShape([None]),
    output_shapes = tf.TensorShape([512, 512, 1])
    ).batch(global_batch_size)

result = model.predict(
    test_dataset,
    verbose=1,
    steps=math.ceil(max / global_batch_size),
    callbacks=[CustomCallback(root_file_list, int(max / len(root_file_list) / global_batch_size))])

It worked using my all GPU!它使用我的所有 GPU 工作!

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM