简体   繁体   中英

Tensorflow tf.layers Dense Neural Net function vs. Class Interface

I am trying to implement a helper class to create a standard Feedforward Neural network in python.

Since I want the class to be general, there is a method called addHiddenLayer() which should append layers to the Flow Graph.

To add layers to the flow graph I went through the tf.layers module which provides two options tf.layers.dense : A function which returns an object which can act as the input to the next layer.

There is also tf.layers.Dense : A class which has almost identical attributes as the parameters of tf.layers.dense(), and implements essentially the same operation on the inputs.

After going through the documentation for both, I fail to see any extra functionality added by using the class version. I think the function implementation should suffice for my use case the skeleton for which is given below.

class myNeuralNet:
def __init__(self, dim_input_data, dim_output_data): 
    #Member variable for dimension of Input Data Vectors (Number of features...)
    self.dim_input_data = dim_input_data
    #Variable for dimension of output labels
    self.dim_output_data = dim_output_data
    #TF Placeholder for input data   
    self.x =  tf.placeholder(tf.float32, [None, 784])
    #TF Placeholder for labels 
    self.y_ = tf.placeholder(tf.float32, [None, 10])
    #Container to store all the layers of the network 
    #Containter to hold layers of NN
    self.layer_list = []

def addHiddenLayer(self, layer_dim, activation_fn=None, regularizer_fn=None):
    # Add a layer to the network of layer_dim
    # append the new layer to the container of layers 
    pass

def addFinalLayer(self, activation_fn=None, regularizer_fn=None):
    pass

def setup_training(self, learn_rate):
    # Define loss, you might want to store it as self.loss
    # Define the train step as self.train_step = ..., use an optimizer from tf.train and call minimize(self.loss)
    pass

def setup_metrics(self):
    # Use the predicted labels and compare them with the input labels(placeholder defined in __init__)
    # to calculate accuracy, and store it as self.accuracy
    pass

# add other arguments to this function as given below
def train(self, sess, max_epochs, batch_size, train_size, print_step = 100):                
    pass

Can someone give an example of a situation where the class version would be required? References:

Related question on SO

Example of function usage

I've always used the dense because you get the output tensor that you can reuse for next layers.

Then it's probably a matter of taste than anything else though.

Using Dense has the advantage that you get the "layer object" that you can refer back to later. dense actually just calls Dense and then uses its apply() method immediately, discarding the layer object afterwards. Here are two example scenarios where Dense would be useful:

  1. Accessing variables. Let's say you want to do something with the dense layer weights, eg visualize them, use them in some kind of regularization etc. If you used dense , you have a problem: The layer object storing the variables was discarded. You can only get them back from the computation graph, which is really annoying and ugly -- see this question for an example. If you created a Dense layer object on the other hand you can simply ask for the trainable_variables attribute of the layer.
    Furthermore, if you use eager execution I believe you have to have explicit variable storage because there is no computational graph -- if you use dense the variables would be discarded along with the layer object and your training wouldn't work (but don't quote me on this, I don't know much about eager execution).
  2. Re-using a layer. Let's say you want to apply the layer to some input, and then later apply it to a different input as well. With dense you have to use variable scopes and the reuse feature which I personally find quite unintuitive, as well as making your code harder to understand. If you used Dense , you can simply call the apply method of the layer object again.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM