简体   繁体   中英

Get coefficients of a linear regression in Tensorflow

I've done a simple linear regression in Tensorflow. How can I know what are the coefficients of the regression? I've read the docs but I cannot find it anywhere! ( https://www.tensorflow.org/api_docs/python/tf/estimator/LinearRegressor )

EDIT Code example

import numpy as np
import tensorflow as tf

# Declare list of features, we only have one real-valued feature
def model_fn(features, labels, mode):
  # Build a linear model and predict values
  W = tf.get_variable("W", [1], dtype=tf.float64)
  b = tf.get_variable("b", [1], dtype=tf.float64)
  y = W * features['x'] + b
  # Loss sub-graph
  loss = tf.reduce_sum(tf.square(y - labels))
  # Training sub-graph
  global_step = tf.train.get_global_step()
  optimizer = tf.train.GradientDescentOptimizer(0.01)
  train = tf.group(optimizer.minimize(loss),
                   tf.assign_add(global_step, 1))
  # EstimatorSpec connects subgraphs we built to the
  # appropriate functionality.
  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=y,
      loss=loss,
      train_op=train)

estimator = tf.estimator.Estimator(model_fn=model_fn)
# define our data sets
x_train = np.array([1., 2., 3., 4.])
y_train = np.array([0., -1., -2., -3.])
x_eval = np.array([2., 5., 8., 1.])
y_eval = np.array([-1.01, -4.1, -7, 0.])
input_fn = tf.estimator.inputs.numpy_input_fn(
    {"x": x_train}, y_train, batch_size=4, num_epochs=None, shuffle=True)
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    {"x": x_train}, y_train, batch_size=4, num_epochs=1000, shuffle=False)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(
    {"x": x_eval}, y_eval, batch_size=4, num_epochs=1000, shuffle=False)

# train
estimator.train(input_fn=input_fn, steps=1000)
# Here we evaluate how well our model did.
train_metrics = estimator.evaluate(input_fn=train_input_fn)
eval_metrics = estimator.evaluate(input_fn=eval_input_fn)
print("train metrics: %r"% train_metrics)
print("eval metrics: %r"% eval_metrics)

EDIT: As Jason Ching points out, there have been some changes after this answer was posted. There are now the estimator methods get_variable_names and get_variable_value , and the estimator weights do not seem to be automatically added to tf.GraphKeys.MODEL_VARIABLES anymore.


Estimators are designed to work basically as a black box, so there is no direct API to retrieve the weights. Even if, as in your case, you are the one defining the model (as opposed to using a preexisting model), you do not have a direct access to the parameters from the estimator object.

That said, you can still retrieve the variables back through other means. If you know the names of the variables, one option is to simply get them from the graph object with get_operation_by_name or get_tensor_by_name . A more practical and general option is to use a collection. Either when you call tf.get_variable or after that, calling tf.add_to_collection , you can put the model variables under a common collection name for later retrieval. If you look at how a tf.estimator.LinearRegressor is actually built (search for the function linear_model in this module ), all model variables are added to both tf.GraphKeys.GLOBAL_VARIABLES and tf.GraphKeys.MODEL_VARIABLES . This is (presumably, I haven't really checked) common to all the available canned estimators, so usually when using one of those you should be able to simply do:

model_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)

It is preferable that you use tf.GraphKeys.MODEL_VARIABLES in this case instead of tf.GraphKeys.GLOBAL_VARIABLES , which has a more general purpose and is likely to contain other unrelated variables as well.

Try with this:

LR.train(input_fn=train_input_data,steps = 1)

with tf.Session() as sess:
    last_check = tf.train.latest_checkpoint(tf_data)
    saver = tf.train.import_meta_graph(last_check + '.meta')
    print (last_check +'.meta')
    saver.restore(sess, last_check)
    ######
    Model_variables = tf.GraphKeys.MODEL_VARIABLES
    Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES
    ######
    all_vars = tf.get_collection(Model_variables)
    # print (all_vars)
    for i in all_vars:
        print (str(i) + '  -->  '+ str(i.eval()))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM