簡體   English   中英

預測新結果時檢查模型輸入角點時出錯

[英]Error when checking model input keras when predicting new results

我正在嘗試使用基於新數據構建的keras模型,除了在嘗試預測預測時出現輸入錯誤。

這是我的模型代碼:

def build_model(max_features, maxlen):
    """Build LSTM model"""
    model = Sequential()
    model.add(Embedding(max_features, 128, input_length=maxlen))
    model.add(LSTM(128))
    model.add(Dropout(0.5))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer='rmsprop')

    return model

我的代碼可以預測新數據的輸出預測:

LSTM_model = load_model('LSTMmodel.h5')
data = pickle.load(open('traindata.pkl', 'rb'))


#### LSTM ####

"""Run train/test on logistic regression model"""

# Extract data and labels
X = [x[1] for x in data]
labels = [x[0] for x in data]

# Generate a dictionary of valid characters
valid_chars = {x:idx+1 for idx, x in enumerate(set(''.join(X)))}

max_features = len(valid_chars) + 1
maxlen = np.max([len(x) for x in X])

# Convert characters to int and pad
X = [[valid_chars[y] for y in x] for x in X]
X = sequence.pad_sequences(X, maxlen=maxlen)

# Convert labels to 0-1
y = [0 if x == 'benign' else 1 for x in labels]


y_pred = LSTM_model.predict(X)

運行此代碼時出現的錯誤:

ValueError: Error when checking input: expected embedding_1_input to have shape (57,) but got array with shape (36,)

我的錯誤來自maxlen因為對於我的訓練數據, maxlen=57和新數據, maxlen=36

因此,我嘗試在預測代碼中設置maxlen=57但隨后出現此錯誤:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[31,53] = 38 is not in [0, 38)
     [[Node: embedding_1/embedding_lookup = GatherV2[Taxis=DT_INT32, Tindices=DT_INT32, Tparams=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embedding_1/embeddings/read, embedding_1/Cast, embedding_1/embedding_lookup/axis)]]

為了解決這些問題我該怎么辦? 更改我的嵌入層?

可以將Embedding層的input_length設置為您將在數據集中看到的最大長度,或者僅使用在pad_sequences構建模型時使用的maxlen值。 在這種情況下,任何序列短於maxlen將被填充任何序列長於maxlen將被截斷。

進一步確保您使用的功能在訓練和測試時間內都相同(即,其編號不應更改)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM