简体   繁体   中英

How to pass the batchsize for a custom environment in Tf-agents

I am using tf-agents library to build a contextual bandit. For this I am building a custom environment.
I am creating a banditpyenvironment and wrapping it in the TFpyenvironment.

The tfpyenvironment automatically adds the batch size dimension (in observation spec). I need to account for this batch size dimension in the _observe and _apply_Action methods. Since depending on the batch size, I should provide the required (batch size) number of observations (for observe) and also as per batch size, I should take in batch size number of actions and should provide the rewards(for apply action).

I am unable to find a single example on how to tell the tfenvironment what the batch size, without letting automatically add a 1 to the first dimension. Can someone please clarify

 def __init__(self, batch_size):

    self.batchsize=batch_size
    observation_spec = BoundedTensorSpec(
    (2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
    action_spec = BoundedTensorSpec(
        shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')


    super(SampleEnvironment, self).__init__(observation_spec, action_spec)

  def _observe(self):
    batch=[]
    for i in range(self.batchsize):
        each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
        batch.append(each)
    self.observation=np.array(batch)
    print("in observe",self.observation)
    return np.array(self.observation)

When I try to somehow account for the batchsize in the observe method like above (using a for loop for the batch size), the tfenvironment is again adding 1 to the first dimension as batchsize. Is there a way to automatically tell the environment that the batch is say 3, instead of it automatically adding 1. At the same time, how would I account for this batch size in replay buffer and agents

This can be done using the BatchedPyEnvironment class as show in the example below. Looks like the bandit environment from above is a non batched environment.

SampleEnvironment in below is the banditpyenvironment which is shown in the question

batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM