简体   繁体   中英

tf.keras manual device placement

Migrating to the TF2.0 I'm trying to use the tf.keras approach for solving things. In standard TF, I can use with tf.device(...) to control where ops are.

For example, I might have a model something like


model = tf.keras.Sequential([tf.keras.layers.Input(..),
                             tf.keras.layers.Embedding(...),
                             tf.keras.layers.LSTM(...),
                             ...])

Assuming I want to have the network up until Embedding (including) on the CPU and the and from there on on the GPU, how will I go about that? (This is just an example, the layers could have nothing to do with embeddings)

If the solution involves subclassing tf.keras.Model that is OK too, I don't mind not using Sequential

You can use the Keras functional API:

inputs = tf.keras.layers.Input(..)
with tf.device("/GPU:0"):
    model = tf.keras.layers.Embedding(...)(inputs)
outputs = tf.keras.layers.LSTM(...)(model)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM