[英]Using batch size of one in tensorflow?
So, I have a model where the theoretical justification for the update procedure relies on having a batch size of 1. (For those curious, it's called Bayesian Personalized Ranking for recommender systems.) 因此,我有一个模型,其中更新过程的理论依据依赖于批量大小为1.(对于那些好奇的,它被称为推荐系统的贝叶斯个性化排名。)
Now, I have some standard code written. 现在,我编写了一些标准代码。 My input is a
tf.placeholder
variable. 我的输入是一个
tf.placeholder
变量。 It's Nx3, and I run it as normal with the feed_dict
. 它是Nx3,我使用
feed_dict
正常运行它。 This is perfectly fine if I want N to be, say, 30K. 如果我想要N,比如30K,那就完全没问题了。 However, if I want N to be 1, the
feed_dict
overhead really slows down my code. 但是,如果我希望N为1,则
feed_dict
开销确实会降低我的代码速度。
For reference, I implemented the gradients by hand in pure Python, and it runs at about 70K iter/second. 作为参考,我在纯Python中手动实现渐变,并且它以大约70K iter /秒运行。 In contrast,
GradientDescentOptimizer
runs at about 1K iter/second. 相比之下,
GradientDescentOptimizer
以大约1K iter /秒的速度运行。 As you can see, this is just far too slow. 如你所见,这太慢了。 So as I said, I suspect the problem is
feed_dict
has too much overhead to call it with a batch size of 1. 正如我所说,我怀疑问题是
feed_dict
有太多的开销来调用它的批量大小为1。
Here is the actual session
code: 这是实际的
session
代码:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for iteration in range(100):
samples = data.generate_train_samples(1000000)
for sample in tqdm(samples):
cvalues = sess.run([trainer, obj], feed_dict={input_data:[sample]})
print("objective = " + str(cvalues[1]))
Is there a better way to do a single update at once? 有没有更好的方法一次进行一次更新?
Probably your code runs much slower for two reasons: 可能由于以下两个原因,您的代码运行速度会慢得多:
Luckily Tensorflow has tf.data
API which helps to solve both problems. 幸运的是Tensorflow有
tf.data
API,有助于解决这两个问题。 You can try to do something like: 您可以尝试执行以下操作:
inputs = tf.placeholder(tf.float32, your_shape)
labels = tf.placeholder(tf.floar32, labels_shape)
data = tf.data.Dataset.from_tensor_slices((inputs, labels))
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, {inputs: your_inputs, labels: your_labels})
And then to get next entry from the dataset you just use iterator.get_next()
然后从数据集中获取下一个条目,您只需使用
iterator.get_next()
If that's what you need, tensorflow has exhaustive documentation on importing data using tf.data
API where you can find suitable for you use-case: documentation 如果这就是您所需要的,tensorflow有关于使用
tf.data
API导入数据的详尽文档,您可以在其中找到适合您的用例: 文档
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.