簡體   English   中英

在 Keras 中訓練變分自動編碼器引發“InvalidArgumentError:不兼容的形狀”錯誤

[英]Training Variational Auto Encoder in Keras raises “InvalidArgumentError: Incompatible shapes” error

我一直試圖讓這個 VAE 整晚都在工作,但是一遍又一遍地遇到同樣的問題。 我不確定問題是什么。 我試過刪除回調、驗證、更改損失 function、更改采樣方法。 該錯誤(雖然在下面顯示為提前停止)一直是添加到擬合 function 的最后一個參數。 我不知道如何讓它發揮作用。

下面是可重現的代碼,然后是我一直遇到的錯誤。 請注意,更改批量大小確實會改變錯誤,但不匹配的數量也會隨着批量大小而減少。

import pandas as pd
from sklearn.datasets import make_blobs 
from sklearn.preprocessing import MinMaxScaler

import keras.backend as K
import tensorflow as tf

from keras.layers import Input, Dense, Lambda, Layer, Add, Multiply
from keras.models import Model, Sequential
from keras.callbacks import EarlyStopping, LearningRateScheduler
from keras.objectives import binary_crossentropy


x, labels = make_blobs(n_samples=150000, n_features=110,  centers=16, cluster_std=4.0)
scaler = MinMaxScaler()
x = scaler.fit_transform(x)
x = pd.DataFrame(x)

train = x.sample(n = 100000)
train_indexs = train.index.values
test = x[~x.index.isin(train_indexs)]
print(train.shape, test.shape)

min_dim = 2
batch_size = 1024

def sampling(args):
    mu, log_sigma = args
    eps = K.random_normal(shape=(batch_size, min_dim), mean = 0.0, stddev = 1.0)
    return mu + K.exp(0.5 * log_sigma) * eps

#Encoder
inputs = Input(shape=(x.shape[1],))
down1 = Dense(64, activation='relu')(inputs)
mu = Dense(min_dim, activation='linear')(down1)
log_sigma = Dense(min_dim, activation='linear')(down1)

#Sampling
sample_set = Lambda(sampling, output_shape=(min_dim,))([mu, log_sigma])

#decoder
up1 = Dense(64, activation='relu')(sample_set)
output = Dense(x.shape[1], activation='sigmoid')(up1)

vae = Model(inputs, output)
encoder = Model(inputs, mu)

def vae_loss(y_true, y_pred):
    recon  = binary_crossentropy(y_true, y_pred)
    kl = - 0.5 * K.mean(1 + log_sigma - K.square(mu) - K.exp(log_sigma), axis=-1)
    return recon + kl

vae.compile(optimizer='adam', loss=vae_loss)
vae.fit(train, train, shuffle = True, epochs = 1000, 
        batch_size = batch_size, validation_data = (test, test), 
        callbacks = [EarlyStopping(patience=50)])

錯誤:


  File "<ipython-input-2-7aa4be21434d>", line 62, in <module>
    callbacks = [EarlyStopping(patience=50)])

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\keras\engine\training.py", line 1239, in fit
    validation_freq=validation_freq)

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\keras\engine\training_arrays.py", line 196, in fit_loop
    outs = fit_function(ins_batch)

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\keras\backend.py", line 3792, in __call__
    outputs = self._graph_fn(*converted_inputs)

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1605, in __call__
    return self._call_impl(args, kwargs)

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1645, in _call_impl
    return self._call_flat(args, self.captured_inputs, cancellation_manager)

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1746, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 598, in call
    ctx=ctx)

  File "C:\Users\se01040434\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)

InvalidArgumentError:  Incompatible shapes: [672] vs. [1024]
     [[node gradients/loss/dense_5_loss/vae_loss/weighted_loss/mul_grad/Mul_1 (defined at C:\Users\se01040434\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_1515]

Function call stack:
keras_scratch_graph

您正在創建一個具有batch_size樣本的隨機張量,其中batch_size是代碼中的固定預設值。 但是,請注意,model 可能不一定需要與batch_size輸入樣本一樣多(例如,最后一批訓練/測試數據的樣本數量可能較少)。 相反,在您的 model 實現取決於批量大小的動態值的這些情況下,您應該使用keras.backend.shape function 動態獲取它:

def sampling(args):
    # ...
    eps = K.random_normal(shape=(K.shape(mu)[0], min_dim)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM