返回

井字棋实战:蒙特卡洛树搜索(MCTS)详解与代码示例

Ai

蒙特卡洛树搜索(MCTS)在井字棋中的实战应用

玩游戏的时候,想让电脑也聪明点,能跟人对弈?蒙特卡洛树搜索(MCTS)是个好办法。它能帮电脑在棋盘上找出胜算最大的下一步。这篇东西,咱们就聊聊 MCTS 在井字棋里咋用。

一、 咋回事?为啥 MCTS 能行?

你可能知道 MCTS 的基本步骤:选择、扩展、模拟、反向传播。 但在实际对局中,比如井字棋,它是咋一步步走的? 有个关键问题:是先一口气把整个搜索树建好,然后每一步都查表走棋?还是每走一步,都稍微建一点树,边下边算?

要搞清这个问题,先得弄明白,MCTS 本质上干了啥。

简单说,MCTS 就是个“试错”的算法。它不断模拟下棋的过程,记录哪些走法赢得多,哪些输得多。 模拟次数越多,它对棋局的“感觉”就越准。 最后,它会根据这些“感觉”,挑一个赢面最大的走法。

但为啥“试错”也能这么厉害? 因为 MCTS 有几个巧妙的设计:

  • 选择(Selection) : 它不会瞎试,而是会优先尝试那些“看起来不错”的走法。 这就好像高手下棋,不会每个空都去试,而是会重点考虑几个关键位置。
  • 扩展(Expansion) : 如果某个走法“看起来不错”,但还没试过几次,MCTS 就会多试几次,看看是不是真的好。
  • 模拟(Simulation) : 试的时候,MCTS 会让“两个自己”对打,直到分出胜负。 这样,它就能知道这个走法是好是坏。
  • 反向传播(Backpropagation) : 打完之后,MCTS 会把结果告诉“祖宗们”——也就是之前选择这个走法的那些节点。 这样,“祖宗们”就知道自己选得对不对了。

二、 咋用?井字棋实战!

好,明白了 MCTS 的原理,咱们来看看在井字棋里咋用。 还是回到刚才的问题:搜索树是提前建好,还是边下边建?

1. 边下边建!

在井字棋这种比较简单的游戏里,更常用的做法是边下边建树 。 为啥?

  • 省事儿 :井字棋虽然简单,但所有可能的局面也有不少。 如果一开始就把所有情况都算出来,那得算到猴年马月。
  • 灵活 : 边下边建,意味着每走一步,MCTS 都会根据当前的局面,重新“思考”一下。 这样,它就不会被之前的“错误”带偏,能更好地适应对手的走法。

那具体咋操作?

代码示例 (Python):
(下面这段代码只是为了演示原理,实际应用中,需要根据具体情况进行优化。)

import random
import math

class Node:
    def __init__(self, state, parent=None, move=None):
        self.state = state  # 当前棋局状态
        self.parent = parent # 父节点
        self.move = move # 从父节点到当前节点的走法
        self.children = [] # 子节点
        self.visits = 0 # 访问次数
        self.wins = 0   # 获胜次数

    def ucb1(self, total_visits):
       #UCB1公式, 用于选择子节点
        if self.visits == 0:
            return float('inf')
        return self.wins / self.visits + 1.41 * math.sqrt(math.log(total_visits) / self.visits)

    def select_child(self):
        #选择最佳UCB1分数的子节点
        best_child = None
        best_score = -float('inf')
        for child in self.children:
            score = child.ucb1(self.visits)
            if score > best_score:
                best_score = score
                best_child = child
        return best_child

    def expand(self, legal_moves):
       # 展开, 针对每个合法的移动都创建一个新的子节点.
        for move in legal_moves:
            new_state = self.state.copy()  # 复制当前状态
            new_state[move] = self.state.current_player()  # 走这一步棋
            child = Node(new_state, parent=self, move=move)
            self.children.append(child)

    def simulate(self):
      # 模拟, 从这个节点快速走到底(双方随机走, 直到分出胜负或者和棋).
        state = self.state.copy()
        while not state.is_game_over():
            legal_moves = state.legal_moves()
            move = random.choice(legal_moves)
            state[move] = state.current_player()
        return state.winner()

    def backpropagate(self, result):
      # 反向传播,更新节点及其所有祖先的访问次数和获胜次数
        self.visits += 1
        if result == self.state.current_player():
            self.wins += 1
        elif result == 0: # Tie
            self.wins += 0.5
        if self.parent:
            self.parent.backpropagate(-result)  # 注意这里的负号

class TicTacToeState:
    # 井字棋棋盘状态
    def __init__(self, board=None, player=1):
      # 构造函数
        if board is None:
            self.board = [0] * 9  # 棋盘,0表示空,1表示玩家1,-1表示玩家2
        else:
            self.board = board
        self.player = player #轮到谁下, 1 或 -1.

    def copy(self):
        # 复制棋盘状态
        return TicTacToeState(self.board.copy(), self.player)

    def current_player(self):
        #当前玩家
        return self.player
    
    def legal_moves(self):
        # 返回当前局面下所有合法的走法
        return [i for i, x in enumerate(self.board) if x == 0]

    def is_game_over(self):
        #判断游戏是否结束
        return self.winner() is not None or all(x != 0 for x in self.board)

    def winner(self):
      #检查是否有玩家获胜, 返回获胜的玩家 (1 或者 -1), 平局返回 0, 游戏未结束返回None.
        # 检查行
        for i in range(0, 9, 3):
            if self.board[i] == self.board[i+1] == self.board[i+2] != 0:
                return self.board[i]
        # 检查列
        for i in range(3):
            if self.board[i] == self.board[i+3] == self.board[i+6] != 0:
                return self.board[i]
        # 检查对角线
        if self.board[0] == self.board[4] == self.board[8] != 0:
            return self.board[0]
        if self.board[2] == self.board[4] == self.board[6] != 0:
            return self.board[2]
        #没有获胜者
        if  all(x != 0 for x in self.board):
          return 0 #Tie
        return None
    def __getitem__(self, index):
        return self.board[index]

    def __setitem__(self, index, value):
         self.board[index] = value
         self.player*=-1 #Change Player
    
    def __str__(self):
        #打印棋盘的格式化
        s = ""
        for i in range(0, 9, 3):
            s += " | ".join(['X' if self.board[i+j] == 1 else 'O' if self.board[i+j] == -1 else ' ' for j in range(3)]) + "\n"
            if i < 6:
                s += "---------\n"
        return s

def mcts_search(root_state, iterations):
  # 执行蒙特卡洛树搜索,返回最佳走法。

    root = Node(root_state)

    for _ in range(iterations):
        node = root
        # 选择
        while node.children:
            node = node.select_child()

        # 扩展 & 模拟
        if not node.state.is_game_over():
            legal_moves = node.state.legal_moves()
            node.expand(legal_moves)
            if(len(node.children)>0):
                node = random.choice(node.children)  # 随机选择一个子节点
            result = node.simulate()
        else:
          result = node.state.winner()

        # 反向传播
        node.backpropagate(result)

    # 选择访问次数最多的子节点作为最佳走法
    best_move = None
    best_visits = -1
    for child in root.children:
        if child.visits > best_visits:
            best_visits = child.visits
            best_move = child.move

    return best_move

# 示例用法
initial_state = TicTacToeState()
print("初始棋盘:")
print(initial_state)

while not initial_state.is_game_over():
    if initial_state.current_player() == 1:  # 假设玩家1先走
        move = mcts_search(initial_state, 1000)  # 可以调整迭代次数
        print("AI走了:", move)
        initial_state[move] = 1
    else:
        # 玩家2 (人类) 走
        while True:
             
            try:
              move = int(input(f"请玩家2输入走棋位置 (0-8): "))
              if move < 0 or move > 8:
                  print("输入超出范围。")
                  continue

              if  initial_state[move] != 0 :
                print("该位置已经被占据")
                continue 
              break
            except ValueError:
              print("请输入数字。")
        initial_state[move] = -1

    print(initial_state)

winner = initial_state.winner()
if winner == 1:
    print("AI 获胜!")
elif winner == -1:
    print("玩家2 获胜!")
else:
    print("平局!")

代码解释:

  1. Node 类:

    • state: 当前节点的棋盘状态。
    • parent: 父节点。
    • move: 从父节点到当前节点的走法(0-8 的数字)。
    • children: 子节点列表。
    • visits: 访问次数。
    • wins: 获胜次数(或平局的0.5次)。
    • ucb1(): 计算 UCB1 值,用于选择子节点。
    • select_child(): 根据 UCB1 值选择最佳子节点。
    • expand(): 扩展节点,创建所有可能的子节点。
    • simulate(): 模拟对局,直到游戏结束。
    • backpropagate(): 反向传播结果,更新节点信息。
  2. TicTacToeState 类:

    • 表示井字棋的棋盘状态,包括棋盘布局和当前玩家。
    • 提供了 copy()(复制状态)、current_player()(获取当前玩家)、legal_moves()(获取合法走法)、is_game_over()(判断游戏是否结束)、winner()(判断赢家)等方法。
  3. mcts_search() 函数:

    • root_state: 当前棋局状态。
    • iterations: 模拟次数。
    • 核心逻辑:
      1. 创建根节点。
      2. 循环 iterations 次:
        • 选择: 从根节点开始,根据 UCB1 值选择子节点,直到叶子节点。
        • 扩展 & 模拟: 如果叶子节点未结束游戏,则扩展并模拟。
        • 反向传播: 将模拟结果反向传播到根节点。
      3. 选择访问次数最多的子节点作为最佳走法。

步骤分解:

  1. 确定当前玩家 : 看看轮到谁了 (AI 还是对手)。
  2. 构建/更新搜索树 :
    • 如果这是第一步,就创建一个根节点,代表当前棋盘状态。
    • 如果不是第一步,找到与当前棋盘状态对应的节点(之前可能已经创建过了),把它作为新的根节点。
  3. 多次模拟 : 从根节点开始,循环执行 MCTS 的四个步骤:选择、扩展、模拟、反向传播。 循环次数越多,结果越可靠(但也会更慢)。
  4. 选最佳走法 : 模拟结束后,看看根节点的哪个子节点访问次数最多(或者胜率最高),就选那个子节点对应的走法。
  5. 走棋 : 把 AI 选好的走法应用到棋盘上。
  6. 对手走棋 : 等对手下棋。
  7. 循环 : 回到第 1 步,重复以上过程,直到游戏结束。

2. 几个小提示:

  • 迭代次数: 迭代次数越多,MCTS 的“棋力”越强。 你可以根据实际情况调整。
  • UCB1 参数: 代码中的 1.41 是个经验值,可以微调。
  • 代码只是示例,实际应用需要优化以提升效率。 例如:
    * 内存管理: 清理不再需要的节点。避免内存泄漏。
    * 并行化: 如果硬件允许,可以使用多线程或多进程来加速模拟过程。
    * 开局库: 对一些常见的开局, 直接用已知的最佳走法,不用再算。

三、 总结

MCTS 在井字棋中,通常是“边下边算”。每次轮到 AI 走棋,都会根据当前局面,构建或更新搜索树,通过多次模拟来找到最佳走法。 这样既节省了计算资源,又能灵活应对对手的策略。记住: 实践出真知。 多动手试试,看看不同策略对结果有啥影响.