[英]Tensorflow error : Dimensions must be equal
我有一個25000張彩色圖片100 * 100(* 3)的數據集,我試圖用一個卷積層構建一個簡單的神經網絡。 它顯示了受瘧疾感染或未受瘧疾感染的細胞的圖片,因此我的輸出為2。但是似乎我的尺寸不匹配,而且我不知道我的錯誤來自哪里。
我的神經網絡:
def simple_nn(X_training, Y_training, X_test, Y_test):
input = 100*100*3
batch_size = 25
X = tf.placeholder(tf.float32, [batch_size, 100, 100, 3])
#Was:
# W = tf.Variable(tf.zeros([input, 2]))
# b = tf.Variable(tf.zeros([2]))
#Now:
W = tf.Variable(tf.truncated_normal([4, 4, 3, 3], stddev=0.1))
B = tf.Variable(tf.ones([3])/10) # What should I put here ??
init = tf.global_variables_initializer()
# model
#Was:
# Y = tf.nn.softmax(tf.matmul(tf.reshape(X, [-1, input]), W) + b)
#Now:
stride = 1 # output is still 28x28
Ycnv = tf.nn.conv2d(X, W, strides=[1, stride, stride, 1], padding='SAME')
Y = tf.nn.relu(Ycnv + B)
# placeholder for correct labels
Y_ = tf.placeholder(tf.float32, [None, 2])
# loss function
cross_entropy = -tf.reduce_sum(Y_ * tf.log(Y))
# % of correct answers found in batch
is_correct = tf.equal(tf.argmax(Y,1), tf.argmax(Y_,1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
learning_rate = 0.00001
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_step = optimizer.minimize(cross_entropy)
sess = tf.Session()
sess.run(init)
#Training here...
我的錯誤:
Traceback (most recent call last):
File "neural_net.py", line 135, in <module>
simple_nn(X_training, Y_training, X_test, Y_test)
File "neural_net.py", line 69, in simple_nn
cross_entropy = -tf.reduce_sum(Y_ * tf.log(Y))
...
ValueError: Dimensions must be equal, but are 2 and 3 for 'mul' (op: 'Mul') with input shapes: [?,2], [25,100,100,3].
我之前使用過一個簡單的圖層,它正在工作。 我改變了自己的體重和偏見,說實話,我不知道為什么要這樣設置偏見,我遵循了一個教程( https://codelabs.developers.google.com/codelabs/cloud-tensorflow-mnist/#11 ),但沒有解釋。 我也將Y替換為conv2D。 而且,我不知道如果要得到大小為2 * 1的向量,我的輸出應該是什么。
您已正確定義標簽為
Y_ = tf.placeholder(tf.float32, [None, 2])
因此,最后一維是2。但是,卷積步驟的輸出並不直接適合將其與標簽進行比較。 我的意思是:如果你這樣做
Ycnv = tf.nn.conv2d(X, W, strides=[1, stride, stride, 1], padding='SAME')
Y = tf.nn.relu(Ycnv + B)
錯誤的維度將為四個:
ValueError: Dimensions must be equal, but are 2 and 3 for 'mul' (op: 'Mul') with input shapes: [?,2], [25,100,100,3].
因此,不可能直接將卷積的輸出與標簽相乘(或運算)。 我建議將卷積的輸出展平(僅重整為一維),並將其傳遞到2個單位的完全連接的層(與您擁有的類一樣多)。 像這樣:
Y = tf.reshape(Y, [1,-1])
logits = tf.layers.dense(Y, units= 2)
您可以將其傳遞給損失。
另外,我建議您將損失更改為更適當的版本。 例如, tf.losses.sigmoid_cross_entropy
。
另外,使用卷積的方式很奇怪。 為什么要在卷積中放入手工過濾器? 此外,您還必須進行初始化並將其放入集合中。 總之,我建議您刪除以下所有代碼:
W = tf.Variable(tf.truncated_normal([4, 4, 3, 3], stddev=0.1))
B = tf.Variable(tf.ones([3])/10) # What should I put here ??
init = tf.global_variables_initializer()
# model
#Was:
# Y = tf.nn.softmax(tf.matmul(tf.reshape(X, [-1, input]), W) + b)
#Now:
stride = 1 # output is still 28x28
Ycnv = tf.nn.conv2d(X, W, strides=[1, stride, stride, 1], padding='SAME')
Y = tf.nn.relu(Ycnv + B)
並替換為:
conv1 = tf.layers.conv2d(X, filters=64, kernel_size=3,
strides=1, padding='SAME',
activation=tf.nn.relu, name="conv1")
同樣, init = tf.global_variable_initializer()
應該位於圖構造的末尾,因為如果不這樣,將會有一些無法捕獲的變量。
我最后的工作代碼是:
def simple_nn():
inp = 100*100*3
batch_size = 2
X = tf.placeholder(tf.float32, [batch_size, 100, 100, 3])
Y_ = tf.placeholder(tf.float32, [None, 2])
#Was:
# W = tf.Variable(tf.zeros([input, 2]))
# b = tf.Variable(tf.zeros([2]))
#Now:
# model
#Was:
# Y = tf.nn.softmax(tf.matmul(tf.reshape(X, [-1, input]), W) + b)
#Now:
stride = 1 # output is still 28x28
conv1 = tf.layers.conv2d(X, filters=64, kernel_size=3,
strides=1, padding='SAME',
activation=tf.nn.relu, name="conv1")
Y = tf.reshape(conv1, [1,-1])
logits = tf.layers.dense(Y, units=2, activation=tf.nn.relu)
# placeholder for correct labels
# loss function
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_, logits=logits)
loss = tf.reduce_mean(cross_entropy)
# % of correct answers found in batch
is_correct = tf.equal(tf.argmax(Y,1), tf.argmax(Y_,1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
learning_rate = 0.00001
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_step = optimizer.minimize(cross_entropy)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
...
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.