簡體   English   中英

如何將 keras 張量轉換為 numpy 數組

[英]How to convert a keras tensor to a numpy array

I am trying to create a q-learning chess engine where the output of the last layer of the neural network (the density is equal to the number of legal moves) is run through a argmax() function which returns an integer that I am using作為存儲合法移動的數組的索引。 這是我的代碼的一部分:

#imports

env = gym.make('ChessAlphaZero-v0')   #builds environment
obs = env.reset()
type(obs)

done = False   #game is not won

num_actions = len(env.legal_moves)   #array where legal moves are stored

obs = chess.Board() 

model = models.Sequential()

def dqn(board):
    
    #dense layers
    
    action = layers.Dense(num_actions)(layer5)
    
    i = np.argmax(action)
    move = env.legal_moves[i]

    return keras.Model(inputs=inputs, outputs=move)

但是當我運行代碼時,出現以下錯誤:

TypeError: Cannot convert a symbolic Keras input/output to a numpy array. This error may indicate that you're trying to pass a symbolic value to a NumPy call, which is not supported. Or, you may be trying to pass Keras symbolic inputs/outputs to a TF API that does not register dispatching, preventing Keras from automatically converting the API call to a lambda layer in the Functional Model.

任何代碼示例將不勝感激,謝謝。

構建 model 並在 keras 中轉發輸入的正確方法是:

1. 構建 model

model = models.Sequential()
model.add(layers.Input(observation_shape))
model.add(layers.Dense(units=128, activation='relu'))
model.add(layers.Dense(units=num_actions, activation='softmax'))
return model

或者

inputs = layers.Input(observation_shape)
x = layers.Dense(units=128, activation='relu')(inputs)
outputs = layers.Dense(units=num_actions, activation='softmax')(x)

model = keras.Model(inputs, output)

兩種方式都是平等的。

2. 轉發觀察並采取最佳行動

action_values = model.predict(observation)
best_action_index = tf.argmax(action_values)
best_action = action_values[best_action_index]

在 keras 中自己實現 DQN 可能會非常令人沮喪。 您可能想使用一個 DRL 框架,例如tf_agents ,它實現了許多代理: https://www.tensorflow.org/agents

此存儲庫包含用於 openai 健身房環境的干凈且易於理解的 DQN 實現。 此外,它還包含使用 tf_agents 庫以及更復雜代理的示例: https://github.com/kochlisGit/Tensorflow-DQN

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM