简体   繁体   中英

Why do I get 6 parameters in Keras in simple 2 output 2 input network?

I am learning about neural networks in keras. I specified a simple model on made up data.

model.add(tf.keras.layers.Dense(2, input_dim=2))
model.compile(optimizer='sgd', loss='mean_squared_error')

I have two attributes to predict two values.

Here is where I initialize my data:

for x1 in range (6):
    y.append([2*x1+x2**2-2, x1*x2])
xs = np.array(x, dtype=float)
ys = np.array(y, dtype=float)
model.fit(xs, ys, epochs=500)

Mind you, I use the data solely for the purpose of learning. After I attempted to observe the model. I run model.summary() and model.get_weights() .

Model: "sequential"
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 2)                 6         
Total params: 6
Trainable params: 6
Non-trainable params: 0
model weights  [array([[0.5137405, 5.477211 ],
       [8.750836 , 1.6910588]], dtype=float32), array([-5.701193, -7.874653], dtype=float32)]

I don't understand why are there 6 params and six weights. From my understanding there should be two going from each input, or should I have somewhere specifically defined the output layer?

The model architecture you have defined is pictorially shown below


You have one dense layer with two neurons. Why two neurons? because the first parameter to Dense is units which denotes the number of neurons. Each neuron does linear operation of XW + b and then applies activation function over it. The learnable parameters in a nuerons are W and b .

Since the size of X is 2 (2 features) so size of W (=2) + b = 3. So each neuron in this case will have 3 parameters and 2 such will have 6 parameters.

You have a single output layer with two neurons, each of these neurons must have two weights (since the inputs are of dimension 2) and another weight called "bias". So each neuron has 3 weights.

In summary, you have 2 neurons and each one has 3 weights or trainable parameters, so in total there are 6 trainable parameters in your network.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM