簡體   English   中英

貝葉斯神經網絡中的不相容形狀誤差

[英]Incompatible shapes error in bayesian neural network

我是機器學習的新手。 我有一個有關貝葉斯神經網絡的項目,可以預測足球比賽的結果。 然后我按照此鏈接的說明進行操作。 然后我編寫如下代碼:

import sys
from math import floor

import edward as ed
import numpy as np
import pandas as pd
import tensorflow as tf
from edward.models import Normal, Categorical
from fancyimpute import KNN
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder

data = pd.read_csv('features_dummies_with_label.csv', sep=',')


def impute_missing_values_by_KNN():
    home_data = data[[col for col in data.columns if 'hp' in col]]
    away_data = data[[col for col in data.columns if 'ap' in col]]
    label_data = data[[col for col in data.columns if 'label' in col]]

    home_filled = pd.DataFrame(KNN(3).complete(home_data))
    home_filled.columns = home_data.columns
    home_filled.index = home_data.index

    away_filled = pd.DataFrame(KNN(3).complete(away_data))
    away_filled.columns = away_data.columns
    away_filled.index = away_data.index

    data_frame_out = pd.concat([home_filled, away_filled, label_data], axis=1)

    return data_frame_out


dataset = impute_missing_values_by_KNN()

dataset = pd.DataFrame(data=dataset)

data_x = dataset.loc[:, dataset.columns != 'label'].as_matrix().astype(np.float32)
data_y_ = dataset.loc[:, 'label'].as_matrix().astype(np.float32)

enc = OneHotEncoder(sparse=False)
integer_encoded = np.array(data_y_).reshape(-1)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
onehot_encoded = enc.fit_transform(integer_encoded)

data_y = onehot_encoded

train_size = 0.9

train_cnt = floor(data_x.shape[0] * train_size)

N = int(train_cnt)

train_x, test_x = data_x[0:N], data_x[N:]
train_y, test_y = data_y[0:N], data_y[N:]

in_size = train_x.shape[1]
out_size = train_y.shape[1]

EPOCH_SUM = 5
BATCH_SIZE = 10

train_y2 = np.argmax(train_y, axis=1)
test_y2 = np.argmax(test_y, axis=1)

n_nodes_hl1 = 500

x_ = tf.placeholder(tf.float32, [None, in_size])
y_ = tf.placeholder(tf.float32)

# def neural_network_model(data):
w_h1 = Normal(loc=tf.zeros([in_size, out_size]), scale=tf.ones([in_size, out_size]))

b_h1 = Normal(loc=tf.zeros([out_size]), scale=tf.ones([out_size]))

y_pre = Normal(tf.matmul(x_, w_h1) + b_h1, scale=1.0)

qw_h1 = Normal(loc=tf.Variable(tf.random_normal([in_size, out_size])),
               scale=tf.Variable(tf.random_normal([in_size, out_size])))

qb_h1 = Normal(loc=tf.Variable(tf.random_normal([out_size])), scale=tf.Variable(tf.random_normal([out_size])))

y = Normal(tf.matmul(x_, qw_h1) + qb_h1, scale=1.0)

inference = ed.KLqp({w_h1: qw_h1, b_h1: qb_h1}, data={y_pre: y_})
inference.initialize()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

with sess:
    samples_num = 100
    for epoch in tqdm(range(EPOCH_SUM), file=sys.stdout):
        perm = np.random.permutation(N)
        for i in range(0, N, BATCH_SIZE):
            batch_x = train_x[perm[i:i + BATCH_SIZE]]
            batch_y = train_y2[perm[i:i + BATCH_SIZE]]
            inference.update(feed_dict={x_: batch_x, y_: batch_y})
        y_samples = y.sample(samples_num).eval(feed_dict={x_: train_x})
        acc = (np.round(y_samples.sum(axis=0) / samples_num) == train_y2).mean()
        y_samples = y.sample(samples_num).eval(feed_dict={x_: test_x})
        tets_acc = (np.round(y_samples.sum(axis=0) / samples_num) == test_y2).mean()
        if (epoch + 1) % 1 == 0:
            tqdm.write('epoch:\t{}\taccuracy:\t{}\tvaridation accuracy:\t{}'.format(epoch + 1, acc, tets_acc))

但是,當我調試它時,會出現如下錯誤:

InvalidArgumentError (see above for traceback): Incompatible shapes: [10] vs. [10,3]
     [[Node: inference/sample/Normal_2/log_prob/standardize/sub = Sub[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_Placeholder_1_0_1, inference/sample/Normal_2/loc)]]

在這一行:

inference.update(feed_dict = {x_:batch_x,y_:batch_y})

錯誤是什么意思? 以及如何解決呢?

沒有看到回溯,就很難調試錯誤。 但是我假設您要傳遞給inference.update的張量或數組與聲明中定義的形狀或數組具有另一種形狀。 因此,我將檢查以下形狀:例如batch_x,train_x(每次迭代),w_h1,qw_h1,...。 打印出這些陣列/張量或使用tfdbg進行調試並進行比較。

請不要將此答案視為最終答案,而應將其作為評論。 但是由於我的得分<50分,所以我無法發表評論。 但是,我想做出貢獻,因為我認為我的帖子可以為解決方案做出貢獻。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM