[英]Why my training speed in Keras with multi_gpu_model is worse than single gpu?
我的Keras版本是2.0.9,並使用tensorflow后端。
我試圖在keras中實現multi_gpu_model 。 但是,在實踐中使用4 gpu進行訓練甚至比1 gpu還要糟糕。 1 gpu的時間為25秒,4 gpu的時間為50秒。 你能告訴我為什么會這樣嗎?
/ log for multi_gpu_model
我用這個推薦1 gpu
CUDA_VISIBLE_DEVICES=0 python gpu_test.py
4 gpus,
python gpu_test.py
-這里是培訓的源代碼。
from keras.datasets import mnist
from keras.layers import Input, Dense, merge
from keras.layers.core import Lambda
from keras.models import Model
from keras.utils import to_categorical
from keras.utils.training_utils import multi_gpu_model
import time
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
inputs = Input(shape=(784,))
x = Dense(4096, activation='relu')(inputs)
x = Dense(2048, activation='relu')(x)
x = Dense(512, activation='relu')(x)
x = Dense(64, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=predictions)
'''
m_model = multi_gpu_model(model, 4)
m_model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
m_model.summary()
a=time.time()
m_model.fit(x_train, y_train, batch_size=128, epochs=5)
print time.time() - a
a=time.time()
m_model.predict(x=x_test, batch_size=128)
print time.time() - a
'''
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
a=time.time()
model.fit(x_train, y_train, batch_size=128, epochs=5)
print time.time() - a
a=time.time()
model.predict(x=x_test, batch_size=128)
print time.time() - a
我可以給您我認為的答案,但是我自己無法完全解決問題。 一個錯誤報告提示了我這一點,但是在multi_gpu_model的源代碼中它說:
# Instantiate the base model (or "template" model).
# We recommend doing this with under a CPU device scope,
# so that the model's weights are hosted on CPU memory.
# Otherwise they may end up hosted on a GPU, which would
# complicate weight sharing.
with tf.device('/cpu:0'):
model = Xception(weights=None,
input_shape=(height, width, 3),
classes=num_classes)
我認為這是問題所在。 不過,我仍在努力使它自己工作。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.