簡體   English   中英

張量流邏輯回歸的准確性非常差

[英]Accuracy very bad in tensorflow logistic regression

我正在嘗試編寫一個程序來預測是否患有惡性腫瘤或良性腫瘤

數據集來自: https : //archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Prognostic%29

這是我的代碼,我的准確率約為65%,這不比擲硬幣好。 任何幫助,將不勝感激

import tensorflow as tf
import pandas as pd
import numpy as np


df = pd.read_csv(r'D:\wholedesktop\logisticReal.txt')
df.drop(['id'], axis=1, inplace=True)

x_data = np.array(df.drop(['class'], axis=1))
x_data = x_data.astype(np.float64)
y = df['class']
y.replace(2, 0, inplace=True)
y.replace(4, 1, inplace=True)
y_data = np.array(y)
# y shape = 681,1
# x shape = 681,9

x = tf.placeholder(name='x', dtype=np.float32)
y = tf.placeholder(name='y', dtype=np.float32)

w = tf.Variable(dtype=np.float32, initial_value=np.random.random((9, 1)))
b = tf.Variable(dtype=np.float32, initial_value=np.random.random((1, 1)))

y_ = (tf.add(tf.matmul(x, w), b))
error = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_, labels=y))
goal = tf.train.GradientDescentOptimizer(0.05).minimize(error)

prediction = tf.round(tf.sigmoid(y_))
correct = tf.cast(tf.equal(prediction, y), dtype=np.float64)
accuracy = tf.reduce_mean(correct)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(2000):
        sess.run(goal, feed_dict={x: x_data, y: y_data})
        print(i, sess.run(accuracy, feed_dict={x: x_data, y: y_data}))

    weight = sess.run(w)
    bias = sess.run(b)
    print(weight)
    print(bias)

您的神經網絡只有一層,所以它能做的最好就是將一條直線擬合到您的數據中,以分隔不同的類。 對於一般(高維)數據集,這是遠遠不夠的。 (深度)神經網絡的力量在於神經元的許多層之間的連通性。 在你的榜樣,你可以通過手動的輸出通過添加更多的圖層matmul到一個新的matmul以不同的權重和偏見,或者你可以使用contrib.layers集合,使之更加簡潔:

x = tf.placeholder(name='x', dtype=np.float32)
fc1 = tf.contrib.layers.fully_connected(inputs=x, num_outputs=16, activation_fn=tf.nn.relu)
fc2 = tf.contrib.layers.fully_connected(inputs=fc1, num_outputs=32, activation_fn=tf.nn.relu)
fc3 = tf.contrib.layers.fully_connected(inputs=fc2, num_outputs=64, activation_fn=tf.nn.relu)

訣竅是將輸出從一層作為輸入傳遞到下一層。 隨着您添加越來越多的圖層,您的准確性將會提高(可能是由於過度擬合,請使用dropout來彌補這一點)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM