簡體   English   中英

Keras、Tensorflow:評估時如何在自定義層中設置斷點(調試)?

[英]Keras, Tensorflow: How to set breakpoint (debug) in custom layer when evaluating?

我只想在自定義圖層內進行一些數字驗證。

假設我們有一個非常簡單的自定義層:

class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # Set break point here
        n = self.w * K.sqrt(x)
        return m + n

和主程序:

import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

如果我在m = x * x行設置斷點調試,則程序在執行y = test_layer()(input)時會在此處暫停,這是因為構建了圖形, call()call()方法。

但是當我使用model.predict()來賦予它真正的價值,並且想看看它是否正常工作時,它不會停在m = x * x

我的問題是:

  1. call()方法是否僅在構建計算圖時call() (喂實值時不會被調用?)

  2. 如何在層內調試(或插入斷點的位置)以在給它實際值輸入時查看變量的值?

在 TensorFlow 2 中,您現在可以向 TensorFlow Keras 模型/層添加斷點,包括在使用擬合、評估和預測方法時。 但是,您必須調用model.compile()之后添加model.run_eagerly = True以使張量的值在斷點處的調試器中可用。 例如,

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


class SimpleModel(Model):

    def __init__(self):
        super().__init__()
        self.dense0 = Dense(2)
        self.dense1 = Dense(1)

    def call(self, inputs):
        z = self.dense0(inputs)
        z = self.dense1(z)  # Breakpoint in IDE here. =====
        return z

x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)

model0 = SimpleModel()
y0 = model0.call(x)  # Values of z shown at breakpoint. =====

model1 = SimpleModel()
model1.run_eagerly = True
model1.compile(optimizer=Adam(), loss=BinaryCrossentropy())
y1 = model1.predict(x)  # Values of z *not* shown at breakpoint. =====

model2 = SimpleModel()
model2.compile(optimizer=Adam(), loss=BinaryCrossentropy())
model2.run_eagerly = True
y2 = model2.predict(x)  # Values of z shown at breakpoint. =====

注意:這是在 TensorFlow 2.0.0-rc0測試的。

  1. 是的。 call()方法僅用於構建計算圖。

  2. 至於調試。 我更喜歡使用TFDBG ,這是TFDBG的推薦調試工具,盡管它不提供斷點功能。

對於 Keras,您可以將這些行添加到您的腳本中以使用 TFDBG

import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)

您可以參考推薦的方法來調試AutoGraph代碼

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM