简体   繁体   中英

Printing inside jupyter notebook custom loss function with Keras/TF

In Keras, if you make a custom loss function in a Jupyter notebook, you can not print anything. For instance if you have:

def loss_func(true_label, NN_output):
        true_cat = true_label[:,0]
        pred_cat = NN_output[:,0]
        indicator = NN_output[:,1]
        print("Hi!")
        custom_term = K.mean(K.abs(indicator))
        return binary_crossentropy(true_cat, pred_cat) + custom_term

Nothing will print when the function is evaluated.

As a workaround, in case I am doing some debugging, I have found that I can write to a file in a cost function, which can be useful if I want to print something standard like an int or a string.

However, trying to write out a tensor like indicator to a file gives the unbelievably helpful output:

Tensor("loss_103/model_105_loss/Print:0", shape=(512,), dtype=float32)

I know TF provides a tf.Print() method to print the value of a tensor, but I don't understand how that plays with Jupyter. Other answers have said that tf.Print() writes to std. err, which means trying

sys.stderr = open('test.txt', 'w')

should theoretically allow me to get my output from a file, but unfortunately this doesn't work (at least in Jupyter).

Is there any general method to get a representation of my tensor as a string? How do people generally get around this barrier to seeing what your code does? If I come up with something more fancy than finding a mean, I want to see exactly what's going on in the steps of my calculation to verify it works as intended.

Thanks!

You can do something like the below code:

def loss_func(true_label, NN_output):
    true_cat = true_label[:,0]
    true_cat = tf.Print(true_cat, [true_cat], message="true_cat: ") # added line
    pred_cat = NN_output[:,0]
    pred_cat = tf.Print(pred_cat, [pred_cat], message="pred_cat: ") # added line
    indicator = NN_output[:,1]
    custom_term = K.mean(K.abs(indicator))
    return binary_crossentropy(true_cat, pred_cat) + custom_term

Basically I have added two lines to print the values of true_cat, pred_cat. To print something, you have to include the print statement in the tf graph by the above statements.
However, the trick is it's going to print on jupyter notebook console where you're running the notebook not on the ipython notebook itself.

References:

How to print the value of a Tensor object in TensorFlow?

Printing the loss during TensorFlow training

https://www.tensorflow.org/api_docs/python/tf/Print

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM