简体   繁体   中英

How do I access Tensor values (e.g. Metrics) which are updated within a tf.function?

I have been working on a model whose training loop uses a tf.function wrapper (I get OOM errors when running eagerly), and training seems to be running fine. However, I am not able to access the tensor values returned by my custom training function (below)

def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        curr_loss = lovasz_softmax_flat(predictions, target)

    gradients = tape.gradient(curr_loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    
    # Need to access this value
    return curr_loss

A simplified version of my 'umbrella' training loop is as follows:

@tf.function
def train_loop():
for epoch in range(EPOCHS):
        for tr_file in train_files:

            tr_inputs = preprocess(tr_file)
            
            tr_loss = train_step(tr_inputs, target)
            print(tr_loss.numpy())
            

When I do try to print out the loss value, I end up with the following error:

AttributeError: 'Tensor' object has no attribute 'numpy'

I also tried using tf.print() as follows:

tf.print("Loss: ", tr_loss, output_stream=sys.stdout)

But nothing seems to appear on the terminal. Any suggestions?

You can't convert to Numpy array in graph mode. Just create a tf.metrics object outside of the function, and update it in the function.

mean_loss_values = tf.metrics.Mean()

def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        curr_loss = lovasz_softmax_flat(predictions, target)

    gradients = tape.gradient(curr_loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))

    # look below
    mean_loss_values(curr_loss)
    # or mean_loss_values.update_state(curr_loss)
    
    # Need to access this value
    return curr_loss

Then later in your code:

mean_loss_values.result()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM