录制代理

在训练或评估代理期间,记录代理在整个事件中的行为并记录累积的总奖励可能很有趣。这可以通过两个包装器来实现:RecordEpisodeStatisticsRecordVideo,第一个跟踪事件数据,例如总奖励、事件长度和花费时间,而第二个使用环境渲染生成 mp4 视频。

我们展示了如何针对两种类型的问题应用这些包装器;第一个用于记录每事件的数据(通常是评估),第二个用于定期记录数据(用于正常训练)。

录制每个事件

给定一个经过训练的代理,您可能希望在评估期间记录多个事件,以查看代理的行为。下面我们提供了一个示例脚本,使用 RecordEpisodeStatisticsRecordVideo 来做到这一点。

import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo

num_eval_episodes = 4

env = gym.make("CartPole-v1", render_mode="rgb_array")  # replace with your environment
env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="eval",
                  episode_trigger=lambda x: True)
env = RecordEpisodeStatistics(env, buffer_length=num_eval_episodes)

for episode_num in range(num_eval_episodes):
    obs, info = env.reset()

    episode_over = False
    while not episode_over:
        action = env.action_space.sample()  # replace with actual agent
        obs, reward, terminated, truncated, info = env.step(action)

        episode_over = terminated or truncated
env.close()

print(f'Episode time taken: {env.time_queue}')
print(f'Episode total rewards: {env.return_queue}')
print(f'Episode lengths: {env.length_queue}')

在上面的脚本中,对于 RecordVideo 包装器,我们指定了三个不同的变量:video_folder 用于指定应保存视频的文件夹(更改您的问题),name_prefix 用于视频本身的前缀,最后是 episode_trigger,以便记录每个事件。这意味着对于环境的每个事件,都会录制一个视频,并以“cartpole-agent/eval-episode-x.mp4”的格式保存。

对于 RecordEpisodicStatistics,我们只需要指定缓冲区长度,这就是内部 time_queuereturn_queuelength_queue 的最大长度。我们无需单独收集每个事件的数据,可以使用数据队列在评估结束时打印信息。

为了加快评估环境的速度,可以使用矢量环境来实现这一点,以便以并行的方式而不是串行的方式同时评估 N 个事件。

在训练期间录制代理

在训练期间,代理将在数百或数千个事件中行动,因此,您无法为每个事件录制视频,但开发人员可能仍然想知道代理在训练中的不同点是如何行动的,在训练期间定期录制事件。虽然对于事件统计来说,了解每个事件的这些数据更有用。以下脚本提供了一个示例,说明如何在定期录制事件的同时录制每个事件的统计信息(我们使用 python 的记录器,但 tensorboardwandb 和其他模块都可用)。

import logging

import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo

training_period = 250  # record the agent's episode every 250
num_training_episodes = 10_000  # total number of training episodes

env = gym.make("CartPole-v1", render_mode="rgb_array")  # replace with your environment
env = RecordVideo(env, video_folder="cartpole-agent", name_prefix="training",
                  episode_trigger=lambda x: x % training_period == 0)
env = RecordEpisodeStatistics(env)

for episode_num in range(num_training_episodes):
    obs, info = env.reset()

    episode_over = False
    while not episode_over:
        action = env.action_space.sample()  # replace with actual agent
        obs, reward, terminated, truncated, info = env.step(action)

        episode_over = terminated or truncated

    logging.info(f"episode-{episode_num}", info["episode"])
env.close()

更多信息