簡體   English   中英

如何獲得模型創建的訓練權重

[英]How to get the trained weights created by a model

我實現了一個簡單的邏輯回歸。 在運行訓練算法之前,我為權重創建了一個占位符,將所有權重初始化為0 ...

W = tf.Variable(tf.zeros([784, 10]))

正確初始化所有變量后,便實現了邏輯回歸(已測試並正確運行)...

for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    # loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})

        # compute average loss
        avg_cost += c / total_batch
    # display logs per epoch step
    if (epoch + 1) % display_step == 0:
        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))

我的問題是,我需要提取模型中使用的權重。 我為模型使用了以下內容...

pred = tf.nn.softmax(tf.matmul(x, W) + b)  # Softmax

我嘗試通過以下方式提取...

var = [v for v in tf.trainable_variables() if v.name == "Variable:0"][0]
print(sess.run(var[0]))

我以為訓練后的權重將位於tf.training_variables() ,但是當我運行print函數時,會得到一個零數組。

我想要的是所有的砝碼。 但是由於某種原因,我得到的是零數組,而不是分類器的實際權重。

變量W應參考訓練后的權重。 請嘗試簡單地做: sess.run(W)

簡單得多,只需使用run函數評估權重,您將獲得具有值的numpy數組:

sess.run([x, W, b])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM