I would like to implement a feed-forward neural network, with the only difference from a usual one that I'd manually control the correspondence between input features and the first hidden layer neurons. For example, in the input layer I have features f1, f2, ..., f100, and in the first hidden layer I have h1, h2, ..., h10. I want the first 10 features f1-f10 fed into h1, and f11-f20 fed into h2, etc.
Graphically, unlike the common deep learning technique dropout which is to prevent over-fitting by randomly omit hidden nodes for a certain layer, here what I want is to statically (fixed) omit certain hidden edges between input and hidden.
I am implementing it using Tensorflow and didn't find a way of specifying this requirement. I also looked into other platforms such as pytourch and theano, but still haven't got an answer. Any idea of implementation using Python would be appreciated!
Take the snippet below:
#!/usr/bin/env python3
import tensorflow as tf
features = tf.constant([1, 2, 3, 4])
hidden_1 = tf.constant([1, 1])
hidden_2 = tf.constant([2, 2])
res1 = hidden_1 * tf.slice(features, [0], [2])
res2 = hidden_2 * tf.slice(features, [2], [2])
final = tf.concat([res1, res2], axis=0)
sess = tf.InteractiveSession()
print(sess.run(final))
Assume features are your input features, with tf.slice they are split into individual slices, and each slice is at that point a separate graph (in this example they become multiplied with hidden_1 and hidden_2) and in the end they are merged back together with tf.concat .
The result is [1, 2, 6, 8] because [1, 2] are multiplied with [1, 1] and [2, 3] are multiplied with [2, 2].
I finally implemented the requirement by forcing certain blocks of the weight matrix corresponding to the first layer to be constant zero. That is, rather than just define w1 = tf.Variables(tf.random_normal([100,10]))
, I define ten 10 by 1 weight vectors and concatenate them with zeros to form a block diagonal matrix as final w1.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.