注意

此示例与 Gymnasium 1.2.0 版本兼容。

使用表格型Q-学习解决21点问题

本教程使用表格型Q-学习训练一个玩21点的智能体。

agent-environment-diagram agent-environment-diagram

在本教程中,我们将探索并解决 Blackjack-v1 环境。

21点 是最受欢迎的赌场纸牌游戏之一,也因在某些条件下可被击败而闻名。此版本游戏使用无限牌堆(我们抽牌时会放回),因此在我们的模拟游戏中,算牌不是一个可行的策略。完整文档可在 https://gymnasium.org.cn/environments/toy_text/blackjack 找到

目标:要获胜,您的牌点数总和应大于庄家,且不超过21点。

行动:智能体可以选择两种行动
  • 停牌 (0):玩家不再要牌

  • 要牌 (1):玩家将获得另一张牌,但玩家可能会超过21点而爆牌

方法:要自行解决此环境,您可以选择您喜欢的离散RL算法。本文提出的解决方案使用了 Q-学习(一种无模型RL算法)。

导入和环境设置

# Author: Till Zemann
# License: MIT License

from __future__ import annotations

from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.patches import Patch
from tqdm import tqdm

import gymnasium as gym


# Let's start by creating the blackjack environment.
# Note: We are going to follow the rules from Sutton & Barto.
# Other versions of the game can be found below for you to experiment.

env = gym.make("Blackjack-v1", sab=True)
# Other possible environment configurations are:

env = gym.make('Blackjack-v1', natural=True, sab=False)
# Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).

env = gym.make('Blackjack-v1', natural=False, sab=False)
# Whether to follow the exact rules outlined in the book by Sutton and Barto. If `sab` is `True`, the keyword argument `natural` will be ignored.

观察环境

首先,我们调用 env.reset() 来开始一个回合。此函数将环境重置到起始位置并返回一个初始 observation。我们通常还会将 done = False。此变量稍后将用于检查游戏是否已终止(即玩家赢或输)。

# reset the environment to get the first observation
done = False
observation, info = env.reset()

# observation = (16, 9, False)

请注意,我们的观测是一个包含3个值的3元组

  • 玩家当前的点数总和

  • 庄家明牌的点数

  • 玩家是否持有可用A的布尔值(如果A计作11点而不会爆牌,则该A可用)

执行行动

在收到我们的第一个观测后,我们将只使用 env.step(action) 函数与环境交互。此函数将一个行动作为输入并在环境中执行它。因为该行动会改变环境的状态,它会返回四个有用的变量给我们。它们是

  • next_state:这是智能体采取行动后将收到的观测。

  • reward:这是智能体采取行动后将收到的奖励。

  • terminated:这是一个布尔变量,指示环境是否已终止。

  • truncated:这是一个布尔变量,也指示回合是否因提前截断而结束,即达到时间限制。

  • info:这是一个字典,可能包含有关环境的额外信息。

next_staterewardterminatedtruncated 变量不言自明,但 info 变量需要一些额外的解释。此变量包含一个字典,其中可能包含有关环境的一些额外信息,但在 Blackjack-v1 环境中您可以忽略它。例如,在雅达利环境中,info 字典有一个 ale.lives 键,它告诉我们智能体还剩下多少条生命。如果智能体生命为0,则该回合结束。

请注意,在训练循环中调用 env.render() 不是一个好主意,因为渲染会大大减慢训练速度。不如尝试构建一个额外的循环,在训练后评估和展示智能体。

# sample a random action from all valid actions
action = env.action_space.sample()
# action=1

# execute the action in our environment and receive infos from the environment
observation, reward, terminated, truncated, info = env.step(action)

# observation=(24, 10, False)
# reward=-1.0
# terminated=True
# truncated=False
# info={}

一旦 terminated = Truetruncated=True,我们应该停止当前回合,并使用 env.reset() 开始新回合。如果您在不重置环境的情况下继续执行行动,它仍然会响应,但输出对训练没有用(如果智能体在无效数据上学习,甚至可能有害)。

构建智能体

让我们构建一个 Q-学习 智能体 来解决 *Blackjack-v1*!我们将需要一些函数来选择行动和更新智能体的行动值。为了确保智能体探索环境,一个可能的解决方案是 epsilon-greedy 策略,我们以 epsilon 的百分比选择一个随机行动,并以 1 - epsilon 的百分比选择贪婪行动(当前被认为是最好的行动)。

class BlackjackAgent:
    def __init__(
        self,
        env,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.95,
    ):
        """Initialize a Reinforcement Learning agent with an empty dictionary
        of state-action values (q_values), a learning rate and an epsilon.

        Args:
            learning_rate: The learning rate
            initial_epsilon: The initial epsilon value
            epsilon_decay: The decay for epsilon
            final_epsilon: The final epsilon value
            discount_factor: The discount factor for computing the Q-value
        """
        self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))

        self.lr = learning_rate
        self.discount_factor = discount_factor

        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        self.training_error = []

    def get_action(self, env, obs: tuple[int, int, bool]) -> int:
        """
        Returns the best action with probability (1 - epsilon)
        otherwise a random action with probability epsilon to ensure exploration.
        """
        # with probability epsilon return a random action to explore the environment
        if np.random.random() < self.epsilon:
            return env.action_space.sample()

        # with probability (1 - epsilon) act greedily (exploit)
        else:
            return int(np.argmax(self.q_values[obs]))

    def update(
        self,
        obs: tuple[int, int, bool],
        action: int,
        reward: float,
        terminated: bool,
        next_obs: tuple[int, int, bool],
    ):
        """Updates the Q-value of an action."""
        future_q_value = (not terminated) * np.max(self.q_values[next_obs])
        temporal_difference = (
            reward + self.discount_factor * future_q_value - self.q_values[obs][action]
        )

        self.q_values[obs][action] = (
            self.q_values[obs][action] + self.lr * temporal_difference
        )
        self.training_error.append(temporal_difference)

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

为了训练智能体,我们将让智能体一次玩一个回合(一个完整的游戏称为一个回合),然后在每个步骤(游戏中的一个单独行动称为一个步骤)之后更新其Q值。

智能体需要经历大量回合才能充分探索环境。

现在我们应该准备好构建训练循环了。

# hyperparameters
learning_rate = 0.01
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2)  # reduce the exploration over time
final_epsilon = 0.1

agent = BlackjackAgent(
    env=env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)

太棒了,开始训练吧!

信息:当前的超参数设置旨在快速训练一个不错的智能体。如果您想收敛到最优策略,请尝试将 n_episodes 增加10倍并降低 learning_rate(例如,降至0.001)。

env = gym.wrappers.RecordEpisodeStatistics(env, buffer_length=n_episodes)
for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    done = False

    # play one episode
    while not done:
        action = agent.get_action(env, obs)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # update the agent
        agent.update(obs, action, reward, terminated, next_obs)

        # update if the environment is done and the current obs
        done = terminated or truncated
        obs = next_obs

    agent.decay_epsilon()

可视化训练过程

rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
axs[0].set_title("Episode rewards")
# compute and assign a rolling average of the data to provide a smoother graph
reward_moving_average = (
    np.convolve(
        np.array(env.return_queue).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
axs[1].set_title("Episode lengths")
length_moving_average = (
    np.convolve(
        np.array(env.length_queue).flatten(), np.ones(rolling_length), mode="same"
    )
    / rolling_length
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)
axs[2].set_title("Training Error")
training_error_moving_average = (
    np.convolve(np.array(agent.training_error), np.ones(rolling_length), mode="same")
    / rolling_length
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
../../../_images/blackjack_training_plots.png

可视化策略

def create_grids(agent, usable_ace=False):
    """Create value and policy grid given an agent."""
    # convert our state-action values to state values
    # and build a policy dictionary that maps observations to actions
    state_value = defaultdict(float)
    policy = defaultdict(int)
    for obs, action_values in agent.q_values.items():
        state_value[obs] = float(np.max(action_values))
        policy[obs] = int(np.argmax(action_values))

    player_count, dealer_count = np.meshgrid(
        # players count, dealers face-up card
        np.arange(12, 22),
        np.arange(1, 11),
    )

    # create the value grid for plotting
    value = np.apply_along_axis(
        lambda obs: state_value[(obs[0], obs[1], usable_ace)],
        axis=2,
        arr=np.dstack([player_count, dealer_count]),
    )
    value_grid = player_count, dealer_count, value

    # create the policy grid for plotting
    policy_grid = np.apply_along_axis(
        lambda obs: policy[(obs[0], obs[1], usable_ace)],
        axis=2,
        arr=np.dstack([player_count, dealer_count]),
    )
    return value_grid, policy_grid


def create_plots(value_grid, policy_grid, title: str):
    """Creates a plot using a value and policy grid."""
    # create a new figure with 2 subplots (left: state values, right: policy)
    player_count, dealer_count, value = value_grid
    fig = plt.figure(figsize=plt.figaspect(0.4))
    fig.suptitle(title, fontsize=16)

    # plot the state values
    ax1 = fig.add_subplot(1, 2, 1, projection="3d")
    ax1.plot_surface(
        player_count,
        dealer_count,
        value,
        rstride=1,
        cstride=1,
        cmap="viridis",
        edgecolor="none",
    )
    plt.xticks(range(12, 22), range(12, 22))
    plt.yticks(range(1, 11), ["A"] + list(range(2, 11)))
    ax1.set_title(f"State values: {title}")
    ax1.set_xlabel("Player sum")
    ax1.set_ylabel("Dealer showing")
    ax1.zaxis.set_rotate_label(False)
    ax1.set_zlabel("Value", fontsize=14, rotation=90)
    ax1.view_init(20, 220)

    # plot the policy
    fig.add_subplot(1, 2, 2)
    ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
    ax2.set_title(f"Policy: {title}")
    ax2.set_xlabel("Player sum")
    ax2.set_ylabel("Dealer showing")
    ax2.set_xticklabels(range(12, 22))
    ax2.set_yticklabels(["A"] + list(range(2, 11)), fontsize=12)

    # add a legend
    legend_elements = [
        Patch(facecolor="lightgreen", edgecolor="black", label="Hit"),
        Patch(facecolor="grey", edgecolor="black", label="Stick"),
    ]
    ax2.legend(handles=legend_elements, bbox_to_anchor=(1.3, 1))
    return fig


# state values & policy with usable ace (ace counts as 11)
value_grid, policy_grid = create_grids(agent, usable_ace=True)
fig1 = create_plots(value_grid, policy_grid, title="With usable ace")
plt.show()
../../../_images/blackjack_with_usable_ace.png
# state values & policy without usable ace (ace counts as 1)
value_grid, policy_grid = create_grids(agent, usable_ace=False)
fig2 = create_plots(value_grid, policy_grid, title="Without usable ace")
plt.show()
../../../_images/blackjack_without_usable_ace.png

在脚本结束时调用 env.close() 是一个好习惯,这样环境使用的任何资源都将被关闭。

你觉得你能做得更好吗?

# You can visualize the environment using the play function
# and try to win a few games.

希望本教程能帮助您掌握如何与 OpenAI-Gym 环境交互,并开启您解决更多RL挑战的旅程。

建议您自行解决此环境(项目式学习非常有效!)。您可以应用您喜欢的离散RL算法,或者尝试蒙特卡洛ES(在 Sutton & Barto 第5.3节中介绍)——这样您可以将您的结果直接与书中内容进行比较。

祝您玩得愉快!