简体   繁体   中英

How can I access weights from a layer created with tf.contrib.layers.fully_connected?

I'm using tf.contrib.layers.fully_connected to create a layer in the following code.

library(tensorflow)

x <- tf$placeholder(tf$float32, shape(NULL, 784L))
logits <- tf$contrib$layers$fully_connected(x, 10L)
y <- tf$nn$softmax(logits)

How can I access the weights as I would in the following block of code with sess$run(W) ?

x <- tf$placeholder(tf$float32, shape(NULL, 784L))
W <- tf$Variable(tf$zeros(shape(784L, 10L)))
b <- tf$Variable(tf$zeros(shape(10L)))
y <- tf$nn$softmax(tf$matmul(x, W) + b)

Note: I'm using TensorFlow for R but this should be the same as TensorFlow for Python by changing $ for . .

You can get the list of all global variables using tf$global_variables() . Not an ideal solution (because it retrieves a list of unnamed variables), but it should get you what you need. Reproducible example below

library(tensorflow)

datasets <- tf$contrib$learn$datasets
mnist <- datasets$mnist$read_data_sets("MNIST-data", one_hot = TRUE)

x <- tf$placeholder(tf$float32, shape(NULL, 784L))
logits <- tf$contrib$layers$fully_connected(x, 10L)

y <- tf$nn$softmax(logits)
y_ <- tf$placeholder(tf$float32, shape(NULL,10L))

cross_entropy <- tf$reduce_mean(-tf$reduce_sum(y_ * tf$log(y), reduction_indices=1L))
train_step <- tf$train$GradientDescentOptimizer(0.5)$minimize(cross_entropy)

sess <- tf$Session()
sess$run(tf$global_variables_initializer())

for (i in 1:1000) {
  batches <- mnist$train$next_batch(100L)
  batch_xs <- batches[[1]]
  batch_ys <- batches[[2]]
  sess$run(train_step,
           feed_dict = dict(x = batch_xs, y_ = batch_ys))
}

lst.variables <- sess$run(tf$global_variables())
str(lst.variables)

You pass the name of the tensor to the run function. You should inspect the graph to see the name of the tensor added to the graph from the function.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM