简体   繁体   English

Keras、Tensorflow:评估时如何在自定义层中设置断点(调试)?

[英]Keras, Tensorflow: How to set breakpoint (debug) in custom layer when evaluating?

I just want to do some numerical validation inside the custom layer.我只想在自定义图层内进行一些数字验证。

Suppose we have a very simple custom layer:假设我们有一个非常简单的自定义层:

class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # Set break point here
        n = self.w * K.sqrt(x)
        return m + n

And the main program:和主程序:

import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

If I set a breakpoint debug at the line m = x * x , the program will pause here when executing y = test_layer()(input) , this is because the graph is built, the call() method is called.如果我在m = x * x行设置断点调试,则程序在执行y = test_layer()(input)时会在此处暂停,这是因为构建了图形, call()call()方法。

But when I use model.predict() to give it real value, and wanna look inside the layer if it work properly, it doesn't pause at the line m = x * x但是当我使用model.predict()来赋予它真正的价值,并且想看看它是否正常工作时,它不会停在m = x * x

My question is:我的问题是:

  1. Is call() method only called when the computational graph is being built? call()方法是否仅在构建计算图时call() (it won't be called when feeding real value?) (喂实值时不会被调用?)

  2. How to debug (or where to insert break point) inside a layer to see the value of variables when give it real value input?如何在层内调试(或插入断点的位置)以在给它实际值输入时查看变量的值?

In TensorFlow 2, you can now add breakpoints to the TensorFlow Keras models/layers, including when using the fit, evaluate, and predict methods.在 TensorFlow 2 中,您现在可以向 TensorFlow Keras 模型/层添加断点,包括在使用拟合、评估和预测方法时。 However, you must add model.run_eagerly = True after calling model.compile() for the values of the tensor to be available in the debugger at the breakpoint.但是,您必须调用model.compile()之后添加model.run_eagerly = True以使张量的值在断点处的调试器中可用。 For example,例如,

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


class SimpleModel(Model):

    def __init__(self):
        super().__init__()
        self.dense0 = Dense(2)
        self.dense1 = Dense(1)

    def call(self, inputs):
        z = self.dense0(inputs)
        z = self.dense1(z)  # Breakpoint in IDE here. =====
        return z

x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)

model0 = SimpleModel()
y0 = model0.call(x)  # Values of z shown at breakpoint. =====

model1 = SimpleModel()
model1.run_eagerly = True
model1.compile(optimizer=Adam(), loss=BinaryCrossentropy())
y1 = model1.predict(x)  # Values of z *not* shown at breakpoint. =====

model2 = SimpleModel()
model2.compile(optimizer=Adam(), loss=BinaryCrossentropy())
model2.run_eagerly = True
y2 = model2.predict(x)  # Values of z shown at breakpoint. =====

Note: this was tested in TensorFlow 2.0.0-rc0 .注意:这是在 TensorFlow 2.0.0-rc0测试的。

  1. Yes.是的。 The call() method is only used to build the computational graph. call()方法仅用于构建计算图。

  2. As to the debug.至于调试。 I prefer using TFDBG , which is a recommended debugging tool for tensorflow, although it doesn't provide break point functions.我更喜欢使用TFDBG ,这是TFDBG的推荐调试工具,尽管它不提供断点功能。

For Keras, you can add these line to your script to use TFDBG对于 Keras,您可以将这些行添加到您的脚本中以使用 TFDBG

import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)

您可以参考推荐的方法来调试AutoGraph代码

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM