繁体   English   中英

TensorFlow 功能太慢

[英]TensorFlow functional too slow

我创建了一个 TensorFlow model 处理属于一个观察的 50 个不同图像。 所以 A Batch 的形式为(32, 50, 128, 128, 1) model 定义为:

input = layers.Input((50, 128, 128, 1))
sub_models = []
for mcol in range(50):
    x = layers.Conv2D(32, kernel_size=(3, 3), input_shape=(128, 128, 1))(input[:, mcol, :, :])
    x = layers.MaxPool2D(pool_size=(2, 2))(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128)(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(32)(x)
    sub_models.append(x)
combined = tf.keras.layers.concatenate(sub_models)
z = layers.Dense(1024)(combined)
z = layers.Dense(512)(z)
z = layers.Dense(512)(z)
z = layers.Dense(2, activation="softmax")(z)
model = tf.keras.Model(input, z)

model 看起来像这样(输入更少的简单版本): Model

我的火车步骤如下:

with tf.GradientTape() as tape:
    logits = model(x_batch_train[:, :50, :, :, None], training=True)
    loss_value = loss(Y, logits)

问题是训练步骤非常慢,并且在 V100 GPU 上的每一步都需要几秒钟。 我认为问题出在for循环上。 有没有办法以更智能、更省时的方式定义 model?

如果您将数据重新格式化为(32、128、128、50),50 个通道,每个图像一个,您可以使用 groups 关键字参数( https://www.tensorflow.org/api_docs /python/tf/keras/layers/Conv2D#args )

import tensorflow as tf
from tensorflow.keras import layers

input = layers.Input((50, 128, 128, 1))

# Reshape data (50, 128, 128, 1) -> (50, 128, 128)
x = tf.keras.backend.squeeze(input, axis=-1)
# Transpose (50, 128, 128) -> (128, 128, 50)
x = layers.Permute((2, 3, 1), input_shape=(50, 128, 128))(x)

# NOTE! The groups = 50 part is what breaks up the network
x = layers.Conv2D(32 * 50,
      kernel_size=(3, 3), input_shape=(128, 128, 50), groups=50)(x)
# Reshape to max pool 3D
# 126, 126, 50 * 32 -> 126, 126, 50, 32
x = layers.Reshape((126, 126, 50, 32))(x)

x = layers.MaxPool3D(pool_size=(2, 2, 1))(x)
x = layers.Dropout(0.25)(x)

# Change (63, 63, 50, 32) -> (50, 63 * 63 * 32)
x = layers.Permute((3, 1, 2, 4), input_shape=(63, 63, 50, 32))(x)
x = layers.Reshape((50, 63 * 63 * 32))(x)

x = layers.Dense(128)(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(32)(x)

# Join everything together as per the spec
combined = layers.Flatten()(x)

z = layers.Dense(1024)(combined)
z = layers.Dense(512)(z)
z = layers.Dense(512)(z)
z = layers.Dense(2, activation="softmax")(z)
model = tf.keras.Model(input, z)

话虽如此,循环本身不应成为速度瓶颈(确保您实际上是在 GPU 上运行),因为您只是在构建计算图,但这仍然应该加快速度。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM