簡體   English   中英

如何在tf.keras自定義丟失函數中觸發python函數?

[英]How to trigger a python function inside a tf.keras custom loss function?

在我的自定義丟失函數中,我需要調用一個傳遞計算出的TD錯誤和一些索引的純python函數。 該函數不需要返回任何內容或進行區分。 這是我想要調用的函數:

def update_priorities(self, traces_idxs, td_errors):
    """Updates the priorities of the traces with specified indexes."""
    self.priorities[traces_idxs] = td_errors + eps

我已經嘗試使用tf.py_function來調用包裝器函數,但只有在它嵌入圖形中時才會被調用,即如果它有輸入和輸出並且使用了輸出。 因此,我試圖通過一些張量而不對它們執行任何操作,現在函數被調用。 這是我的整個自定義丟失函數:

def masked_q_loss(data, y_pred):
    """Computes the MSE between the Q-values of the actions that were taken and the cumulative
    discounted rewards obtained after taking those actions. Updates trace priorities.
    """
    action_batch, target_qvals, traces_idxs = data[:,0], data[:,1], data[:,2]
    seq = tf.cast(tf.range(0, tf.shape(action_batch)[0]), tf.int32)
    action_idxs = tf.transpose(tf.stack([seq, tf.cast(action_batch, tf.int32)]))
    qvals = tf.gather_nd(y_pred, action_idxs)

    def update_priorities(_qvals, _target_qvals, _traces_idxs):
        """Computes the TD error and updates memory priorities."""
        td_error = _target_qvals - _qvals
        _traces_idxs = tf.cast(_traces_idxs, tf.int32)
        mem.update_priorities(_traces_idxs, td_error)
        return _qvals

    qvals = tf.py_function(func=update_priorities, inp=[qvals, target_qvals, traces_idxs], Tout=[tf.float32])
    return tf.keras.losses.mse(qvals, target_qvals)

但是由於調用mem.update_priorities(_traces_idxs, td_error)我收到以下錯誤

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

我不需要為update_priorities計算漸變,我只想在圖形計算中的特定點調用它而忘記它。 我怎樣才能做到這一點?

在包裝函數內的張量上使用.numpy()修復了問題:

def update_priorities(_qvals, _target_qvals, _traces_idxs):
    """Computes the TD error and updates memory priorities."""
    td_error = np.abs((_target_qvals - _qvals).numpy())
    _traces_idxs = (tf.cast(_traces_idxs, tf.int32)).numpy()
    mem.update_priorities(_traces_idxs, td_error)
    return _qvals

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM