簡體   English   中英

使用 tfdataset 訓練 CNN

[英]Training CNN with tfdataset

我正在嘗試使用TFRecordDataset訓練 CNN(我認為這無關,但這是我的情況)並得到以下錯誤:

ValueError:維度 0 的切片索引 0 超出范圍。 for '{{node strided_slice}} = StridedSlice[Index=DT_INT32,T=DT_INT32,begin_mask=0,ellipsis_mask=0,end_mask=0,new_axis_mask=0,shrink_axis_mask=1](形狀,strided_slice/stack,strided_slice/stack_1, strided_slice/stack_2)' 具有輸入形狀:[0]、[1]、[1]、[1] 和計算輸入張量:input[1] = <0>、input[2] = <1>、input[ 3] = <1>。

例如,這是我正在執行的代碼:

美國有線電視新聞網:

import tensorflow as tf
def get_cnn_model(input_shape=(31, 31, 9), n_outputs=4, convolutions=3, optimizer='adam', seed=26):
    tf.random.set_seed(seed=seed)
    _input = layers.Input(shape=input_shape, name='input')
    x = layers.Conv2D(64, (4, 4), activation='relu', padding='same', name=f'conv_0')(_input)
    x = layers.MaxPooling2D(2)(x)
    for i in range(convolutions - 1):
        x = layers.Conv2D(64, (4, 4), activation='relu', padding='same', name=f'conv_{i + 1}')(x)
        x = layers.MaxPooling2D(2)(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='relu', name='dense_1')(x)
    x = layers.Dropout(0.35, name='dropout_1')(x)
    x = layers.Dense(128, activation='relu', name='dense_2')(x)
    x = layers.Dropout(0.35, name='dropout_2')(x)
    p = layers.Dense(n_outputs, activation='tanh', name='p')(x)
    v = layers.Dense(1, activation='tanh', name='v')(x)
    cnn_model = Model(inputs=_input, outputs=[v, p])
    losses = {
        "v": 'mean_squared_error',
        "p": keras.losses.BinaryCrossentropy()
    }
    cnn_model.compile(loss=losses, optimizer=optimizer)
    return cnn_model

cnn = get_cnn_model((31, 31, 9), n_outputs=16, convolutions=3, optimizer='adam', seed=26)

這是示例數據集:

import numpy as np
import tensorflow as tf

v = 0.9
p = np.random.randn(16)
state = np.random.randn(31*31*9)

sample = tf.train.Example(
    features = tf.train.Features(
        feature = {
            'v': tf.train.Feature(float_list=tf.train.FloatList(value=[v])),
            'p': tf.train.Feature(float_list=tf.train.FloatList(value = p)),
            's': tf.train.Feature(float_list=tf.train.FloatList(value = state))
        }
    )
)

with tf.io.TFRecordWriter('tf_record_data') as f:
    f.write(sample.SerializeToString())

這是我得到上述錯誤的訓練過程:

def read_tfrecord(example):
    feature_desc = {
        'v': tf.io.FixedLenFeature([], tf.float32),
        'p': tf.io.VarLenFeature(tf.float32),
        's': tf.io.VarLenFeature(tf.float32)
    }
    sample = tf.io.parse_single_example(example, feature_desc)
    x = tf.reshape(tf.sparse.to_dense(parsed['s']), (1,31,31, 9))
    y = {'v':sample['v'], 'p': tf.sparse.to_dense(sample['p'])}
    return x, y

ds = tf.data.TFRecordDataset(['tf_record_data'])
ds = ds.map(read_tfrecord)

cnn.fit(ds)

有趣的是,當我對數據集進行預測時,它確實有效:

import numpy as np
for serialized in tf.data.TFRecordDataset(['tf_record_data']):
    parsed = tf.io.parse_single_example(serialized, feature_desc)
    st= tf.sparse.to_dense(parsed['s'])
    t = tf.reshape(st, (1, 31, 31, 9))
    print(cnn.predict(t))

我該如何解決這個錯誤?

我將數據記錄的 map 更改為以下內容:

def read_tfrecord(example):
    feature_desc = {
       'v': tf.io.FixedLenFeature([], tf.float32),
       'p': tf.io.VarLenFeature(tf.float32),
       's': tf.io.VarLenFeature(tf.float32)
    }
    sample = tf.io.parse_single_example(example, feature_desc)
    x = tf.reshape(tf.sparse.to_dense(parsed['s']), (1,rows,cols, layers))
    p = tf.reshape(tf.sparse.to_dense(parsed['p']), (1, 16))
    v = tf.reshape(sample['v'], (1, 1))

    y = {'v':v, 'p': p}
    return x, y

重塑輸出解決了問題

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM