簡體   English   中英

如何在 Keras/TensorFlow 中可視化 RNN/LSTM 梯度?

[英]How to visualize RNN/LSTM gradients in Keras/TensorFlow?

我遇到過研究出版物和問答討論需要檢查每個反向傳播時間 (BPTT) 的 RNN 梯度 - 即每個時間步長的梯度。 主要用途是自省:我們如何知道 RNN 是否正在學習長期依賴 一個自己主題的問題,但最重要的見解是梯度流

  • 如果非零梯度流經每個時間步,那么每個時間步都有助於學習——即,最終梯度源於對每個輸入時間步的考慮,因此整個序列會影響權重更新
  • 如上所述,RNN不再忽略長序列的一部分,而是被迫向它們學習

...但是我如何在 Keras / TensorFlow 中實際可視化這些梯度? 一些相關的答案是在正確的方向上,但它們似乎對雙向 RNN 失敗了,並且只展示了如何獲得層的梯度,而不是如何有意義地可視化它們(輸出是一個 3D 張量 - 我該如何繪制它?)

可以通過權重輸出獲取梯度——我們將需要后者。 此外,為了獲得最佳結果,需要進行特定於架構的處理。 下面的代碼和解釋涵蓋了 Keras/TF RNN 的所有可能情況,並且應該可以輕松擴展到任何未來的 API 更改。


完整性:顯示的代碼是一個簡化版本 - 完整版本可以在我的存儲庫中找到,參見 RNN (這篇文章包含更大的圖像); 包括:

  • 更大的視覺可定制性
  • 解釋所有功能的文檔字符串
  • 支持 Eager、Graph、TF1、TF2,以及from kerasfrom tf.keras
  • 激活可視化
  • 權重梯度可視化(即將推出)
  • 權重可視化(即將推出)

I/O 維度(所有 RNN):

  • 輸入: (batch_size, timesteps, channels) - 或者,等效地, (samples, timesteps, features)
  • 輸出:與輸入相同,除了:
    • channels / features現在是 RNN 單元的數量,並且:
    • return_sequences=True --> timesteps_out = timesteps_in (輸出每個輸入時間步的預測)
    • return_sequences=False --> timesteps_out = 1 (僅在處理的最后一個時間步輸出預測)

可視化方法

  • 一維繪圖網格:繪制每個通道的梯度與時間步長
  • 2D 熱圖:使用梯度強度熱圖繪制通道與時間步長
  • 0D 對齊散射:為每個樣本繪制每個通道的梯度
  • 直方圖:沒有表示“與時間步長”關系的好方法
  • 一個樣本:對單個樣本執行上述每一項
  • 整批:對一批中的所有樣品進行上述每項操作; 需要仔細治療
# for below examples
grads = get_rnn_gradients(model, x, y, layer_idx=1) # return_sequences=True
grads = get_rnn_gradients(model, x, y, layer_idx=2) # return_sequences=False

EX 1:一個樣本,uni-LSTM,6 個單位-- return_sequences=True ,訓練了 20 次迭代
show_features_1D(grads[0], n_rows=2)

  • 注意:梯度是從右到左讀取的,因為它們是計算的(從最后一個時間步到第一個)
  • 最右邊(最新)的時間步長始終具有更高的梯度
  • 消失梯度:約 75% 的最左側時間步的梯度為零,表明時間依賴性學習不佳

在此處輸入圖片說明


EX 2:所有 (16) 個樣本,uni-LSTM,6 個單位-- return_sequences=True ,訓練 20 次迭代
show_features_1D(grads, n_rows=2)
show_features_2D(grads, n_rows=4, norm=(-.01, .01))

  • 每個樣本以不同的顏色顯示(但跨通道每個樣本的顏色相同)
  • 一些樣本的表現比上面顯示的要好,但相差不大
  • 熱圖繪制通道(y 軸)與時間步長(x 軸); 藍色=-0.01,紅色=0.01,白色=0(梯度值)

在此處輸入圖片說明 在此處輸入圖片說明


EX 3:所有 (16) 個樣本,uni-LSTM,6 個單位-- return_sequences=True ,訓練 200 次迭代
show_features_1D(grads, n_rows=2)
show_features_2D(grads, n_rows=4, norm=(-.01, .01))

  • 兩個圖都顯示 LSTM 在 180 次額外迭代后表現明顯更好
  • 梯度仍然消失了大約一半的時間步長
  • 所有 LSTM 單元都能更好地捕捉一個特定樣本(藍色曲線,所有圖)的時間依賴性——我們可以從熱圖中看出它是第一個樣本。 我們可以繪制該樣本與其他樣本的圖以嘗試了解差異

在此處輸入圖片說明 在此處輸入圖片說明


EX 4:2D 與 1D,uni-LSTM :256 個單位, return_sequences=True ,訓練 200 次迭代
show_features_1D(grads[0])
show_features_2D(grads[:, :, 0], norm=(-.0001, .0001))

  • 2D 更適合比較少量樣本中的多個通道
  • 1D 更適合比較多個通道中的多個樣本

在此處輸入圖片說明


EX 5:bi-GRU,256 個單位(總共 512 個) —— return_sequences=True ,訓練了 400 次迭代
show_features_2D(grads[0], norm=(-.0001, .0001), reflect_half=True)

  • 向后層的梯度被翻轉以保持與時間軸的一致性
  • 繪圖揭示了 Bi-RNN 一個鮮為人知的優勢——信息效用:集體梯度覆蓋了大約兩倍的數據。 然而,這不是免費的午餐:每一層都是一個獨立的特征提取器,所以學習並不是真正的補充
  • 預計更多單位的norm較低,大約為相同的損失派生梯度分布在更多參數上(因此平方數字平均值較小)


EX 6: 0D, all (16) samples, uni-LSTM, 6 units -- return_sequences=False , 訓練了 200 次迭代
show_features_0D(grads)

  • return_sequences=False僅利用最后一個時間步的梯度(它仍然來自所有時間步,除非使用截斷的 BPTT),需要一種新的方法
  • 在樣本中對每個 RNN 單元進行一致的顏色編碼以進行比較(可以使用一種顏色代替)
  • 評估梯度流不那么直接,而且在理論上涉及更多。 一種簡單的方法是比較訓練初期和后期的分布:如果差異不顯着,則 RNN 在學習長期依賴關系方面表現不佳


EX 7: LSTM vs. GRU vs. SimpleRNN, unidir, 256 units -- return_sequences=True , 訓練了 250 次迭代
show_features_2D(grads, n_rows=8, norm=(-.0001, .0001), show_xy_ticks=[0,0], show_title=False)

  • 注意:比較意義不大; 每個網絡都使用不同的超參數蓬勃發展,而所有網絡都使用相同的超參數。 LSTM,一方面,每單位承擔最多的參數,淹沒了 SimpleRNN
  • 在這個設置中,LSTM 最終擊敗了 GRU 和 SimpleRNN

在此處輸入圖片說明


可視化功能

def get_rnn_gradients(model, input_data, labels, layer_idx=None, layer_name=None, 
                      sample_weights=None):
    if layer is None:
        layer = _get_layer(model, layer_idx, layer_name)

    grads_fn = _make_grads_fn(model, layer, mode)
    sample_weights = sample_weights or np.ones(len(input_data))
    grads = grads_fn([input_data, sample_weights, labels, 1])

    while type(grads) == list:
        grads = grads[0]
    return grads

def _make_grads_fn(model, layer):
    grads = model.optimizer.get_gradients(model.total_loss, layer.output)
    return K.function(inputs=[model.inputs[0],  model.sample_weights[0],
                              model._feed_targets[0], K.learning_phase()], outputs=grads) 

def _get_layer(model, layer_idx=None, layer_name=None):
    if layer_idx is not None:
        return model.layers[layer_idx]

    layer = [layer for layer in model.layers if layer_name in layer.name]
    if len(layer) > 1:
        print("WARNING: multiple matching layer names found; "
              + "picking earliest")
    return layer[0]


def show_features_1D(data, n_rows=None, label_channels=True,
                     equate_axes=True, max_timesteps=None, color=None,
                     show_title=True, show_borders=True, show_xy_ticks=[1,1], 
                     title_fontsize=14, channel_axis=-1, 
                     scale_width=1, scale_height=1, dpi=76):
    def _get_title(data, show_title):
        if len(data.shape)==3:
            return "((Gradients vs. Timesteps) vs. Samples) vs. Channels"
        else:        
            return "((Gradients vs. Timesteps) vs. Channels"

    def _get_feature_outputs(data, subplot_idx):
        if len(data.shape)==3:
            feature_outputs = []
            for entry in data:
                feature_outputs.append(entry[:, subplot_idx-1][:max_timesteps])
            return feature_outputs
        else:
            return [data[:, subplot_idx-1][:max_timesteps]]

    if len(data.shape)!=2 and len(data.shape)!=3:
        raise Exception("`data` must be 2D or 3D")

    if len(data.shape)==3:
        n_features = data[0].shape[channel_axis]
    else:
        n_features = data.shape[channel_axis]
    n_cols = int(n_features / n_rows)

    if color is None:
        n_colors = len(data) if len(data.shape)==3 else 1
        color = [None] * n_colors

    fig, axes = plt.subplots(n_rows, n_cols, sharey=equate_axes, dpi=dpi)
    axes = np.asarray(axes)

    if show_title:
        title = _get_title(data, show_title)
        plt.suptitle(title, weight='bold', fontsize=title_fontsize)
    fig.set_size_inches(12*scale_width, 8*scale_height)

    for ax_idx, ax in enumerate(axes.flat):
        feature_outputs = _get_feature_outputs(data, ax_idx)
        for idx, feature_output in enumerate(feature_outputs):
            ax.plot(feature_output, color=color[idx])

        ax.axis(xmin=0, xmax=len(feature_outputs[0]))
        if not show_xy_ticks[0]:
            ax.set_xticks([])
        if not show_xy_ticks[1]:
            ax.set_yticks([])
        if label_channels:
            ax.annotate(str(ax_idx), weight='bold',
                        color='g', xycoords='axes fraction',
                        fontsize=16, xy=(.03, .9))
        if not show_borders:
            ax.set_frame_on(False)

    if equate_axes:
        y_new = []
        for row_axis in axes:
            y_new += [np.max(np.abs([col_axis.get_ylim() for
                                     col_axis in row_axis]))]
        y_new = np.max(y_new)
        for row_axis in axes:
            [col_axis.set_ylim(-y_new, y_new) for col_axis in row_axis]
    plt.show()


def show_features_2D(data, n_rows=None, norm=None, cmap='bwr', reflect_half=False,
                     timesteps_xaxis=True, max_timesteps=None, show_title=True,
                     show_colorbar=False, show_borders=True, 
                     title_fontsize=14, show_xy_ticks=[1,1],
                     scale_width=1, scale_height=1, dpi=76):
    def _get_title(data, show_title, timesteps_xaxis, vmin, vmax):
        if timesteps_xaxis:
            context_order = "(Channels vs. %s)" % "Timesteps"
        if len(data.shape)==3:
            extra_dim = ") vs. Samples"
            context_order = "(" + context_order
        return "{} vs. {}{} -- norm=({}, {})".format(context_order, "Timesteps",
                                                     extra_dim, vmin, vmax)

    vmin, vmax = norm or (None, None)
    n_samples = len(data) if len(data.shape)==3 else 1
    n_cols = int(n_samples / n_rows)

    fig, axes = plt.subplots(n_rows, n_cols, dpi=dpi)
    axes = np.asarray(axes)

    if show_title:
        title = _get_title(data, show_title, timesteps_xaxis, vmin, vmax)
        plt.suptitle(title, weight='bold', fontsize=title_fontsize)

    for ax_idx, ax in enumerate(axes.flat):
        img = ax.imshow(data[ax_idx], cmap=cmap, vmin=vmin, vmax=vmax)
        if not show_xy_ticks[0]:
            ax.set_xticks([])
        if not show_xy_ticks[1]:
            ax.set_yticks([])
        ax.axis('tight')
        if not show_borders:
            ax.set_frame_on(False)

    if show_colorbar:
        fig.colorbar(img, ax=axes.ravel().tolist())

    plt.gcf().set_size_inches(8*scale_width, 8*scale_height)
    plt.show()


def show_features_0D(data, marker='o', cmap='bwr', color=None,
                     show_y_zero=True, show_borders=False, show_title=True,
                     title_fontsize=14, markersize=15, markerwidth=2,
                     channel_axis=-1, scale_width=1, scale_height=1):
    if color is None:
        cmap = cm.get_cmap(cmap)
        cmap_grad = np.linspace(0, 256, len(data[0])).astype('int32')
        color = cmap(cmap_grad)
        color = np.vstack([color] * data.shape[0])
    x = np.ones(data.shape) * np.expand_dims(np.arange(1, len(data) + 1), -1)

    if show_y_zero:
        plt.axhline(0, color='k', linewidth=1)
    plt.scatter(x.flatten(), data.flatten(), marker=marker,
                s=markersize, linewidth=markerwidth, color=color)
    plt.gca().set_xticks(np.arange(1, len(data) + 1), minor=True)
    plt.gca().tick_params(which='minor', length=4)

    if show_title:
        plt.title("(Gradients vs. Samples) vs. Channels",
                  weight='bold', fontsize=title_fontsize)
    if not show_borders:
        plt.box(None)
    plt.gcf().set_size_inches(12*scale_width, 4*scale_height)
    plt.show()

完整的最小示例:請參閱存儲庫的自述文件


獎金代碼

  • 如何在不閱讀源代碼的情況下檢查重量/門排序?
rnn_cell = model.layers[1].cell          # unidirectional
rnn_cell = model.layers[1].forward_layer # bidirectional; also `backward_layer`
print(rnn_cell.__dict__)

更方便的代碼見repo的rnn_summary


額外的事實:如果您在GRU上運行,您可能會注意到bias沒有門; 為什么這樣? 文檔

有兩種變體。 默認的基於 1406.1078v3,並在矩陣乘法之前將重置門應用於隱藏狀態。 另一個是基於原來的 1406.1078v1 並且順序顛倒了。

第二個變體與 CuDNNGRU(僅限 GPU)兼容,並允許在 CPU 上進行推理。 因此它對 kernel 和 recurrent_kernel 有不同的偏差。 使用 'reset_after'=True 和 recurrent_activation='sigmoid'。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM