[英]More efficient way to do the Mandlebrot example in Tensorflow
The tensorflow website has an example where TF is used to compute the Mandlebrot set. 在tensorflow网站上有一个示例 ,其中TF用于计算Mandlebrot集。 Here's the relevant snippet:
以下是相关代码段:
# Compute the new values of z: z^2 + x
zs_ = zs*zs + xs
# Have we diverged with this new value?
not_diverged = tf.abs(zs_) < 4
# Operation to update the zs and the iteration count.
#
# Note: We keep computing zs after they diverge! This
# is very wasteful! There are better, if a little
# less simple, ways to do this.
#
step = tf.group(
zs.assign(zs_),
ns.assign_add(tf.cast(not_diverged, tf.float32))
)
The relevant quote is "There are better, if a little less simple, ways to do this." 相关的引言是“有更好的方法,即使不太简单,也可以做到这一点”。 Does anyone know what a better way might be?
谁知道有什么更好的方法吗?
I'm messing around with ray tracing in TF, and I'm encountering situations where 90% of pixels have converged, but I keep recomputing them because I don't know how to update a subset of the entires in a tensor without sacrificing the speed benefits of using vector operations everywhere. 我搞乱了TF中的光线跟踪,遇到了90%的像素会聚的情况,但是我一直在重新计算它们,因为我不知道如何在不牺牲张量的情况下更新整个张量的子集加快在任何地方使用向量运算的好处。
I figured out a way of doing it. 我想出了一种方法。
The key is to use tf.where
to find the pixels that haven't yet diverged, tf.gather_nd
to pull those pixels into a smaller array, perform the update step on those specific pixels, then use tf.scatter_nd
(or another "scatter" function) to apply sparse updates to some variable representing the state. 关键是使用
tf.where
查找尚未发散的像素, tf.gather_nd
将这些像素拉入较小的阵列,对这些特定像素执行更新步骤,然后使用tf.scatter_nd
(或另一个“分散” ”)将稀疏更新应用于表示状态的某些变量。
In my use case, this saved quite a lot of time. 在我的用例中,这节省了大量时间。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.