簡體   English   中英

穩定基線不適用於 tensorflow

[英]Stable Baselines doesn't work with tensorflow

所以我最近重新開始學習機器學習,並決定開始“ConnectX”的 Kaggle 課程( https://www.kaggle.com/learn/intro-to-game-ai-and-reinforcement-learning )。 我正在嘗試做第 4 課,其中我使用 stable-baselines + Tensorflow 來制作 AI。 問題是,我似乎無法正確使用穩定基線,因為當我嘗試導入它時它會立即給我一個錯誤。 這是錯誤消息:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-13-f5986851ce81> in <module>
      1 import os
----> 2 from stable_baselines.bench import Monitor
      3 from stable_baselines.common.vec_env import DummyVecEnv
      4 
      5 # Create directory for logging training information

~\Anaconda3\lib\site-packages\stable_baselines\__init__.py in <module>
----> 1 from stable_baselines.a2c import A2C
      2 from stable_baselines.acer import ACER
      3 from stable_baselines.acktr import ACKTR
      4 from stable_baselines.deepq import DQN
      5 from stable_baselines.her import HER

~\Anaconda3\lib\site-packages\stable_baselines\a2c\__init__.py in <module>
----> 1 from stable_baselines.a2c.a2c import A2C

~\Anaconda3\lib\site-packages\stable_baselines\a2c\a2c.py in <module>
      5 import tensorflow as tf
      6 
----> 7 from stable_baselines import logger
      8 from stable_baselines.common import explained_variance, tf_util, ActorCriticRLModel, SetVerbosity, TensorboardWriter
      9 from stable_baselines.common.policies import ActorCriticPolicy, RecurrentActorCriticPolicy

~\Anaconda3\lib\site-packages\stable_baselines\logger.py in <module>
     15 from tensorflow.python.util import compat
     16 
---> 17 from stable_baselines.common.misc_util import mpi_rank_or_zero
     18 
     19 DEBUG = 10

~\Anaconda3\lib\site-packages\stable_baselines\common\__init__.py in <module>
      2 from stable_baselines.common.console_util import fmt_row, fmt_item, colorize
      3 from stable_baselines.common.dataset import Dataset
----> 4 from stable_baselines.common.math_util import discount, discount_with_boundaries, explained_variance, \
      5     explained_variance_2d, flatten_arrays, unflatten_vector
      6 from stable_baselines.common.misc_util import zipsame, set_global_seeds, boolean_flag

~\Anaconda3\lib\site-packages\stable_baselines\common\math_util.py in <module>
      1 import numpy as np
----> 2 import scipy.signal
      3 
      4 
      5 def safe_mean(arr):

~\Anaconda3\lib\site-packages\scipy\signal\__init__.py in <module>
    287 
    288 """
--> 289 from . import sigtools, windows
    290 from .waveforms import *
    291 from ._max_len_seq import max_len_seq

~\Anaconda3\lib\site-packages\scipy\signal\windows\__init__.py in <module>
     39 """
     40 
---> 41 from .windows import *
     42 
     43 __all__ = ['boxcar', 'triang', 'parzen', 'bohman', 'blackman', 'nuttall',

~\Anaconda3\lib\site-packages\scipy\signal\windows\windows.py in <module>
      5 
      6 import numpy as np
----> 7 from scipy import linalg, special, fft as sp_fft
      8 
      9 __all__ = ['boxcar', 'triang', 'parzen', 'bohman', 'blackman', 'nuttall',

~\Anaconda3\lib\site-packages\scipy\special\__init__.py in <module>
    631 from .sf_error import SpecialFunctionWarning, SpecialFunctionError
    632 
--> 633 from . import _ufuncs
    634 from ._ufuncs import *
    635 

ImportError: DLL load failed: The specified module could not be found.

scipy看起來有問題,但我不知道我能做些什么來解決它。 即使我運行import stable_baselines發生此錯誤。 這是我為創建虛擬環境而運行的代碼(順便說一句,這是在 PowerShell b/c 中,這是 Jupyter Lab 給我的):

python -m venv myenv
.\myenv\Scripts\Activate.ps1
pip install stable-baselines

注意:我不知道這是否有任何意義,但是當我安裝stable-baselines時,會出現錯誤: ERROR: gym 0.17.2 has requirement cloudpickle<1.4.0,>=1.2.0, but you'll have cloudpickle 1.5.0 which is incompatible.

PS:我在這里發現了同樣的問題,但我不知道他們是如何解決的。 The answer just says "I used anaconda", but there is no stable-baselines package in anaconda, I tried installing tensorflow from anaconda and stable-baselines from pip. 但它仍然給出了同樣的錯誤。

最后編輯:看起來這個問題在. 導入並且僅在jupyter notebook中有效(與tensorflow沒有任何關系 - 它在 Python CLI 中工作正常)。 我已經在關於opencv 新問題中解釋了它。

〜阿尤什

穩定的基線網站聲稱他們還不支持 tf2.X。 所以這可能是你的問題

試試以下,

pip install tensorflow==1.14.0
pip install stable-baselines[mpi]==2.10.0

直到今天(2020 年 9 月 4 日),他們似乎都在為我工作。

我知道這可能有點晚了,但我現在發現了你的問題,並決定盡我所能回答。 祝你好運!

如果您正在專門尋找穩定基線的 TF2 版本,請檢查以下(實驗性)分支之一:

或者,嘗試基於PyTorch而不是 Tensorflow 的穩定基線 3 (目前處於測試階段),旨在替換當前基於 TF1 的 SB2 版本:

import stable_baselines不起作用。

import stable_baselines3有效。


要安裝穩定基線,您可以使用

  • conda install -c conda-forge stable-baselines3如果你使用 Anaconda
  • 如果要使用 PIP 安裝,請安裝pip install stable-baselines

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM