繁体   English   中英

如何获得模型创建的训练权重

[英]How to get the trained weights created by a model

我实现了一个简单的逻辑回归。 在运行训练算法之前,我为权重创建了一个占位符,将所有权重初始化为0 ...

W = tf.Variable(tf.zeros([784, 10]))

正确初始化所有变量后,便实现了逻辑回归(已测试并正确运行)...

for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    # loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})

        # compute average loss
        avg_cost += c / total_batch
    # display logs per epoch step
    if (epoch + 1) % display_step == 0:
        print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))

我的问题是,我需要提取模型中使用的权重。 我为模型使用了以下内容...

pred = tf.nn.softmax(tf.matmul(x, W) + b)  # Softmax

我尝试通过以下方式提取...

var = [v for v in tf.trainable_variables() if v.name == "Variable:0"][0]
print(sess.run(var[0]))

我以为训练后的权重将位于tf.training_variables() ,但是当我运行print函数时,会得到一个零数组。

我想要的是所有的砝码。 但是由于某种原因,我得到的是零数组,而不是分类器的实际权重。

变量W应参考训练后的权重。 请尝试简单地做: sess.run(W)

简单得多,只需使用run函数评估权重,您将获得具有值的numpy数组:

sess.run([x, W, b])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM