簡體   English   中英

如何從 Rllib 的 PPO 算法中獲取一系列觀察值的價值函數/評論值?

[英]How do I get value function/critic values from Rllib's PPO algorithm for a range of observations?

目標:我想針對某個問題訓練 PPO 代理,並針對一系列觀察確定其最優值 function。 稍后我計划使用這個值 function(經濟不平等研究)。 該問題非常復雜,以至於動態規划技術不再適用。

方法:為了檢查我是否得到值 function 的正確輸出,我在一個簡單問題上訓練了 PPO,其解析解是已知的。 然而,值 function 的結果是垃圾,這就是為什么我懷疑我做錯了什么。

代碼:

from keras import backend as k_util
...

parser = argparse.ArgumentParser()

# Define framework to use
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "tfe", "torch"],
    default="tf",
    help="The DL framework specifier.",
)
...

def get_rllib_config(seeds, debug=False, framework="tf") -> Dict:
...

def get_value_function(agent, min_state, max_state):
    policy = agent.get_policy()
    value_function = []
    for i in np.arange(min_state, max_state, 1):
        model_out, _ = policy.model({"obs": np.array([[i]], dtype=np.float32)})
        value = k_util.eval(policy.model.value_function())[0]
        value_function.append(value)
        print(i, value)
    return value_function


def train_schedule(config, reporter):
    rllib_config = config["config"]
    iterations = rllib_config.pop("training_iteration", 10)

    agent = PPOTrainer(env=rllib_config["env"], config=rllib_config)
    for _ in range(iterations):
        result = agent.train()
        reporter(**result)
    values = get_value_function(agent, 0, 100)
    print(values)
    agent.stop()

...

resources = PPO.default_resource_request(exp_config)
tune_analysis = tune.Tuner(tune.with_resources(train_schedule, resources=resources), param_space=exp_config).fit()
ray.shutdown()

所以首先我得到策略( policy = agent.get_policy() )並使用 100 個值中的每一個運行前向傳遞( model_out, _ = policy.model({"obs": np.array([[i]], dtype=np.float32)}) )。 然后,在每次前向傳遞后,我使用 value_function() 方法獲取 critic.network 的 output 並通過 keras 后端評估張量。

結果: True VF(分析溶液) VF output of Rllib

不幸的是,您可以看到結果並不那么有希望。 也許我錯過了預處理或后處理步驟? value_function() 方法甚至返回 critic.network 的最后一層嗎?

我非常感謝任何幫助!

它不是您腳本的一部分,但我假設您在嘗試從中獲取有用的值之前已經對策略進行了培訓。

您假設 value_function() 返回 RLlib 實現中 critic.network 最后一層的 output 是正確的。 查看值 function 指標,看看它是否真的在學習任何東西(RLlib 日志.../learner_stats/vf_loss.../learner_stats/vf_explained_var ),在訓練 model 之后。我也會嘗試直接查詢 model ,如果這樣看起來更好。 您在此處發布的代碼可能有問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM