使用 Q 学习解决 Blackjack

agent-environment-diagram agent-environment-diagram

在本教程中,我们将探索并解决 _Blackjack-v1_ 环境。

**Blackjack** 是最受欢迎的赌场纸牌游戏之一,它也因在特定条件下可被击败而臭名昭著。此版本的规则使用无限牌堆(我们用替换的方式抽取卡片),因此在模拟游戏中,算牌将不是可行的策略。完整文档可在 https://gymnasium.org.cn/environments/toy_text/blackjack 找到。

**目标**: 为了获胜,您的牌点总和应大于庄家的牌点总和,但不得超过 21。

**动作**: 代理可以选择两种动作之一
  • stand (0): 玩家不再要牌

  • hit (1): 玩家将获得另一张牌,但是玩家可能会超过 21 点而爆掉

**方法**: 要自己解决这个环境,您可以选择您最喜欢的离散 RL 算法。本解决方案使用 _Q 学习_(一种无模型 RL 算法)。

导入和环境设置

# Author: Till Zemann
# License: MIT License

from __future__ import annotations

from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.patches import Patch
from tqdm import tqdm

import gymnasium as gym


# Let's start by creating the blackjack environment.
# Note: We are going to follow the rules from Sutton & Barto.
# Other versions of the game can be found below for you to experiment.

env = gym.make("Blackjack-v1", sab=True)
# Other possible environment configurations are:

env = gym.make('Blackjack-v1', natural=True, sab=False)
# Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).

env = gym.make('Blackjack-v1', natural=False, sab=False)
# Whether to follow the exact rules outlined in the book by Sutton and Barto. If `sab` is `True`, the keyword argument `natural` will be ignored.

观察环境

首先,我们调用 env.reset() 来开始一局。此函数将环境重置为起始位置并返回一个初始 observation。我们通常还会设置 done = False。此变量稍后将用于检查游戏是否结束(即,玩家赢或输)。

# reset the environment to get the first observation
done = False
observation, info = env.reset()

# observation = (16, 9, False)

请注意,我们的观察结果是一个包含 3 个值的 3 元组

  • 玩家当前的牌点总和

  • 庄家明牌的牌值

  • 布尔值,表示玩家是否持有可用 A(如果 A 算作 11 而不爆掉,则 A 可用)

执行动作

在收到我们的第一个观察结果后,我们只将使用 env.step(action) 函数与环境交互。此函数以动作作为输入并在环境中执行它。由于该动作会改变环境的状态,因此它会向我们返回四个有用的变量。它们分别是

  • next_state: 这是代理采取动作后将收到的观察结果。

  • reward: 这是代理采取动作后将收到的奖励。

  • terminated: 这是一个布尔变量,表示环境是否已结束。

  • truncated: 这是一个布尔变量,也表示回合是否因提前截断而结束,即达到时间限制。

  • info: 这是一个字典,可能包含有关环境的附加信息。

next_staterewardterminatedtruncated 变量是不言自明的,但 info 变量需要一些额外的解释。此变量包含一个字典,其中可能包含一些有关环境的额外信息,但在 Blackjack-v1 环境中,您可以忽略它。例如,在 Atari 环境中,info 字典包含一个 ale.lives 键,告诉我们代理还剩多少条命。如果代理的命数为 0,那么回合就结束了。

请注意,在您的训练循环中调用 env.render() 不是一个好主意,因为渲染会大大降低训练速度。相反,尝试构建一个额外的循环来评估和展示训练后的代理。

# sample a random action from all valid actions
action = env.action_space.sample()
# action=1

# execute the action in our environment and receive infos from the environment
observation, reward, terminated, truncated, info = env.step(action)

# observation=(24, 10, False)
# reward=-1.0
# terminated=True
# truncated=False
# info={}

一旦 terminated = Truetruncated=True,我们就应该停止当前回合并使用 env.reset() 开始新的一局。如果您继续执行动作而不重置环境,它仍然会响应,但输出将对训练无用(如果代理在无效数据上进行学习,甚至可能是有害的)。

构建一个代理

让我们构建一个 Q-learning agent 来解决 _Blackjack-v1_!我们需要一些函数来选择动作并更新代理的动作值。为了确保代理探索环境,一种可能的解决方案是 epsilon-greedy 策略,在这种策略中,我们以 epsilon 的百分比选择随机动作,并以 1 - epsilon 的百分比选择贪婪动作(目前被评估为最佳动作)。

class BlackjackAgent:
    def __init__(
        self,
        env,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.95,
    ):
        """Initialize a Reinforcement Learning agent with an empty dictionary
        of state-action values (q_values), a learning rate and an epsilon.

        Args:
            learning_rate: The learning rate
            initial_epsilon: The initial epsilon value
            epsilon_decay: The decay for epsilon
            final_epsilon: The final epsilon value
            discount_factor: The discount factor for computing the Q-value
        """
        self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))

        self.lr = learning_rate
        self.discount_factor = discount_factor

        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        self.training_error = []

    def get_action(self, env, obs: tuple[int, int, bool]) -> int:
        """
        Returns the best action with probability (1 - epsilon)
        otherwise a random action with probability epsilon to ensure exploration.
        """
        # with probability epsilon return a random action to explore the environment
        if np.random.random() < self.epsilon:
            return env.action_space.sample()

        # with probability (1 - epsilon) act greedily (exploit)
        else:
            return int(np.argmax(self.q_values[obs]))

    def update(
        self,
        obs: tuple[int, int, bool],
        action: int,
        reward: float,
        terminated: bool,
        next_obs: tuple[int, int, bool],
    ):
        """Updates the Q-value of an action."""
        future_q_value = (not terminated) * np.max(self.q_values[next_obs])
        temporal_difference = (
            reward + self.discount_factor * future_q_value - self.q_values[obs][action]
        )

        self.q_values[obs][action] = (
            self.q_values[obs][action] + self.lr * temporal_difference
        )
        self.training_error.append(temporal_difference)

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

为了训练代理,我们将让代理玩一局(一整局游戏被称为一局),然后在每一步之后更新其 Q 值(游戏中的一个单一动作被称为一步)。

代理必须体验大量的回合才能充分探索环境。

现在,我们应该准备构建训练循环了。

# hyperparameters
learning_rate = 0.01
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2)  # reduce the exploration over time
final_epsilon = 0.1

agent = BlackjackAgent(
    env=env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)

很好,让我们开始训练吧!

信息:当前的超参数被设置为快速训练一个不错的代理。如果您想收敛到最优策略,尝试将 n_episodes 增加 10 倍,并将学习率降低(例如,降至 0.001)。

env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    done = False

    # play one episode
    while not done:
        action = agent.get_action(env, obs)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # update the agent
        agent.update(obs, action, reward, terminated, next_obs)

        # update if the environment is done and the current obs
        done = terminated or truncated
        obs = next_obs

    agent.decay_epsilon()

可视化训练

rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
axs[0].set_title("Episode rewards")
# compute and assign a rolling average of the data to provide a smoother graph
reward_moving_average = (
    np.convolve(
        np.array(env.return_queue).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
axs[1].set_title("Episode lengths")
length_moving_average = (
    np.convolve(
        np.array(env.length_queue).flatten(), np.ones(rolling_length), mode="same"
    )
    / rolling_length
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)
axs[2].set_title("Training Error")
training_error_moving_average = (
    np.convolve(np.array(agent.training_error), np.ones(rolling_length), mode="same")
    / rolling_length
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
../../../_images/blackjack_training_plots.png

可视化策略

def create_grids(agent, usable_ace=False):
    """Create value and policy grid given an agent."""
    # convert our state-action values to state values
    # and build a policy dictionary that maps observations to actions
    state_value = defaultdict(float)
    policy = defaultdict(int)
    for obs, action_values in agent.q_values.items():
        state_value[obs] = float(np.max(action_values))
        policy[obs] = int(np.argmax(action_values))

    player_count, dealer_count = np.meshgrid(
        # players count, dealers face-up card
        np.arange(12, 22),
        np.arange(1, 11),
    )

    # create the value grid for plotting
    value = np.apply_along_axis(
        lambda obs: state_value[(obs[0], obs[1], usable_ace)],
        axis=2,
        arr=np.dstack([player_count, dealer_count]),
    )
    value_grid = player_count, dealer_count, value

    # create the policy grid for plotting
    policy_grid = np.apply_along_axis(
        lambda obs: policy[(obs[0], obs[1], usable_ace)],
        axis=2,
        arr=np.dstack([player_count, dealer_count]),
    )
    return value_grid, policy_grid


def create_plots(value_grid, policy_grid, title: str):
    """Creates a plot using a value and policy grid."""
    # create a new figure with 2 subplots (left: state values, right: policy)
    player_count, dealer_count, value = value_grid
    fig = plt.figure(figsize=plt.figaspect(0.4))
    fig.suptitle(title, fontsize=16)

    # plot the state values
    ax1 = fig.add_subplot(1, 2, 1, projection="3d")
    ax1.plot_surface(
        player_count,
        dealer_count,
        value,
        rstride=1,
        cstride=1,
        cmap="viridis",
        edgecolor="none",
    )
    plt.xticks(range(12, 22), range(12, 22))
    plt.yticks(range(1, 11), ["A"] + list(range(2, 11)))
    ax1.set_title(f"State values: {title}")
    ax1.set_xlabel("Player sum")
    ax1.set_ylabel("Dealer showing")
    ax1.zaxis.set_rotate_label(False)
    ax1.set_zlabel("Value", fontsize=14, rotation=90)
    ax1.view_init(20, 220)

    # plot the policy
    fig.add_subplot(1, 2, 2)
    ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
    ax2.set_title(f"Policy: {title}")
    ax2.set_xlabel("Player sum")
    ax2.set_ylabel("Dealer showing")
    ax2.set_xticklabels(range(12, 22))
    ax2.set_yticklabels(["A"] + list(range(2, 11)), fontsize=12)

    # add a legend
    legend_elements = [
        Patch(facecolor="lightgreen", edgecolor="black", label="Hit"),
        Patch(facecolor="grey", edgecolor="black", label="Stick"),
    ]
    ax2.legend(handles=legend_elements, bbox_to_anchor=(1.3, 1))
    return fig


# state values & policy with usable ace (ace counts as 11)
value_grid, policy_grid = create_grids(agent, usable_ace=True)
fig1 = create_plots(value_grid, policy_grid, title="With usable ace")
plt.show()
../../../_images/blackjack_with_usable_ace.png
# state values & policy without usable ace (ace counts as 1)
value_grid, policy_grid = create_grids(agent, usable_ace=False)
fig2 = create_plots(value_grid, policy_grid, title="Without usable ace")
plt.show()
../../../_images/blackjack_without_usable_ace.png

在脚本结束时调用 env.close() 是一个好习惯,这样可以关闭环境使用的任何资源。

您认为自己可以做得更好吗?

# You can visualize the environment using the play function
# and try to win a few games.

希望本教程帮助您掌握如何与 OpenAI-Gym 环境交互,并开启您解决更多 RL 挑战的旅程。

建议您自己解决这个环境(基于项目的学习非常有效!)。您可以应用您最喜欢的离散 RL 算法或尝试一下 Monte Carlo ES(在 Sutton & Barto 中的 5.3 节中介绍) - 这样您就可以将您的结果直接与该书进行比较。

祝您玩得开心!