簡體   English   中英

在 keras 中可視化訓練神經網絡的權重

[英]Visualizing weights of trained neural network in keras

嗨,我訓練了一個卷積層為 96*96*32 的自動編碼器網絡

現在我得到了名為自動編碼器的 model 的權重

layer=autoencoder.layers[1]
W=layer.get_weights()

由於 w 是一個列表,請幫助我對其元素進行排序並可視化訓練過的內核。 我猜它應該是 32 個內核,大小為 96×96

當我輸入

len(w)

它給了 2 所以我有 2 arrays

頂部數組有 9 個子數組,每個子數組有 32 個數字。最后一個數組有 32 個元素。 所以一定是偏見


[array([[[[-6.56146603e-03, -1.51752336e-02, -3.76937017e-02,
           -4.55160812e-03,  1.26366820e-02, -2.97747254e-02,
            3.76312323e-02, -1.56892575e-02,  2.03932393e-02,
            3.29606095e-03,  3.76580656e-02,  6.99581252e-03,
           -4.97130565e-02,  3.63005586e-02,  3.70187908e-02,
            2.63699284e-03,  4.42482866e-02,  8.26128479e-03,
            3.44854854e-02,  1.94760375e-02,  3.91177870e-02,
           -6.67006942e-03,  5.64308763e-02, -1.55166145e-02,
           -3.46037326e-03, -3.14556211e-02, -2.31548538e-03,
            5.77888393e-04,  2.17472352e-02, -8.16953406e-02,
            1.54041937e-02, -3.55066173e-02]],

         [[ 7.61649990e-03, -6.52475432e-02,  2.02584285e-02,
           -4.36152853e-02, -7.94242844e-02, -6.29556971e-03,
           -2.17294712e-02,  3.30206454e-02,  3.47386077e-02,
           -2.77627818e-03,  4.49984707e-02, -3.03241126e-02,
           -3.36903334e-02,  2.34354921e-02,  3.31020765e-02,
           -7.81059638e-03, -9.54489596e-03, -1.07985372e-02,
            4.10569459e-02,  5.06392084e-02, -1.64809041e-02,
            8.42852518e-03, -6.24148361e-03,  1.38165271e-02,
            4.47277874e-02, -1.68551356e-02,  2.87279133e-02,
           -4.17906158e-02, -3.29194516e-02,  5.37550561e-02,
           -3.10864598e-02, -4.53849025e-02]],

         [[ 5.37880100e-02,  2.00091377e-02, -8.04780126e-02,
            2.05146279e-02, -6.41385652e-03,  2.94176023e-02,
            2.42049675e-02,  2.98423916e-02,  1.30865928e-02,
           -9.23016574e-03, -2.63463743e-02, -1.58412699e-02,
           -4.76215854e-02, -1.53328422e-02, -2.54222248e-02,
            1.03113698e-02,  1.97005924e-02, -1.09527409e-02,
           -4.29149866e-02,  1.15255425e-02,  3.65356952e-02,
            2.26275604e-02,  8.76231957e-03, -1.82650369e-02,
            4.30952013e-02, -1.58966344e-03,  1.01399068e-02,
            7.15927547e-03,  2.70794444e-02, -1.93151142e-02,
            2.06329934e-02, -3.24055366e-02]]],


        [[[ 7.32885906e-04, -5.99233769e-02,  1.01583647e-02,
            2.62707975e-02, -1.60765275e-02,  4.54364009e-02,
            1.22182900e-02,  1.77695882e-02,  3.40870097e-02,
           -3.20678158e-03,  1.94115974e-02, -5.89495376e-02,
            5.51430099e-02,  1.08586736e-02, -2.14386974e-02,
           -1.10124948e-03, -1.41514605e-02, -8.40184465e-03,
           -4.09237854e-02,  2.27938611e-02,  2.82027805e-03,
            3.99805643e-02, -5.23957238e-02, -6.65743649e-02,
           -1.86213956e-03,  1.84283289e-03,  8.22036352e-04,
           -2.04587094e-02, -4.95675243e-02,  5.40869832e-02,
            4.00022417e-02, -4.74570543e-02]],

         [[-3.73015292e-02,  9.84914601e-03,  9.94246900e-02,
            3.19805741e-02,  8.14174674e-03,  2.72354241e-02,
           -1.58177980e-03, -5.65455444e-02, -2.13499945e-02,
            2.36055311e-02,  4.57456382e-03,  5.87781705e-02,
           -4.50953143e-03, -3.05559561e-02,  8.65572542e-02,
           -2.87776738e-02,  7.56273838e-03, -2.02421043e-02,
            4.32164557e-02,  1.07650533e-02,  1.74834915e-02,
           -2.26386450e-02, -4.51299828e-03, -7.19766971e-03,
           -5.64673692e-02, -3.46505865e-02, -9.57003422e-03,
           -4.17267382e-02,  2.74983943e-02,  7.50013590e-02,
           -1.39447292e-02, -2.10063234e-02]],

         [[-4.99953330e-03, -1.95915010e-02,  7.38414973e-02,
            3.00457701e-02,  4.11909744e-02, -4.93509434e-02,
           -3.72827090e-02, -4.84874584e-02, -1.73344277e-02,
            2.13540550e-02,  2.63152272e-02,  5.11181913e-02,
            5.94335012e-02, -8.46157200e-04, -3.79960015e-02,
           -2.01609023e-02,  2.21411046e-02, -1.14003820e-02,
           -1.78077854e-02, -6.17240835e-03, -9.96494666e-03,
           -2.70768851e-02,  3.32489684e-02, -1.18451891e-02,
            7.48611614e-02,  3.68427448e-02, -1.70680200e-04,
            2.78645731e-03,  3.37152109e-02, -6.00774325e-02,
            3.43431458e-02,  6.80516511e-02]]],


        [[[ 4.51148823e-02,  4.12209071e-02, -1.92945134e-02,
           -2.68811788e-02,  4.68725041e-02, -2.08357088e-02,
           -3.62888947e-02, -1.60191804e-02,  3.19913588e-02,
            1.54639455e-02, -7.92380888e-03, -4.85247411e-02,
           -3.52074914e-02, -1.04825860e-02, -6.63231388e-02,
            4.35819328e-02,  1.74060687e-02, -3.14022303e-02,
           -2.88435258e-02, -2.56987382e-03, -4.61222306e-02,
            9.01424140e-03, -3.54990773e-02,  3.61517034e-02,
           -4.51472104e-02, -1.96188372e-02,  2.76502203e-02,
           -3.39846462e-02, -5.75804268e-04, -4.55158725e-02,
            2.47761561e-03,  5.08131757e-02]],

         [[ 3.74217257e-02,  4.53428067e-02, -4.36269939e-02,
           -1.65079869e-02, -2.69084796e-02, -2.38134293e-03,
            2.26788968e-02, -3.10470518e-02, -4.33242172e-02,
            1.89485904e-02, -5.52747138e-02,  6.01334386e-02,
           -1.70235410e-02, -4.17503342e-02, -1.59652822e-03,
           -3.10646854e-02, -1.94913559e-02,  5.42740058e-03,
            5.47912866e-02,  2.19548331e-03, -2.94116754e-02,
            2.24571414e-02, -1.57341175e-02, -5.24678500e-03,
            4.41270098e-02,  1.79115515e-02, -3.40841003e-02,
           -2.95497216e-02,  4.40835916e-02,  4.28234115e-02,
           -4.25039157e-02,  5.90493456e-02]],

         [[-2.71476209e-02,  6.84098527e-02, -2.91980486e-02,
           -2.52507403e-02, -6.22444265e-02,  3.67519422e-03,
            5.06899729e-02,  3.09969904e-03,  4.50362265e-02,
            8.56801707e-05,  4.21552844e-02, -3.78406122e-02,
           -1.73772611e-02,  4.68185954e-02, -6.93227863e-03,
           -4.71074954e-02,  5.72011899e-03, -1.59831103e-02,
           -1.66428182e-02,  1.12894354e-02,  5.62585844e-03,
            1.36870472e-02, -2.89466791e-02, -2.87153292e-03,
           -3.21626514e-02, -3.75866666e-02, -1.62240565e-02,
            3.01954672e-02, -2.69964593e-03, -2.27513053e-02,
            2.10835561e-02, -4.13369946e-02]]]], dtype=float32),
 array([-1.1922461e-03, -2.0752363e-04,  1.1357996e-05,  1.6377015e-05,
        -2.5950783e-04,  1.9307183e-05, -1.5572178e-06, -1.3648998e-03,
        -8.6763187e-04,  4.4856939e-04,  2.7988455e-03, -7.7398616e-04,
        -5.1178242e-04, -6.8265648e-04,  1.8571866e-04, -7.1992702e-04,
        -5.5880222e-04, -3.6114815e-04, -9.7678707e-04,  2.6443407e-03,
         1.1190268e-03, -1.0251488e-03, -1.1638318e-03,  7.1209669e-04,
         4.9417594e-04,  2.3746442e-04, -4.8552561e-04,  1.4480414e-03,
        -1.8445569e-05,  4.2989667e-04,  1.0579359e-04, -3.2821635e-04],
       dtype=float32)]

model幾個起始層的總結


Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 96, 96, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 96, 96, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 96, 96, 32)   128         conv2d_1[0][0]                   


現在我如何訂購它們並可視化

我正在使用 keras

謝謝

通常,如果您使用 Dense 層,則第一個 lenth 2 對應於權重向量和偏置向量。

由於我不知道您的圖層類型,因此我添加了一個示例來解釋 Dense、Conv2D 圖層的形狀。

第一個長度總是對應於權重和偏差,第二個形狀的權重和偏差是不同的,對於偏差它總是一個數組,對於 Dense,權重有一個形狀 (input_dim, output_dim),對於 Conv2D (channels, kernel_h, kernel_w,數量過濾器)。

from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import numpy as np

i1 = Input(shape=(32,32,3))
c1 = Conv2D(32, 3)(i1)
f1 = Flatten()(c1)
d1 = Dense(5)(f1)

m = Model(i1, d1)

m.summary()

y = m(np.zeros((1, 32, 32, 3)))

print(m.layers)
cw1 = np.array(m.layers[1].get_weights())
print(cw1.shape) # 2 weight, 1 weight, 1 bias
print(cw1[0].shape) # 3 channels, 3 by 3 kernels, 32 filters
print(cw1[1].shape) # 32 biases

cw1 = np.array(m.layers[2].get_weights())
print(cw1.shape) # this is just a flatten operations, so no weights

cw1 = np.array(m.layers[3].get_weights())
print(cw1.shape) # 2 -> 1 weight, 1 bias
print(cw1[0].shape) # 28800 inputs, 5 outputs, 28800 by 5 weight matrix
print(cw1[1].shape) # 5 biases
Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_14 (InputLayer)        [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 30, 30, 32)        896       
_________________________________________________________________
flatten_13 (Flatten)         (None, 28800)             0         
_________________________________________________________________
dense_13 (Dense)             (None, 5)                 144005    
=================================================================
Total params: 144,901
Trainable params: 144,901
Non-trainable params: 0
_________________________________________________________________
[<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7fb8ce3bb828>, <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb8ce5fd6d8>, <tensorflow.python.keras.layers.core.Flatten object at 0x7fb8ce3bb940>, <tensorflow.python.keras.layers.core.Dense object at 0x7fb8ce3bbb70>]
(2,)
(3, 3, 3, 32)
(32,)
(0,)
(2,)
(28800, 5)
(5,)

可視化完全取決於維度。

如果是一維的,

import matplotlib.pyplot as plt
plt.plot(weight)
plt.show()

如果是二維的,

import matplotlib.pyplot as plt
plt.imshow(weight)
plt.show()

如果是 3D,

您可以只選擇一個通道和 plot 該部分。


# plotting the 32 conv filter
import matplotlib.pyplot as plt
cw1 = np.array(m.layers[1].get_weights())
for i in range(32):
  plt.imshow(cw1[0][:,:,:,i])
  plt.show()

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM