[英]Modify Tensorflow Code to place preprocessing on CPU and training on GPU
[英]Tensorflow 2.0-GPU Windows is running training code on the CPU
目前我在Tensorflow 2.0上使用。 这是我的电脑:
我的项目是一个使用ResNet50和CIFAR100 Dataset的图像分类项目。
我使用子类构建网络(代码片段太长,所以我没有在这个问题上附加它)并使用tf.data.Dataset.from_tensor_slices
加载数据:
def load_cifar100(batch_size, num_classes=100):
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
x_train, x_test = x_train.astype('float32') / 255, x_test.astype('float32') / 255
x_val, y_val = x_train[-10000:], y_train[-10000:]
x_train, y_train = x_train[:-10000], y_train[:-10000]
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
y_val = to_categorical(y_val, num_classes)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)
return train_dataset, val_dataset, test_dataset
我使用GradientTape
来设置我的训练过程:
def training(x_batch, y_batch):
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss_val = loss(y_batch, logits)
grads = tape.gradient(loss_val, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric(y_batch, logits)
for epoch in range(epochs):
train_acc_metric.reset_states()
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
training(x_batch_train, y_batch_train)
train_acc = train_acc_metric.result()
template = 'Epoch {}, Train_Acc: {}'
print(template.format(epoch + 1,
train_acc))
在训练过程中,我看到我的GPU
根本无法工作 [Pic1],即使我打开了调试配置tf.debugging.set_log_device_placement(True)
并且所有的层似乎都在,但所有的训练过程都被放入了CPU
加载到GPU
[Pic2]。
更新:
这是当我更改为model.fit
函数时 GPU 的样子。 每个 epoch 的训练时间比GradientTape
快得多:
在开始训练过程之前检查 Tensorflow (TF2) 是否使用 GPU 会很有帮助,方法是:
assert tf.test.is_gpu_available()
assert tf.test.is_built_with_cuda()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.