![](/img/trans.png)
[英]Shape of _observation_spec and shape of _action_spec in the Tf-agents environments example
[英]TFAGENTS: clarification on the usage of observation_and_action_constraint_splitter for DqnAgent agents
根據這篇文章,我正在嘗試創建一個帶有有效/無效操作掩碼的 DqnAgent 代理,我應該為observation_and_action_constraint_splitter
arg 指定一個splitter_fn
。 根據 tf_agents 文檔
, splitter_fn
會是這樣的:
def observation_and_action_constraint_splitter(observation):
return observation['network_input'], observation['constraint']
在我看來,我認為變量observation
應該是env.step(action).observation
返回的數組,在我的例子中它是一個形狀為 (56,) 的數組(它是一個扁平化的數組,原始形狀為 (14,4 ), 每行是每個選擇的4個特征值,有5-14個選擇,如果選擇無效則相應的特征將全部為0),所以我這樣寫了我的splitter_fn:
def observation_and_action_constrain_splitter(observation):
print(observation)
temp = observation.reshape(14,-1)
action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)
agent = DqnAgent(
tf_time_step_spec,
tf_action_spec,
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=tf_common.element_wise_squared_loss,
train_step_counter=train_step_counter,
observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
)
但是,在運行上述代碼單元時,它返回了以下錯誤:
BoundedTensorSpec(shape=(56,), dtype=tf.float32, name='observation', minimum=array(-3.4028235e+38, dtype=float32), maximum=array(3.4028235e+38, dtype=float32))
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-213-07450ea5ba21> in <module>()
13 td_errors_loss_fn=tf_common.element_wise_squared_loss,
14 train_step_counter=train_step_counter,
---> 15 observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
16 )
17
4 frames
<ipython-input-212-dbfee6076511> in observation_and_action_constrain_splitter(observation)
1 def observation_and_action_constrain_splitter(observation):
2 print(observation)
----> 3 temp = observation.reshape(14,-1)
4 action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
5 return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)
AttributeError: 'BoundedTensorSpec' object has no attribute 'reshape'
In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)
結果是print(observation)
返回一個BoundedTensorSpec
object,而不是一個數組,也不是一個tf.Tensor
object。我如何從BoundedTensorSpec
創建我的動作掩碼,它甚至不包含用於觀察的數組?
提前致謝!
PS:tf_agents 版本為 0.12.0
我遇到了同樣的問題。 我通過將 function observation_and_action_constrain_splitter
傳遞給策略而不是DqnAgent
來解決它
agent = categorical_dqn_agent.CategoricalDqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
categorical_q_network=categorical_q_net,
optimizer=optimizer,
min_q_value=min_q_value,
max_q_value=max_q_value,
n_step_update=n_step_update,
td_errors_loss_fn=common.element_wise_squared_loss,
gamma=gamma,
train_step_counter=train_step_counter)
agent.initialize()
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec(),
observation_and_action_constraint_splitter=observation_and_action_constraint_splitter)
我希望這對你有幫助。
Cheick 的解決方案將起作用,但前提是您單獨定義策略。 但是,我有同樣的問題,如果你想直接在Dqn Agent上使用observation_and_action_constraint_splitter
function,我找到了另一個解決方案。
例如,如果您查看 DqnAgent 的__init__
(對於 CategoricalDqnAgent 也是如此),您將在構造函數的頂部看到此部分:
net_observation_spec = time_step_spec.observation
if observation_and_action_constraint_splitter:
net_observation_spec, _ = observation_and_action_constraint_splitter(net_observation_spec)
q_network.create_variables(net_observation_spec)
它的作用是使用代理構造函數中給出的observation_and_action_constraint_splitter
function 並使用 .net_observation_spec object 調用它,這實際上是觀察規范!
注意:此調用僅在初始化時進行一次。 之后function會被策略正常調用
現在,在構造函數下面的幾行中,調用了 function _setup_policy,它定義了代理策略和收集策略。
例如,Dqn 代理策略將定義如下:
policy = q_policy.QPolicy(
time_step_spec,
action_spec,
q_network=self._q_network,
emit_log_probability=emit_log_probability,
observation_and_action_constraint_splitter=(
self._observation_and_action_constraint_splitter
))
這就是您的 function 獲取保單的方式。
因此,我的解決方案是簡單地在observation_and_action_constraint_splitter
function 中添加一個檢查,如果觀察參數是 TensorSpec(BoundedTensorSpec 的父級),則按原樣返回觀察。
我的解決方案:
def observation_and_action_constraint_splitter(observation):
if isinstance(observation, tf.TensorSpec):
return observation, None
# rest of your method here!
# ...
# return observation and action_mask
return observation, action_mask
我不確定它為什么這樣做,但我懷疑這是為了防止您有一個環境,您不希望將所有觀察結果作為輸入,而只是其中的一個子集。 在這種情況下,您可以使用它來通知代理修改后的 observation_spec? 我不確定。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.