简体   繁体   中英

Optimizer problem with tensorflows distribution strategy and tf.keras models

When using the distributed strategy MirroredStrategy, we experience a poor validation accuracy. Running the training without the distribution strategy on a single GPU, the training and validation accuracy are above 95%.

The problem arises with the tf.keras.resnet50 model. With "self-built" small CNNs, the distribution strategy works fine.

It seems like the optimizer has problems with tf.keras models.

Does anyone have an idea what could be the problem and how to solve it? We ran out of ideas.

General setup

  • CUDA 10.0
  • tf-nightly-gpu 2.1.0.dev20191029
  • 2x RTX 2080 Ti
  • custom grayscale images (270, 270) with a well tested input pipeline based on tf.data.Dataset. Small CNNs are performing above 95% accuracy.

What we already tried and led to similar behaviour:

  • self-built tf2.0 with CUDA10.1
  • pip package tensorflow-gpu (V2.0.0) with CUDA 10.0
  • different optimizers

Setup A:

ResNet50 on a single gpu performing above 95% validation accuray.

Setup B:

ResNet50 in the MirroredStrategy scope with 2 GPUs performing less than 70% validation accuracy.

Batch sizes are a multiple of two since we use two GPUs.

Code:

df_trainset[target_label] = df_trainset[target_label].astype('int')
df_validset[target_label] = df_validset[target_label].astype('int')
df_testset[target_label] = df_testset[target_label].astype('int')

list_labels_train = df_trainset[target_label]
list_paths_train = df_trainset['sample_path']

list_labels_valid = df_validset[target_label]
list_paths_valid = df_validset['sample_path']

list_labels_test = df_testset[target_label]
list_paths_test = df_testset['sample_path']

def parse_img(label, path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=1, dtype=tf.uint8)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img, label

BATCH_SIZE = 32

ds_train = tf.data.Dataset.from_tensor_slices((list_labels_train,
                                               list_paths_train))

ds_train = ds_train.map(parse_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(buffer_size=len(list_paths_train), seed=42,
                            reshuffle_each_iteration=True)
ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True).repeat()
ds_train = ds_train.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

train_steps = np.ceil(len(list_paths_train) / BATCH_SIZE)

# valid
ds_valid = tf.data.Dataset.from_tensor_slices((list_labels_valid,
                                               list_paths_valid))
ds_valid = ds_valid.map(parse_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_valid = ds_valid.cache()
ds_valid = ds_valid.shuffle(buffer_size=len(list_paths_valid), seed=42,
                            reshuffle_each_iteration=True)
ds_valid = ds_valid.batch(batch_size=BATCH_SIZE, drop_remainder=True).repeat()
ds_valid = ds_valid.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

valid_steps = np.ceil(len(list_paths_valid) / BATCH_SIZE)


strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = ResNet50(include_top=True,
                     weights=None,
                     input_tensor=None,
                     input_shape=(270, 270, 1),
                     pooling=None,
                     classes=3)

    model.compile(optimizer=tf.optimizers.Adam(),
                  loss='sparse_categorical_crossentropy',
                  metrics=["accuracy"],
                  )

model.summary()


history = model.fit(
        x=ds_train,
        epochs=10,
        verbose=1,
        validation_data=ds_valid,
        steps_per_epoch=train_steps,
        validation_steps=valid_steps,
        use_multiprocessing=False)

Results

Training on one device, without distribution strategy, everything looks fine.

Full console-dump on pastebin

563/563 [==============================] - 141s 250ms/step - loss: 0.1185 - accuracy: 0.9616 - val_loss: 0.5751 - val_accuracy: 0.8078
Epoch 2/10
563/563 [==============================] - 130s 231ms/step - loss: 0.0400 - accuracy: 0.9865 - val_loss: 0.8953 - val_accuracy: 0.7119
Epoch 3/10
563/563 [==============================] - 130s 231ms/step - loss: 0.0478 - accuracy: 0.9870 - val_loss: 25.3537 - val_accuracy: 0.3367
Epoch 4/10
563/563 [==============================] - 130s 230ms/step - loss: 0.0309 - accuracy: 0.9906 - val_loss: 0.0576 - val_accuracy: 0.9946
Epoch 5/10
563/563 [==============================] - 129s 230ms/step - loss: 0.0210 - accuracy: 0.9940 - val_loss: 0.0780 - val_accuracy: 0.9916
Epoch 6/10
563/563 [==============================] - 130s 230ms/step - loss: 0.0227 - accuracy: 0.9937 - val_loss: 0.0595 - val_accuracy: 0.9887
Epoch 7/10
563/563 [==============================] - 129s 230ms/step - loss: 0.0160 - accuracy: 0.9949 - val_loss: 0.0536 - val_accuracy: 0.9946
Epoch 8/10
 81/563 [===>..........................] - ETA: 1:39 - loss: 0.0222 - accuracy: 0.9945

Training with distribution strategy

Full console-dump on pastebin

563/563 [==============================] - 119s 211ms/step - loss: 1.0535 - accuracy: 0.5099 - val_loss: 1.0735 - val_accuracy: 0.6682
Epoch 2/10
563/563 [==============================] - 95s 169ms/step - loss: 1.0123 - accuracy: 0.5277 - val_loss: 1.0721 - val_accuracy: 0.6682
Epoch 3/10
563/563 [==============================] - 95s 169ms/step - loss: 1.0121 - accuracy: 0.5277 - val_loss: 1.0709 - val_accuracy: 0.6682
Epoch 4/10
563/563 [==============================] - 95s 169ms/step - loss: 1.0124 - accuracy: 0.5277 - val_loss: 1.0667 - val_accuracy: 0.6682
Epoch 5/10
563/563 [==============================] - 95s 169ms/step - loss: 1.0121 - accuracy: 0.5277 - val_loss: 1.0687 - val_accuracy: 0.6682
Epoch 6/10
563/563 [==============================] - 95s 168ms/step - loss: 1.0125 - accuracy: 0.5277 - val_loss: 1.0638 - val_accuracy: 0.6682
Epoch 7/10
563/563 [==============================] - 94s 167ms/step - loss: 1.0125 - accuracy: 0.5277 - val_loss: 1.0639 - val_accuracy: 0.6682
Epoch 8/10
400/563 [====================>.........] - ETA: 24s - loss: 1.0135 - accuracy: 0.5268

I actually solved this error by increasing the max_queue_size argument in model.fit()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM