簡體   English   中英

TFAGENTS:關於 DqnAgent 代理的 observation_and_action_constraint_splitter 用法的說明

[英]TFAGENTS: clarification on the usage of observation_and_action_constraint_splitter for DqnAgent agents

根據這篇文章,我正在嘗試創建一個帶有有效/無效操作掩碼的 DqnAgent 代理,我應該為observation_and_action_constraint_splitter arg 指定一個splitter_fn 根據 tf_agents 文檔

, splitter_fn會是這樣的:

def observation_and_action_constraint_splitter(observation):
  return observation['network_input'], observation['constraint'] 

在我看來,我認為變量observation應該是env.step(action).observation返回的數組,在我的例子中它是一個形狀為 (56,) 的數組(它是一個扁平化的數組,原始形狀為 (14,4 ), 每行是每個選擇的4個特征值,有5-14個選擇,如果選擇無效則相應的特征將全部為0),所以我這樣寫了我的splitter_fn:

def observation_and_action_constrain_splitter(observation):
     print(observation)
     temp = observation.reshape(14,-1)
     action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
     return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)

agent = DqnAgent(
    tf_time_step_spec,
    tf_action_spec,
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=tf_common.element_wise_squared_loss,
    train_step_counter=train_step_counter,
    observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
)

但是,在運行上述代碼單元時,它返回了以下錯誤:

BoundedTensorSpec(shape=(56,), dtype=tf.float32, name='observation', minimum=array(-3.4028235e+38, dtype=float32), maximum=array(3.4028235e+38, dtype=float32))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-213-07450ea5ba21> in <module>()
     13     td_errors_loss_fn=tf_common.element_wise_squared_loss,
     14     train_step_counter=train_step_counter,
---> 15     observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
     16     )
     17 

4 frames
<ipython-input-212-dbfee6076511> in observation_and_action_constrain_splitter(observation)
      1 def observation_and_action_constrain_splitter(observation):
      2      print(observation)
----> 3      temp = observation.reshape(14,-1)
      4      action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
      5      return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)

AttributeError: 'BoundedTensorSpec' object has no attribute 'reshape'
  In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)

結果是print(observation)返回一個BoundedTensorSpec object,而不是一個數組,也不是一個tf.Tensor object。我如何從BoundedTensorSpec創建我的動作掩碼,它甚至不包含用於觀察的數組?

提前致謝!

PS:tf_agents 版本為 0.12.0

我遇到了同樣的問題。 我通過將 function observation_and_action_constrain_splitter傳遞給策略而不是DqnAgent來解決它

agent = categorical_dqn_agent.CategoricalDqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec(),
                                                observation_and_action_constraint_splitter=observation_and_action_constraint_splitter)

我希望這對你有幫助。

Cheick 的解決方案將起作用,但前提是您單獨定義策略 但是,我有同樣的問題,如果你想直接在Dqn Agent上使用observation_and_action_constraint_splitter function,我找到了另一個解決方案。

例如,如果您查看 DqnAgent 的__init__ (對於 CategoricalDqnAgent 也是如此),您將在構造函數的頂部看到此部分:

net_observation_spec = time_step_spec.observation
if observation_and_action_constraint_splitter:
    net_observation_spec, _ = observation_and_action_constraint_splitter(net_observation_spec)
q_network.create_variables(net_observation_spec)

它的作用是使用代理構造函數中給出的observation_and_action_constraint_splitter function 並使用 .net_observation_spec object 調用它,這實際上是觀察規范!

注意:此調用僅在初始化時進行一次。 之后function會被策略正常調用

現在,在構造函數下面的幾行中,調用了 function _setup_policy,它定義了代理策略和收集策略。

例如,Dqn 代理策略將定義如下:

policy = q_policy.QPolicy(
    time_step_spec,
    action_spec,
    q_network=self._q_network,
    emit_log_probability=emit_log_probability,
    observation_and_action_constraint_splitter=(
        self._observation_and_action_constraint_splitter
    ))

這就是您的 function 獲取保單的方式。

因此,我的解決方案是簡單地在observation_and_action_constraint_splitter function 中添加一個檢查,如果觀察參數是 TensorSpec(BoundedTensorSpec 的父級),則按原樣返回觀察。

我的解決方案:

def observation_and_action_constraint_splitter(observation):
    if isinstance(observation, tf.TensorSpec):
        return observation, None
    
    # rest of your method here!
    # ...

    # return observation and action_mask
    return observation, action_mask

我不確定它為什么這樣做,但我懷疑這是為了防止您有一個環境,您不希望將所有觀察結果作為輸入,而只是其中的一個子集。 在這種情況下,您可以使用它來通知代理修改后的 observation_spec? 我不確定。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM