简体   繁体   中英

Using batch size of one in tensorflow?

So, I have a model where the theoretical justification for the update procedure relies on having a batch size of 1. (For those curious, it's called Bayesian Personalized Ranking for recommender systems.)

Now, I have some standard code written. My input is a tf.placeholder variable. It's Nx3, and I run it as normal with the feed_dict . This is perfectly fine if I want N to be, say, 30K. However, if I want N to be 1, the feed_dict overhead really slows down my code.

For reference, I implemented the gradients by hand in pure Python, and it runs at about 70K iter/second. In contrast, GradientDescentOptimizer runs at about 1K iter/second. As you can see, this is just far too slow. So as I said, I suspect the problem is feed_dict has too much overhead to call it with a batch size of 1.

Here is the actual session code:

sess = tf.Session()
for iteration in range(100):
    samples = data.generate_train_samples(1000000)
    for sample in tqdm(samples):
        cvalues = sess.run([trainer, obj], feed_dict={input_data:[sample]})
    print("objective = " + str(cvalues[1]))

Is there a better way to do a single update at once?

Probably your code runs much slower for two reasons:

  1. You copy your data to GPU memory (if you use GPU) only when you run session and you do it many times (And this is really time consuming)
  2. You do it in 1 thread

Luckily Tensorflow has tf.data API which helps to solve both problems. You can try to do something like:

inputs = tf.placeholder(tf.float32, your_shape)
labels = tf.placeholder(tf.floar32, labels_shape)
data = tf.data.Dataset.from_tensor_slices((inputs, labels))

iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, {inputs: your_inputs, labels: your_labels})

And then to get next entry from the dataset you just use iterator.get_next()

If that's what you need, tensorflow has exhaustive documentation on importing data using tf.data API where you can find suitable for you use-case: documentation

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM