簡體   English   中英

使用ScikitLearn的神經網絡實現時出現問題

[英]Problems while using ScikitLearn's Neural Network implementation

我正在嘗試使用Scikit Learn提供的神經網絡實現來實現圖像處理。 我有接近10,000張'JPG'格式的彩色圖像,我將這些圖像轉換為'PNG'格式並刪除了顏色信息。 新圖像均為黑色或白色圖像。 在將這些圖像轉換為矢量格式之后,這些圖像矢量形成了對神經網絡的輸入。

對於每個圖像,還有一個輸出,它形成神經網絡的輸出。

輸入文件只有0和1的值,而沒有任何其他值。 每個圖像的輸出對應於一個連續的矢量,介於0和1之間,長度為22。 即每個圖像的輸出是長度為22的向量。

為了開始處理,我開始只有100個圖像及其相應的輸出,並得到以下錯誤:

ValueError: Array contains NaN or infinity

我還想指出第一次迭代在這里完成,我在第二次迭代期間遇到了這個錯誤。

為了嘗試不同的東西,我將輸入和輸出調整為每個10張圖像。 使用相同的代碼(即將出現),我能夠完成7次迭代(我已經將迭代次數設置為20次),然后收到相同的錯誤。

然后我將迭代次數更改為5,只是為了檢查它是否有效。 在此更改后,我收到以下錯誤:

ValueError: bad input shape (10, 22)

我也嘗試在輸入和輸出上使用np.reval() ,但這又給了我NaN or Infinity錯誤。

這是我在整個過程中使用的代碼:

import numpy as np
import csv
import matplotlib.pyplot as plt
from scipy.ndimage import convolve
from sklearn import linear_model, datasets, metrics
from sklearn.cross_validation import train_test_split
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline


def ReadCsv(fileName):
    in_file = open(fileName, 'rUb')
    reader = csv.reader(in_file, delimiter=',', quotechar='"')
    data = [[]]
    for row in reader:
        data.append(row)

    data.pop(0)
    return data

X_train = np.asarray(ReadCsv('100Images.csv'), 'float32')
Y_train = np.asarray(ReadCsv('100Images_Y_new.csv'), 'float32')
X_test = np.asarray(ReadCsv('ImagesForTest.csv'), 'float32')
Y_test = np.asarray(ReadCsv('ImagesForTest_Y_new.csv'), 'float32')

logistic = linear_model.LogisticRegression()
rbm = BernoulliRBM(random_state=0, verbose=True)

classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])

rbm.learning_rate = 0.06
rbm.n_iter = 5

rbm.n_components = 100
logistic.C = 6000.0

classifier.fit(X_train, Y_train)

print()
print("Logistic regression using RBM features:\n%s\n" % (
    metrics.classification_report(
        Y_test,
        classifier.predict(X_test))))

我真的很感激這個問題的任何幫助。

TIA。

將學習率更改為較小的值可能會解決此問題。 (即rbm.learning_rate)

至少這解決了我之前遇到的問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM