簡體   English   中英

穩定基線 3 的字典觀察空間不起作用

[英]Dict Observation Space for Stable Baselines3 Not Working

我在下面創建了一個最小的可重現示例,它可以輕松地在新的 Google Colab 筆記本中運行。 第一次安裝完成后,只需Runtime > Restart and Run All即可生效。

我在下面制作了一個簡單的輪盤游戲環境進行測試。 對於觀察空間,我創建了一個gym.spaces.Dict() ,您將看到它(代碼有很好的注釋)。

它訓練得很好,但是當它進入測試迭代時,我得到了錯誤:

ValueError                                Traceback (most recent call last)
<ipython-input-56-7c2cb900b44f> in <module>
      6 obs = env.reset()
      7 for i in range(1000):
----> 8     action, _state = model.predict(obs, deterministic=True)
      9     obs, reward, done, info = env.step(action)
     10     env.render()

ValueError: Error: Unexpected observation shape () for Box environment, please use (1,) or (n_env, 1) for the observation shape.

我在某處讀到 dict 空間需要用 gym.wrappers.FlattenObservation 展平,所以我更改了這一行:

    action, _state = model.predict(obs, deterministic=True)

...到:

    action, _state = model.predict(FlattenObservation(obs), deterministic=True)

...導致此錯誤:

AttributeError                            Traceback (most recent call last)
<ipython-input-57-87824c61fc45> in <module>
      6 obs = env.reset()
      7 for i in range(1000):
----> 8     action, _state = model.predict(FlattenObservation(obs), deterministic=True)
      9     obs, reward, done, info = env.step(action)
     10     env.render()

AttributeError: 'collections.OrderedDict' object has no attribute 'observation_space'

我也嘗試過這樣做,結果與上一個錯誤相同:

obs = env.reset()
obs = FlattenObservation(obs)

很明顯我沒有做正確的事情,但我只是不知道它是什么,因為這將是我第一次使用Dict空間。

import os, sys
if not os.path.isdir('/usr/local/lib/python3.7/dist-packages/stable_baselines3'):
    !pip3 install stable_baselines3
    print("\n\n\n Stable Baselines3 has been installed, Restart and Run All now. DO NOT factory reset, or you'll have to start over\n")
    sys.exit(0)

from random import randint
from numpy import inf, float32, array, int32, int64
import gym
from gym.wrappers import FlattenObservation
from stable_baselines3 import A2C, DQN, PPO

"""Roulette environment class"""
class Roulette_Environment(gym.Env):

    metadata = {'render.modes': ['human', 'text']}

    """Initialize the environment"""
    def __init__(self):
        super(Roulette_Environment, self).__init__()

        # Some global variables
        self.max_table_limit = 1000
        self.initial_bankroll = 2000

        # Spaces
        # Each number on roulette board can have 0-1000 units placed on it
        self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,))

        # We're going to keep track of how many times each number shows up
        # while we're playing, plus our current bankroll and the max
        # table betting limit so the agent knows how much $ in total is allowed
        # to be placed on the table. Going to use a Dict space for this.
        self.observation_space = gym.spaces.Dict(
            {
                "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                
                "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
                
                "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
            }
        )

    """Reset the Environment"""
    def reset(self):
        self.current_bankroll = self.initial_bankroll
        self.done = False

        # Take a sample from the observation_space to modify the values of
        self.current_state = self.observation_space.sample()
        
        # Reset each number being tracked throughout gameplay to 0
        for i in range(0, 37):
            self.current_state[str(i)] = 0

        # Reset our globals
        self.current_state['current_bankroll'] = self.current_bankroll
        self.current_state['max_table_limit'] = self.max_table_limit
        
        return self.current_state


    """Step Through the Environment"""
    def step(self, action):
        
        # Convert actions to ints cuz they show up as floats,
        # even when defined as ints in the environment.
        # https://github.com/openai/gym/issues/3107
        for i in range(len(action)):
            action[i] = int(action[i])
        self.current_action = action
        
        # Subtract your bets from bankroll
        sum_of_bets = sum([bet for bet in self.current_action])

        # Spin the wheel
        self.current_number = randint(a=0, b=36)

        # Calculate payout/reward
        self.reward = 36 * self.current_action[self.current_number] - sum_of_bets

        self.current_bankroll += self.reward

        # Update the current state
        self.current_state['current_bankroll'] = self.current_bankroll
        self.current_state[str(self.current_number)] += 1

        # If we've doubled our money, or lost our money
        if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
            self.done = True

        return self.current_state, self.reward, self.done, {}


    """Render the Environment"""
    def render(self, mode='text'):
        # Text rendering
        if mode == "text":
            print(f'Bets Placed: {self.current_action}')
            print(f'Number rolled: {self.current_number}')
            print(f'Reward: {self.reward}')
            print(f'New Bankroll: {self.current_bankroll}')

env = Roulette_Environment()

model = PPO('MultiInputPolicy', env, verbose=1)
model.learn(total_timesteps=10000)

obs = env.reset()
# obs = FlattenObservation(obs)

for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    # action, _state = model.predict(FlattenObservation(obs), deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

不幸的是, stable-baselines3對觀察格式非常挑剔。
最后幾天我遇到了同樣的問題。
一些文檔以及示例 model幫助我解決了問題:

可以使用 Dict-observations

但是, Box -Spaces 的values必須映射為具有正確dtypes numpy.ndarrays
對於Discrete觀察,觀察也可以作為int值傳遞。 但是,我不完全確定這是否仍然適用於多維MultiDiscrete -spaces

一個非常簡單的解決方案

您的示例的解決方案是每次通過以下方式重新分配 Dict 的值時替換代碼:
self.current_state[key] = np.array([value], dtype=int)

您可以在下面找到問題的有效實現(不過,我的系統安裝了Python=3.10 。但它也應該適用於較低版本)。

工作代碼:

import os, sys

from random import randint
from numpy import inf, float32, array, int32, int64
import gym
from gym.wrappers import FlattenObservation
from stable_baselines3 import A2C, DQN, PPO
import numpy as np

"""Roulette environment class"""
class Roulette_Environment(gym.Env):

    metadata = {'render.modes': ['human', 'text']}

    """Initialize the environment"""
    def __init__(self):
        super(Roulette_Environment, self).__init__()

        # Some global variables
        self.max_table_limit = 1000
        self.initial_bankroll = 2000

        # Spaces
        # Each number on roulette board can have 0-1000 units placed on it
        self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,))

        # We're going to keep track of how many times each number shows up
        # while we're playing, plus our current bankroll and the max
        # table betting limit so the agent knows how much $ in total is allowed
        # to be placed on the table. Going to use a Dict space for this.
        self.observation_space = gym.spaces.Dict(
            {
                "0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                "36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
                
                "current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
                
                "max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
            }
        )

    """Reset the Environment"""
    def reset(self):
        self.current_bankroll = self.initial_bankroll
        self.done = False

        # Take a sample from the observation_space to modify the values of
        self.current_state = self.observation_space.sample()
        
        # Reset each number being tracked throughout gameplay to 0
        for i in range(0, 37):
            self.current_state[str(i)] = np.array([0], dtype=int)

        # Reset our globals
        self.current_state['current_bankroll'] = np.array([self.current_bankroll], dtype=int)
        self.current_state['max_table_limit'] = np.array([self.max_table_limit], dtype=int)
        
        return self.current_state


    """Step Through the Environment"""
    def step(self, action):
        
        # Convert actions to ints cuz they show up as floats,
        # even when defined as ints in the environment.
        # https://github.com/openai/gym/issues/3107
        for i in range(len(action)):
            action[i] = int(action[i])
        self.current_action = action
        
        # Subtract your bets from bankroll
        sum_of_bets = sum([bet for bet in self.current_action])

        # Spin the wheel
        self.current_number = randint(a=0, b=36)

        # Calculate payout/reward
        self.reward = 36 * self.current_action[self.current_number] - sum_of_bets

        self.current_bankroll += self.reward

        # Update the current state
        self.current_state['current_bankroll'] = np.array([self.current_bankroll], dtype=int)
        self.current_state[str(self.current_number)] += np.array([1], dtype=int)

        # If we've doubled our money, or lost our money
        if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
            self.done = True

        return self.current_state, self.reward, self.done, {}


    """Render the Environment"""
    def render(self, mode='text'):
        # Text rendering
        if mode == "text":
            print(f'Bets Placed: {self.current_action}')
            print(f'Number rolled: {self.current_number}')
            print(f'Reward: {self.reward}')
            print(f'New Bankroll: {self.current_bankroll}')

env = Roulette_Environment()

model = PPO('MultiInputPolicy', env, verbose=1)
model.learn(total_timesteps=10)

obs = env.reset()
# obs = FlattenObservation(obs)

for i in range(1000):
    action, _state = model.predict(obs, deterministic=True)
    # action, _state = model.predict(FlattenObservation(obs), deterministic=True)
    obs, reward, done, info = env.step(action)
    env.render()
    if done:
      obs = env.reset()

您在這里有 3 個不同的問題。

首先,您的主要問題在於reset方法。 您已將項目定義為帶有shape=(1,)的框。 然而,在reset中,您只需為您的項目分配整數,例如,這里self.current_state[str(i)] = 0和稍后的current_bankrollmax_table_limit鍵。 SB3 的BasePolicypredictnp.array(your_integer_value)包裝你的 dict 值,它有一個形狀() ,它會引發異常,因為它與你的盒子形狀不兼容。 將您的初始值分別更改為 1-size arrays,例如self.current_state[str(i)] = [0] 還要更改您的step方法以更新 1-size 列表,而不是整數。 這將解決您的形狀不足問題。

其次,您實際上可以通過手動將所有單一形狀的Box扁平化為一個來擺脫Dict 您的low將分別變成一個列表(如果您將current_bankroll值的low更改為0,那么您甚至不需要編輯low ,它還可以是一個整數)。

第三,除了上面提到的,你的環境看起來是正確的。 但是,sb3 中有一個錯誤。 我假設您已經使用最新的 1.6.2 標簽(10 月 10 日)安裝了帶有 pip 的 sb3。 在這個版本中,有一個錯誤將np.ndarray BaseAlgortihm.predict 后來在 master 分支中修復了。 所以直接從 git 安裝 sb3。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM