[英]Accuracy very bad in tensorflow logistic regression
我正在嘗試編寫一個程序來預測是否患有惡性腫瘤或良性腫瘤
數據集來自: https : //archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Prognostic%29
這是我的代碼,我的准確率約為65%,這不比擲硬幣好。 任何幫助,將不勝感激
import tensorflow as tf
import pandas as pd
import numpy as np
df = pd.read_csv(r'D:\wholedesktop\logisticReal.txt')
df.drop(['id'], axis=1, inplace=True)
x_data = np.array(df.drop(['class'], axis=1))
x_data = x_data.astype(np.float64)
y = df['class']
y.replace(2, 0, inplace=True)
y.replace(4, 1, inplace=True)
y_data = np.array(y)
# y shape = 681,1
# x shape = 681,9
x = tf.placeholder(name='x', dtype=np.float32)
y = tf.placeholder(name='y', dtype=np.float32)
w = tf.Variable(dtype=np.float32, initial_value=np.random.random((9, 1)))
b = tf.Variable(dtype=np.float32, initial_value=np.random.random((1, 1)))
y_ = (tf.add(tf.matmul(x, w), b))
error = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_, labels=y))
goal = tf.train.GradientDescentOptimizer(0.05).minimize(error)
prediction = tf.round(tf.sigmoid(y_))
correct = tf.cast(tf.equal(prediction, y), dtype=np.float64)
accuracy = tf.reduce_mean(correct)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(2000):
sess.run(goal, feed_dict={x: x_data, y: y_data})
print(i, sess.run(accuracy, feed_dict={x: x_data, y: y_data}))
weight = sess.run(w)
bias = sess.run(b)
print(weight)
print(bias)
您的神經網絡只有一層,所以它能做的最好就是將一條直線擬合到您的數據中,以分隔不同的類。 對於一般(高維)數據集,這是遠遠不夠的。 (深度)神經網絡的力量在於神經元的許多層之間的連通性。 在你的榜樣,你可以通過手動的輸出通過添加更多的圖層matmul
到一個新的matmul
以不同的權重和偏見,或者你可以使用contrib.layers
集合,使之更加簡潔:
x = tf.placeholder(name='x', dtype=np.float32)
fc1 = tf.contrib.layers.fully_connected(inputs=x, num_outputs=16, activation_fn=tf.nn.relu)
fc2 = tf.contrib.layers.fully_connected(inputs=fc1, num_outputs=32, activation_fn=tf.nn.relu)
fc3 = tf.contrib.layers.fully_connected(inputs=fc2, num_outputs=64, activation_fn=tf.nn.relu)
訣竅是將輸出從一層作為輸入傳遞到下一層。 隨着您添加越來越多的圖層,您的准確性將會提高(可能是由於過度擬合,請使用dropout
來彌補這一點)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.