训练代理¶
本页简要概述了如何为 Gymnasium 环境训练代理,特别是,我们将使用基于表格的 Q 学习来解决二十一点 v1 环境。有关此教程的完整完整版本以及更多针对其他环境和算法的训练教程,请参阅此处。在阅读本页之前,请阅读基本用法。在我们实现任何代码之前,这里是对二十一点和 Q 学习的概述。
二十一点是最受欢迎的赌场纸牌游戏之一,它也因在特定条件下可被击败而臭名昭著。此版本的 trò chơi sử dụng một bộ bài vô hạn (chúng ta rút bài có thay thế), vì vậy việc đếm bài sẽ không phải là một chiến lược khả thi trong trò chơi mô phỏng của chúng ta. 观察结果是玩家当前总和、庄家正面朝上的牌的价值以及玩家是否持有可用 ace 的布尔值的元组。 代理可以选择两种动作:stand (0),玩家不再要牌,hit (1),玩家要牌。 为了获胜,你的牌总和必须大于庄家的牌,但不能超过 21。 如果玩家选择 stand 或牌总和超过 21,游戏结束。 完整文档可以在https://gymnasium.org.cn/environments/toy_text/blackjack找到。
Q 学习是 Watkins 在 1989 年提出的一个无模型非策略学习算法,用于具有离散动作空间的环境,它因成为第一个证明在特定条件下收敛到最优策略的强化学习算法而闻名。
执行动作¶
在收到第一个观察结果后,我们只将使用env.step(action)
函数与环境交互。 此函数将动作作为输入,并在环境中执行它。 由于该动作改变了环境的状态,因此它会向我们返回四个有用的变量。 这些是
next observation
: 这是代理在采取动作后将收到的观察结果。reward
: 这是代理在采取动作后将收到的奖励。terminated
: 这是一个布尔变量,指示环境是否已终止,即由于内部条件而结束。truncated
: 这是一个布尔变量,也指示情节是否因过早截断而结束,即时间限制已达到。info
: 这是一个可能包含有关环境的附加信息的字典。
next observation
、reward
、terminated
和 truncated
变量不言自明,但 info
变量需要一些额外的解释。 此变量包含一个字典,其中可能包含有关环境的某些额外信息,但在二十一点 v1 环境中,你可以忽略它。 例如,在 Atari 环境中,info 字典具有 ale.lives
键,它告诉我们代理还有多少条命。 如果代理的生命值为 0,则情节结束。
请注意,在你的训练循环中调用 env.render()
不是一个好主意,因为渲染会大大降低训练速度。 相反,请尝试构建一个额外的循环来评估和展示训练后的代理。
构建代理¶
让我们构建一个 Q 学习代理来解决二十一点! 我们将需要一些函数来选择动作并更新代理的动作值。 为了确保代理探索环境,一种可能的解决方案是 epsilon 贪婪策略,在这种策略中,我们以 epsilon
的百分比选择一个随机动作,以 1 - epsilon
的百分比选择贪婪动作(当前被认为是最好的)。
from collections import defaultdict
import gymnasium as gym
import numpy as np
class BlackjackAgent:
def __init__(
self,
env: gym.Env,
learning_rate: float,
initial_epsilon: float,
epsilon_decay: float,
final_epsilon: float,
discount_factor: float = 0.95,
):
"""Initialize a Reinforcement Learning agent with an empty dictionary
of state-action values (q_values), a learning rate and an epsilon.
Args:
env: The training environment
learning_rate: The learning rate
initial_epsilon: The initial epsilon value
epsilon_decay: The decay for epsilon
final_epsilon: The final epsilon value
discount_factor: The discount factor for computing the Q-value
"""
self.env = env
self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))
self.lr = learning_rate
self.discount_factor = discount_factor
self.epsilon = initial_epsilon
self.epsilon_decay = epsilon_decay
self.final_epsilon = final_epsilon
self.training_error = []
def get_action(self, obs: tuple[int, int, bool]) -> int:
"""
Returns the best action with probability (1 - epsilon)
otherwise a random action with probability epsilon to ensure exploration.
"""
# with probability epsilon return a random action to explore the environment
if np.random.random() < self.epsilon:
return self.env.action_space.sample()
# with probability (1 - epsilon) act greedily (exploit)
else:
return int(np.argmax(self.q_values[obs]))
def update(
self,
obs: tuple[int, int, bool],
action: int,
reward: float,
terminated: bool,
next_obs: tuple[int, int, bool],
):
"""Updates the Q-value of an action."""
future_q_value = (not terminated) * np.max(self.q_values[next_obs])
temporal_difference = (
reward + self.discount_factor * future_q_value - self.q_values[obs][action]
)
self.q_values[obs][action] = (
self.q_values[obs][action] + self.lr * temporal_difference
)
self.training_error.append(temporal_difference)
def decay_epsilon(self):
self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)
训练代理¶
为了训练代理,我们将让代理一次玩一个情节(一个完整的游戏称为一个情节),然后在每个情节后更新它的 Q 值。 代理必须经历很多情节才能充分探索环境。
# hyperparameters
learning_rate = 0.01
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2) # reduce the exploration over time
final_epsilon = 0.1
agent = BlackjackAgent(
learning_rate=learning_rate,
initial_epsilon=start_epsilon,
epsilon_decay=epsilon_decay,
final_epsilon=final_epsilon,
)
信息:当前超参数设置为快速训练一个不错的代理。 如果你想收敛到最优策略,尝试将 n_episodes
增加 10 倍,并将学习速率降低(例如,降低到 0.001)。
from tqdm import tqdm
env = gym.make("Blackjack-v1", sab=False)
env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
for episode in tqdm(range(n_episodes)):
obs, info = env.reset()
done = False
# play one episode
while not done:
action = agent.get_action(obs)
next_obs, reward, terminated, truncated, info = env.step(action)
# update the agent
agent.update(obs, action, reward, terminated, next_obs)
# update if the environment is done and the current obs
done = terminated or truncated
obs = next_obs
agent.decay_epsilon()
可视化策略¶
希望本教程帮助你了解如何与 Gymnasium 环境交互,并让你踏上解决更多 RL 挑战的旅程。
建议你自己解决这个问题(基于项目的学习非常有效!)。 您可以应用您最喜欢的离散 RL 算法或尝试一下蒙特卡罗 ES(在Sutton & Barto <http://incompleteideas.net/book/the-book-2nd.html>
_ 中的 5.3 节中介绍) - 这样你可以将你的结果直接与书中进行比较。
祝你好运!