创建您自己的自定义环境¶
本文档概述了创建新环境以及 Gymnasium 中为创建新环境而设计的相关有用 wrappers、实用程序和测试。
设置¶
推荐解决方案¶
按照 pipx 文档 安装
pipx
。然后安装 Copier
pipx install copier
替代解决方案¶
使用 Pip 或 Conda 安装 Copier
pip install copier
或者
conda install -c conda-forge copier
生成您的环境¶
您可以通过运行以下命令来检查 Copier
是否已正确安装,该命令应输出版本号
copier --version
然后,您可以直接运行以下命令,并将字符串 path/to/directory
替换为要创建新项目的目录路径。
copier copy https://github.com/Farama-Foundation/gymnasium-env-template.git "path/to/directory"
回答问题,完成后您应该获得类似以下的项目结构
.
├── gymnasium_env
│ ├── envs
│ │ ├── grid_world.py
│ │ └── __init__.py
│ ├── __init__.py
│ └── wrappers
│ ├── clip_reward.py
│ ├── discrete_actions.py
│ ├── __init__.py
│ ├── reacher_weighted_reward.py
│ └── relative_position.py
├── LICENSE
├── pyproject.toml
└── README.md
继承 gymnasium.Env¶
在学习如何创建自己的环境之前,您应该查看 Gymnasium API 文档。
为了说明继承 gymnasium.Env
的过程,我们将实现一个非常简单的游戏,称为 GridWorldEnv
。我们将编写自定义环境的代码,放在 gymnasium_env/envs/grid_world.py
中。该环境由一个固定大小的二维方形网格组成(在构建期间通过 size
参数指定)。代理可以在每个时间步长在网格单元之间垂直或水平移动。代理的目标是导航到网格上的目标,目标在情节开始时随机放置。
观察提供目标和代理的位置。
我们的环境中有 4 个动作,分别对应于“向右”、“向上”、“向左”和“向下”的移动。
一旦代理导航到目标所在的网格单元,就会发出 done 信号。
奖励是二进制和稀疏的,这意味着即时奖励始终为零,除非代理已到达目标,则奖励为 1。
该环境中一个情节(size=5
)可能如下所示
其中蓝点是代理,红色方块表示目标。
让我们逐段查看 GridWorldEnv
的源代码
声明和初始化¶
我们的自定义环境将继承自抽象类 gymnasium.Env
。您不应该忘记将 metadata
属性添加到您的类中。在其中,您应该指定环境支持的渲染模式(例如 "human"
、"rgb_array"
、"ansi"
)以及环境应渲染的帧率。每个环境都应支持 None
作为渲染模式;您无需在元数据中添加它。在 GridWorldEnv
中,我们将支持“rgb_array”和“human”模式,并以 4 FPS 的速率渲染。
环境的 __init__
方法将接受整数 size
,它决定方形网格的大小。我们将为渲染设置一些变量,并定义 self.observation_space
和 self.action_space
。在我们的例子中,观察应该提供有关代理和目标在二维网格上的位置的信息。我们将选择以字典的形式表示观察,其键为 "agent"
和 "target"
。观察可能如下所示:{"agent": array([1, 0]), "target": array([0, 3])}
。由于我们的环境中有 4 个动作(“向右”、“向上”、“向左”、“向下”),我们将使用 Discrete(4)
作为动作空间。以下是 GridWorldEnv
的声明以及 __init__
的实现
# gymnasium_env/envs/grid_world.py
from enum import Enum
import numpy as np
import pygame
import gymnasium as gym
from gymnasium import spaces
class Actions(Enum):
RIGHT = 0
UP = 1
LEFT = 2
DOWN = 3
class GridWorldEnv(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
def __init__(self, render_mode=None, size=5):
self.size = size # The size of the square grid
self.window_size = 512 # The size of the PyGame window
# Observations are dictionaries with the agent's and the target's location.
# Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
self.observation_space = spaces.Dict(
{
"agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
"target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
}
)
self._agent_location = np.array([-1, -1], dtype=int)
self._target_location = np.array([-1, -1], dtype=int)
# We have 4 actions, corresponding to "right", "up", "left", "down"
self.action_space = spaces.Discrete(4)
"""
The following dictionary maps abstract actions from `self.action_space` to
the direction we will walk in if that action is taken.
i.e. 0 corresponds to "right", 1 to "up" etc.
"""
self._action_to_direction = {
Actions.RIGHT.value: np.array([1, 0]),
Actions.UP.value: np.array([0, 1]),
Actions.LEFT.value: np.array([-1, 0]),
Actions.DOWN.value: np.array([0, -1]),
}
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
"""
If human-rendering is used, `self.window` will be a reference
to the window that we draw to. `self.clock` will be a clock that is used
to ensure that the environment is rendered at the correct framerate in
human-mode. They will remain `None` until human-mode is used for the
first time.
"""
self.window = None
self.clock = None
从环境状态构建观察¶
由于我们将在 reset
和 step
中计算观察,因此通常使用一个(私有)方法 _get_obs
将环境的状态转换为观察会比较方便。但是,这不是强制性的,您也可以分别在 reset
和 step
中计算观察。
def _get_obs(self):
return {"agent": self._agent_location, "target": self._target_location}
我们还可以为由 step
和 reset
返回的辅助信息实现类似的方法。在我们的例子中,我们希望提供代理和目标之间的曼哈顿距离
def _get_info(self):
return {
"distance": np.linalg.norm(
self._agent_location - self._target_location, ord=1
)
}
通常,info 还将包含一些仅在 step
方法内部可用的数据(例如,单个奖励项)。在这种情况下,我们将不得不更新由 step
中的 _get_info
返回的字典。
重置¶
将调用 reset
方法来启动一个新情节。您可以假设 step
方法在调用 reset
之前不会被调用。此外,每当发出 done 信号时,都应该调用 reset
。用户可以将 seed
关键字传递给 reset
,以将环境使用的任何随机数生成器初始化为确定性状态。建议使用环境基类 gymnasium.Env
提供的随机数生成器 self.np_random
。如果您只使用此 RNG,则无需过多担心播种,但您需要记住调用 ``super().reset(seed=seed)`` 以确保 gymnasium.Env
正确地对 RNG 播种。完成此操作后,我们可以随机设置环境的状态。在我们的例子中,我们将随机选择代理的位置并随机抽取目标位置,直到它不与代理的位置重合。
reset
方法应该返回一个包含初始观察和一些辅助信息的元组。我们可以为此使用前面实现的 _get_obs
和 _get_info
方法
def reset(self, seed=None, options=None):
# We need the following line to seed self.np_random
super().reset(seed=seed)
# Choose the agent's location uniformly at random
self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)
# We will sample the target's location randomly until it does not coincide with the agent's location
self._target_location = self._agent_location
while np.array_equal(self._target_location, self._agent_location):
self._target_location = self.np_random.integers(
0, self.size, size=2, dtype=int
)
observation = self._get_obs()
info = self._get_info()
if self.render_mode == "human":
self._render_frame()
return observation, info
步骤¶
通常情况下,step
方法包含了环境逻辑的大部分。它接受一个 action
,计算应用该动作后环境的状态,并返回 5 元组 (observation, reward, terminated, truncated, info)
。请参阅 gymnasium.Env.step()
。一旦计算出环境的新状态,我们就可以检查它是否为终端状态,并相应地设置 done
。由于我们在 GridWorldEnv
中使用稀疏二进制奖励,因此一旦我们知道 done
,计算 reward
就会变得很简单。为了收集 observation
和 info
,我们可以再次使用 _get_obs
和 _get_info
def step(self, action):
# Map the action (element of {0,1,2,3}) to the direction we walk in
direction = self._action_to_direction[action]
# We use `np.clip` to make sure we don't leave the grid
self._agent_location = np.clip(
self._agent_location + direction, 0, self.size - 1
)
# An episode is done iff the agent has reached the target
terminated = np.array_equal(self._agent_location, self._target_location)
reward = 1 if terminated else 0 # Binary sparse rewards
observation = self._get_obs()
info = self._get_info()
if self.render_mode == "human":
self._render_frame()
return observation, reward, terminated, False, info
渲染¶
在这里,我们使用 PyGame 进行渲染。许多包含在 Gymnasium 中的环境都使用了类似的渲染方法,您可以将其用作构建自己环境的模板。
def render(self):
if self.render_mode == "rgb_array":
return self._render_frame()
def _render_frame(self):
if self.window is None and self.render_mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(
(self.window_size, self.window_size)
)
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()
canvas = pygame.Surface((self.window_size, self.window_size))
canvas.fill((255, 255, 255))
pix_square_size = (
self.window_size / self.size
) # The size of a single grid square in pixels
# First we draw the target
pygame.draw.rect(
canvas,
(255, 0, 0),
pygame.Rect(
pix_square_size * self._target_location,
(pix_square_size, pix_square_size),
),
)
# Now we draw the agent
pygame.draw.circle(
canvas,
(0, 0, 255),
(self._agent_location + 0.5) * pix_square_size,
pix_square_size / 3,
)
# Finally, add some gridlines
for x in range(self.size + 1):
pygame.draw.line(
canvas,
0,
(0, pix_square_size * x),
(self.window_size, pix_square_size * x),
width=3,
)
pygame.draw.line(
canvas,
0,
(pix_square_size * x, 0),
(pix_square_size * x, self.window_size),
width=3,
)
if self.render_mode == "human":
# The following line copies our drawings from `canvas` to the visible window
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
# We need to ensure that human-rendering occurs at the predefined framerate.
# The following line will automatically add a delay to keep the framerate stable.
self.clock.tick(self.metadata["render_fps"])
else: # rgb_array
return np.transpose(
np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
)
关闭¶
close
方法应关闭环境使用的所有打开资源。在许多情况下,您实际上不必费心实现此方法。但是,在我们的示例中,render_mode
可能是 "human"
,我们可能需要关闭已打开的窗口。
def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
在其他环境中,close
也可以关闭已打开的文件或释放其他资源。在调用 close
后,您不应与环境交互。
注册环境¶
为了使自定义环境能够被 Gymnasium 检测到,必须按照以下步骤进行注册。我们将选择将此代码放在 gymnasium_env/__init__.py
中。
from gymnasium.envs.registration import register
register(
id="gymnasium_env/GridWorld-v0",
entry_point="gymnasium_env.envs:GridWorldEnv",
)
环境 ID 由三个部分组成,其中两个是可选的:一个可选的命名空间(这里为:gymnasium_env
),一个必填的名称(这里为:GridWorld
)和一个可选但推荐的版本(这里为 v0)。它也可以注册为 GridWorld-v0
(推荐的方法),GridWorld
或 gymnasium_env/GridWorld
,在环境创建过程中应使用相应的 ID。
关键字参数 max_episode_steps=300
将确保通过 gymnasium.make
实例化的 GridWorld 环境将被包装在一个 TimeLimit
包装器中(有关更多信息,请参阅 包装器文档)。如果代理已到达目标或当前剧集已执行 300 步,则将产生一个完成信号。要区分截断和终止,您可以检查 info["TimeLimit.truncated"]
。
除了 id
和 entrypoint
之外,您还可以将以下附加关键字参数传递给 register
名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
|
|
|
任务被认为已解决之前的奖励阈值 |
|
|
|
即使在播种后,此环境是否仍然是非确定性的 |
|
|
|
一个剧集可以包含的最大步数。如果为 |
|
|
|
是否在环境中包装一个 |
|
|
|
传递给环境类的默认 kwargs |
这些关键字中的大多数(除了 max_episode_steps
、order_enforce
和 kwargs
)不会改变环境实例的行为,而只是提供有关您的环境的一些额外信息。注册后,我们的自定义 GridWorldEnv
环境可以通过 env = gymnasium.make('gymnasium_env/GridWorld-v0')
创建。
gymnasium_env/envs/__init__.py
应该包含
from gymnasium_env.envs.grid_world import GridWorldEnv
如果您的环境未注册,您也可以选择传递一个要导入的模块,该模块将在创建环境之前注册您的环境,方法如下:env = gymnasium.make('module:Env-v0')
,其中 module
包含注册代码。对于 GridWorld 环境,注册代码通过导入 gymnasium_env
来运行,因此如果无法显式导入 gymnasium_env,您可以在制作时注册,方法是:env = gymnasium.make('gymnasium_env:gymnasium_env/GridWorld-v0')
。这在您只能将环境 ID 传递到第三方代码库(例如学习库)时尤其有用。这使您可以在无需编辑库源代码的情况下注册您的环境。
创建包¶
最后一步是将我们的代码构建为 Python 包。这需要配置 pyproject.toml
。以下是执行此操作的最小示例
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "gymnasium_env"
version = "0.0.1"
dependencies = [
"gymnasium",
"pygame==2.1.3",
"pre-commit",
]
创建环境实例¶
现在,您可以使用以下命令在本地安装您的包
pip install -e .
然后,您可以通过以下方式创建环境实例
# run_gymnasium_env.py
import gymnasium
import gymnasium_env
env = gymnasium.make('gymnasium_env/GridWorld-v0')
您也可以将环境构造函数的关键字参数传递给 gymnasium.make
以自定义环境。在我们的例子中,我们可以这样做
env = gymnasium.make('gymnasium_env/GridWorld-v0', size=10)
有时,您可能会发现跳过注册并直接调用环境的构造函数更为方便。有些人可能发现这种方法更符合 Python 的风格,这样实例化的环境也是完全可以的(但请记住也要添加包装器!)。
使用包装器¶
通常情况下,我们希望使用自定义环境的不同变体,或者我们希望修改 Gymnasium 或其他方提供的环境的行为。包装器允许我们执行此操作,而无需更改环境实现或添加任何样板代码。查看 包装器文档 以了解有关如何使用包装器以及实现您自己的包装器的说明。在我们的示例中,观察结果不能直接在学习代码中使用,因为它们是字典。但是,我们实际上不需要触碰我们的环境实现来解决这个问题!我们只需在环境实例之上添加一个包装器来将观察结果扁平化为单个数组
import gymnasium
import gymnasium_env
from gymnasium.wrappers import FlattenObservation
env = gymnasium.make('gymnasium_env/GridWorld-v0')
wrapped_env = FlattenObservation(env)
print(wrapped_env.reset()) # E.g. [3 0 3 3], {}
包装器具有很大的优势,它们使环境高度模块化。例如,您可以不将 GridWorld 的观察结果扁平化,而只想查看目标和代理的相对位置。在 ObservationWrappers 部分中,我们实现了一个执行此工作的包装器。此包装器也存在于 gymnasium_env/wrappers/relative_position.py
中
import gymnasium
import gymnasium_env
from gymnasium_env.wrappers import RelativePosition
env = gymnasium.make('gymnasium_env/GridWorld-v0')
wrapped_env = RelativePosition(env)
print(wrapped_env.reset()) # E.g. [-3 3], {}