简体   繁体   中英

How can I implement a recursive neural network in TensorFlow?

Is there some way of implementing a recursive neural network like the one in [Socher et al. 2011] using TensorFlow? Note that this is different from recurrent neural networks, which are nicely supported by TensorFlow. The difference is that the network is not replicated into a linear sequence of operations, but into a tree structure.

I imagine that I could use the While op to construct something like a breadth-first traversal of the tree data structure for each entry of my dataset.

Maybe it would be possible to implement tree traversal as a new C++ op in TensorFlow, similar to While (but more general)?

Your guess is correct, you can use tf.while_loop and tf.cond to represent the tree structure in a static graph. More info: https://github.com/bogatyy/cs224d/tree/master/assignment3

In my evaluation, it makes training 16x faster compared to re-building the graph for every new tree.

Currently, these models are very hard to implement efficiently and cleanly in TensorFlow because the graph structure depends on the input. That also makes it very hard to do minibatching. It is possible using things like the while loop you mentioned, but doing it cleanly isn't easy.

You can build a new graph for each example, but this will be very annoying. If, for a given input size, you can enumerate a reasonably small number of possible graphs you can select between them and build them all at once, but this won't be possible for larger inputs.

You can also route examples through your graph with complicated tf.gather logic and masks, but this can also be a huge pain.

Ultimately, building the graph on the fly for each example is probably the easiest and there is a chance that there will be alternatives in the future that support better immediate style execution. But as of v0.8 I would expect this to be a bit annoying and introduce some overhead as Yaroslav mentions in his comment.

Edit: Since I answered, here is an example using a static graph with while loops: https://github.com/bogatyy/cs224d/tree/master/assignment3 I am not sure how performant it is compared to custom C++ code for models like this, although in principle it could be batched.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM