![](/img/trans.png)
[英]How to save the structure and weights of trained tensorflow model?
[英]How to get the trained weights created by a model
我實現了一個簡單的邏輯回歸。 在運行訓練算法之前,我為權重創建了一個占位符,將所有權重初始化為0 ...
W = tf.Variable(tf.zeros([784, 10]))
正確初始化所有變量后,便實現了邏輯回歸(已測試並正確運行)...
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
# loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})
# compute average loss
avg_cost += c / total_batch
# display logs per epoch step
if (epoch + 1) % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
我的問題是,我需要提取模型中使用的權重。 我為模型使用了以下內容...
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
我嘗試通過以下方式提取...
var = [v for v in tf.trainable_variables() if v.name == "Variable:0"][0]
print(sess.run(var[0]))
我以為訓練后的權重將位於tf.training_variables()
,但是當我運行print
函數時,會得到一個零數組。
我想要的是所有的砝碼。 但是由於某種原因,我得到的是零數組,而不是分類器的實際權重。
變量W
應參考訓練后的權重。 請嘗試簡單地做: sess.run(W)
簡單得多,只需使用run函數評估權重,您將獲得具有值的numpy數組:
sess.run([x, W, b])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.