简体   繁体   中英

Incremental model training with tensorflow

I have a simple linear model which inputs (x, y) pairs and deduces b0 and b1 in y = b0 + b1 * x ; the key code is below. It trains on a dataset of known size. Now I want to add the ability of training it constantly: ie add every other batch of (x, y), and update cofficients according to the new data. There will be unlimited amount of input.

    x = tf.placeholder(tf.float32, [data_len], name="x")
    y = ...
    b0 = tf.Variable([0.8], trainable=True)
    b1 = ...
    #the model
    y = tf.add(tf.mul(x, b1), b0)
    y_act = tf.placeholder(tf.float32, [data_len], name="y_act")
    error = tf.sqrt((y - y_act) * (y - y_act))
    train_step = tf.train.AdamOptimizer(0.01).minimize(error)
    x_in = ...
    y_in = ...
    init = tf.initialize_all_variables()
    sess.run(init)
    feed_dict = { ... }
    fetches_in = { b0: b0, b1: b1, y: y, train_step: train_step }
    for i in range(0, 50):
        fetches = sess.run(fetches_in, feed_dict)

My idea is to remember so-far-trained coefficients, init a model with them, then just repeat again the training with the new portion of data. Repeat on each input. Is this a right way to go? The model will probably be promoted later to something more complex..

It sounds like you're talking about on-line training, ie continuously train a model with incoming data while simultaneously using it. You're right in that you should be able to pick up where you left off and just feed in new data. What you'll need is a way to save and load the variables between training sessions. You can use a tf.Saver to do this in "raw" tensorflow.

You can also use a tf.contrib.learn.Estimator to do this for you. You just give it a model_fn that constructs your model, and a model_dir to save the model in, and it will take care of the rest. Of course, there's already a linear model in tf.contrib.learn.LinearEstimator. With estimators, you'd just call fit(...) whenever you have new data and it will load your variables and continue running the training steps you've defined.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM