簡體   English   中英

如何修復此 InvalidArgumentError?

[英]How do I fix this InvalidArgumentError?

我正在為我的機器學習課研究這種國際象棋算法,但我不確定出了什么問題。 我正在關注這里的視頻,但是當我嘗試擬合我的模型時,似乎一切都出錯了。 我附上了下面的代碼,它設置了一個棋盤,然后是一個卷積網絡。 我不斷收到錯誤:

InvalidArgumentError: Graph Execution Error which points to model.fit(x_train, y_train). 
The size of x_train is (150000, 14, 8, 8) while y_train is (150000, ) 

代碼:

 def random_board(max_depth=200):
        board = chess.Board()
        depth = random.randrange(0, max_depth)
        
        for _ in range(depth):
            all_moves = list(board.legal_moves)
            random_move = random.choice(all_moves)
            board.push(random_move)
            if board.is_game_over():
                break
        return board          
    
    
      squares_index = {
          'a': 0,
          'b': 1,
          'c': 2,
          'd': 3,
          'e': 4,
          'f': 5,
          'g': 6,
          'h': 7
        }
        
        
        # example: h3 -> 17
        def square_to_index(square):
            letter = chess.square_name(square)
            return 8 - int(letter[1]), squares_index[letter[0]]
        
        
        def split_dims(board):
          # create empty 3d matrix for board 
            board3d = numpy.zeros((14, 8, 8), dtype=numpy.int8)
          # here we add the pieces's view on the matrix
            for piece in chess.PIECE_TYPES:
                for square in board.pieces(piece, chess.WHITE):
                    idx = numpy.unravel_index(square, (8, 8))
                    board3d[piece - 1][7 - idx[0]][idx[1]] = 1
            for square in board.pieces(piece, chess.BLACK):
                    idx = numpy.unravel_index(square, (8, 8))
                    board3d[piece + 5][7 - idx[0]][idx[1]] = 1
        
          # add attacks and valid moves too
          # so the network knows what is being attacked
            aux = board.turn
            board.turn = chess.WHITE
            for move in board.legal_moves:
                i, j = square_to_index(move.to_square)
                board3d[12][i][j] = 1
                board.turn = chess.BLACK
            for move in board.legal_moves:
                i, j = square_to_index(move.to_square)
                board3d[13][i][j] = 1
            board.turn = aux
        
            return board3d
    
    import tensorflow.keras.models as models
    import tensorflow.keras.layers as layers
    import tensorflow.keras.utils as utils
    import tensorflow.keras.optimizers as optimizers 
    
    def build_model(conv_size, conv_depth):
        board3d = layers.Input(shape=(14, 8, 8))
        
        #convolutional layers
        x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(board3d)
        for _ in range(conv_depth):
            previous = x
            x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation('relu')(x)
            x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Add()([x, previous])
            x = layers.Activation('relu')(x)
        x = layers.Flatten()(x)
        x = layers.Dense(1, 'sigmoid')(x)
    
        return models.Model(inputs=board3d, outputs=x)
    
    model = build_model(32, 4)
    utils.plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=False)
    
    import tensorflow.keras.callbacks as callbacks
    
    def get_dataset():
        container = numpy.load('dataset\\dataset.npz')
        b, v = container['b'], container['v']
        v = numpy.asarray(v / abs(v).max()/2 + 0.5, dtype=numpy.float32) #normalize 
        return b, v
    
    x_train, y_train = get_dataset()
    
    model.compile(optimizer=optimizers.Adam(5e-4), loss='mean_squared_error')
    model.summary()
    model.fit(x_train, y_train,
              batch_size=2048,
              epochs=1000,
              verbose=1,
              validation_split=0.1,
              callbacks=[callbacks.ReduceLROnPlateau(monitor='loss', patience=10),
                         callbacks.EarlyStopping(monitor='loss', patience=15, min_delta=1e-4)])
    model.save('model.h5')
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_15172/3395566405.py in <module>
      1 model.compile(optimizer=optimizers.Adam(5e-4), loss='mean_squared_error')
      2 model.summary()
----> 3 model.fit(x_train, y_train,
      4           batch_size=2048,
      5           epochs=1000,

C:\ProgramData\Anaconda3\lib\site-packages\keras\utils\traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     52   try:
     53     ctx.ensure_initialized()
---> 54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:

InvalidArgumentError: Graph execution error:

編輯:model.summary() 的輸出日志

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 14, 8, 8)]   0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 32, 8, 8)     4064        ['input_1[0][0]']                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 32, 8, 8)     9248        ['conv2d[0][0]']                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 8, 8)    32          ['conv2d_1[0][0]']               
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 32, 8, 8)     0           ['batch_normalization[0][0]']    
                                                                                                  
 conv2d_2 (Conv2D)              (None, 32, 8, 8)     9248        ['activation[0][0]']             
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_2[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add (Add)                      (None, 32, 8, 8)     0           ['batch_normalization_1[0][0]',  
                                                                  'conv2d[0][0]']                 
                                                                                                  
 activation_1 (Activation)      (None, 32, 8, 8)     0           ['add[0][0]']                    
                                                                                                  
 conv2d_3 (Conv2D)              (None, 32, 8, 8)     9248        ['activation_1[0][0]']           
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_3[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_2 (Activation)      (None, 32, 8, 8)     0           ['batch_normalization_2[0][0]']  
                                                                                                  
 conv2d_4 (Conv2D)              (None, 32, 8, 8)     9248        ['activation_2[0][0]']           
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_4[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add_1 (Add)                    (None, 32, 8, 8)     0           ['batch_normalization_3[0][0]',  
                                                                  'activation_1[0][0]']           
                                                                                                  
 activation_3 (Activation)      (None, 32, 8, 8)     0           ['add_1[0][0]']                  
                                                                                                  
 conv2d_5 (Conv2D)              (None, 32, 8, 8)     9248        ['activation_3[0][0]']           
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_5[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_4 (Activation)      (None, 32, 8, 8)     0           ['batch_normalization_4[0][0]']  
                                                                                                  
 conv2d_6 (Conv2D)              (None, 32, 8, 8)     9248        ['activation_4[0][0]']           
                                                                                                  
 batch_normalization_5 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_6[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add_2 (Add)                    (None, 32, 8, 8)     0           ['batch_normalization_5[0][0]',  
                                                                  'activation_3[0][0]']           
                                                                                                  
 activation_5 (Activation)      (None, 32, 8, 8)     0           ['add_2[0][0]']                  
                                                                                                  
 conv2d_7 (Conv2D)              (None, 32, 8, 8)     9248        ['activation_5[0][0]']           
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_7[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_6 (Activation)      (None, 32, 8, 8)     0           ['batch_normalization_6[0][0]']  
                                                                                                  
 conv2d_8 (Conv2D)              (None, 32, 8, 8)     9248        ['activation_6[0][0]']           
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 32, 8, 8)    32          ['conv2d_8[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 add_3 (Add)                    (None, 32, 8, 8)     0           ['batch_normalization_7[0][0]',  
                                                                  'activation_5[0][0]']           
                                                                                                  
 activation_7 (Activation)      (None, 32, 8, 8)     0           ['add_3[0][0]']                  
                                                                                                  
 flatten (Flatten)              (None, 2048)         0           ['activation_7[0][0]']           
                                                                                                  
 dense (Dense)                  (None, 1)            2049        ['flatten[0][0]']                
                                                                                                  
==================================================================================================
Total params: 80,353
Trainable params: 80,225
Non-trainable params: 128
__________________________________________________________________________________________________

data_format = channels_first是罪魁禍首,在刪除代碼執行完美之后(盡管 AI 並非完美無缺)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM