|
1 | | -# Copyright 2022 The Brax Authors. |
| 1 | +# Copyright 2023 The Brax Authors. |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | # pylint:disable=g-multiple-import |
16 | | -"""Some example environments to help get started quickly with brax.""" |
| 16 | +"""Environments for training and evaluating policies.""" |
17 | 17 |
|
18 | 18 | import functools |
19 | | -from typing import Callable, Optional, Type, Union, overload |
| 19 | +from typing import Optional, Type |
20 | 20 |
|
21 | | -from brax.envs import acrobot |
22 | 21 | from brax.envs import ant |
23 | 22 | from brax.envs import fast |
24 | | -from brax.envs import fetch |
25 | | -from brax.envs import grasp |
26 | 23 | from brax.envs import half_cheetah |
27 | 24 | from brax.envs import hopper |
28 | 25 | from brax.envs import humanoid |
29 | | -from brax.envs import humanoid_standup |
| 26 | +from brax.envs import humanoidstandup |
30 | 27 | from brax.envs import inverted_double_pendulum |
31 | 28 | from brax.envs import inverted_pendulum |
32 | 29 | from brax.envs import pusher |
33 | 30 | from brax.envs import reacher |
34 | | -from brax.envs import reacherangle |
35 | | -from brax.envs import swimmer |
36 | | -from brax.envs import ur5e |
37 | 31 | from brax.envs import walker2d |
38 | | -from brax.envs import wrappers |
39 | | -from brax.envs.env import Env, State, Wrapper |
40 | | -import gym |
41 | | - |
| 32 | +from brax.envs import wrapper |
| 33 | +from brax.envs.env import Env, State |
42 | 34 |
|
43 | 35 | _envs = { |
44 | | - 'acrobot': acrobot.Acrobot, |
45 | | - 'ant': functools.partial(ant.Ant, use_contact_forces=True), |
| 36 | + 'ant': ant.Ant, |
46 | 37 | 'fast': fast.Fast, |
47 | | - 'fetch': fetch.Fetch, |
48 | | - 'grasp': grasp.Grasp, |
49 | 38 | 'halfcheetah': half_cheetah.Halfcheetah, |
50 | 39 | 'hopper': hopper.Hopper, |
51 | 40 | 'humanoid': humanoid.Humanoid, |
52 | | - 'humanoidstandup': humanoid_standup.HumanoidStandup, |
| 41 | + 'humanoidstandup': humanoidstandup.HumanoidStandup, |
53 | 42 | 'inverted_pendulum': inverted_pendulum.InvertedPendulum, |
54 | 43 | 'inverted_double_pendulum': inverted_double_pendulum.InvertedDoublePendulum, |
55 | 44 | 'pusher': pusher.Pusher, |
56 | 45 | 'reacher': reacher.Reacher, |
57 | | - 'reacherangle': reacherangle.ReacherAngle, |
58 | | - 'swimmer': swimmer.Swimmer, |
59 | | - 'ur5e': ur5e.Ur5e, |
60 | 46 | 'walker2d': walker2d.Walker2d, |
61 | 47 | } |
62 | 48 |
|
63 | 49 |
|
64 | | -def get_environment(env_name, **kwargs) -> Env: |
| 50 | + |
| 51 | +def get_environment(env_name: str, **kwargs) -> Env: |
| 52 | + """Returns an environment from the environment registry. |
| 53 | +
|
| 54 | + Args: |
| 55 | + env_name: environment name string |
| 56 | + **kwargs: keyword arguments that get passed to the Env class constructor |
| 57 | +
|
| 58 | + Returns: |
| 59 | + env: an environment |
| 60 | + """ |
65 | 61 | return _envs[env_name](**kwargs) |
66 | 62 |
|
67 | 63 |
|
68 | 64 | def register_environment(env_name: str, env_class: Type[Env]): |
| 65 | + """Adds an environment to the registry. |
| 66 | +
|
| 67 | + Args: |
| 68 | + env_name: environment name string |
| 69 | + env_class: the Env class to add to the registry |
| 70 | + """ |
69 | 71 | _envs[env_name] = env_class |
70 | 72 |
|
71 | 73 |
|
72 | | -def create(env_name: str, |
73 | | - episode_length: int = 1000, |
74 | | - action_repeat: int = 1, |
75 | | - auto_reset: bool = True, |
76 | | - batch_size: Optional[int] = None, |
77 | | - eval_metrics: bool = False, |
78 | | - **kwargs) -> Env: |
79 | | - """Creates an Env with a specified brax system.""" |
| 74 | +def create( |
| 75 | + env_name: str, |
| 76 | + episode_length: int = 1000, |
| 77 | + action_repeat: int = 1, |
| 78 | + auto_reset: bool = True, |
| 79 | + batch_size: Optional[int] = None, |
| 80 | + **kwargs, |
| 81 | +) -> Env: |
| 82 | + """Creates an environment from the registry. |
| 83 | +
|
| 84 | + Args: |
| 85 | + env_name: environment name string |
| 86 | + episode_length: length of episode |
| 87 | + action_repeat: how many repeated actions to take per environment step |
| 88 | + auto_reset: whether to auto reset the environment after an episode is done |
| 89 | + batch_size: the number of environments to batch together |
| 90 | + **kwargs: keyword argments that get passed to the Env class constructor |
| 91 | +
|
| 92 | + Returns: |
| 93 | + env: an environment |
| 94 | + """ |
80 | 95 | env = _envs[env_name](**kwargs) |
| 96 | + |
81 | 97 | if episode_length is not None: |
82 | | - env = wrappers.EpisodeWrapper(env, episode_length, action_repeat) |
| 98 | + env = wrapper.EpisodeWrapper(env, episode_length, action_repeat) |
83 | 99 | if batch_size: |
84 | | - env = wrappers.VectorWrapper(env, batch_size) |
| 100 | + env = wrapper.VmapWrapper(env, batch_size) |
85 | 101 | if auto_reset: |
86 | | - env = wrappers.AutoResetWrapper(env) |
87 | | - if eval_metrics: |
88 | | - env = wrappers.EvalWrapper(env) |
| 102 | + env = wrapper.AutoResetWrapper(env) |
89 | 103 |
|
90 | 104 | return env # type: ignore |
91 | | - |
92 | | - |
93 | | -def create_fn(env_name: str, **kwargs) -> Callable[..., Env]: |
94 | | - """Returns a function that when called, creates an Env.""" |
95 | | - return functools.partial(create, env_name, **kwargs) |
96 | | - |
97 | | - |
98 | | -@overload |
99 | | -def create_gym_env(env_name: str, |
100 | | - batch_size: None = None, |
101 | | - seed: int = 0, |
102 | | - backend: Optional[str] = None, |
103 | | - **kwargs) -> gym.Env: |
104 | | - ... |
105 | | - |
106 | | - |
107 | | -@overload |
108 | | -def create_gym_env(env_name: str, |
109 | | - batch_size: int, |
110 | | - seed: int = 0, |
111 | | - backend: Optional[str] = None, |
112 | | - **kwargs) -> gym.vector.VectorEnv: |
113 | | - ... |
114 | | - |
115 | | - |
116 | | -def create_gym_env(env_name: str, |
117 | | - batch_size: Optional[int] = None, |
118 | | - seed: int = 0, |
119 | | - backend: Optional[str] = None, |
120 | | - **kwargs) -> Union[gym.Env, gym.vector.VectorEnv]: |
121 | | - """Creates a `gym.Env` or `gym.vector.VectorEnv` from a Brax environment.""" |
122 | | - environment = create(env_name=env_name, batch_size=batch_size, **kwargs) |
123 | | - if batch_size is None: |
124 | | - return wrappers.GymWrapper(environment, seed=seed, backend=backend) |
125 | | - if batch_size <= 0: |
126 | | - raise ValueError( |
127 | | - '`batch_size` should either be None or a positive integer.') |
128 | | - return wrappers.VectorGymWrapper(environment, seed=seed, backend=backend) |
0 commit comments