1 Star 0 Fork 0

gypsophila/gobang

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
mcts_alphaZero.py 5.97 KB
一键复制 编辑 原始数据 按行查看 历史
gypsophila 提交于 2023-07-10 10:31 . 可以跑了,第一代模型
import numpy as np
import copy
def Softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
class TreeNode(object):
"""蒙特卡洛搜索树的节点"""
def __init__(self, parent, prior_p):
self._parent = parent
self._children = {}
self._n_visits = 0
self._Q = 0
self._u = 0 # 置信上线
self._P = prior_p
# c_puct 控制探索和利用平衡的参数,从外部传入以方便调整
def select(self, c_puct):
"""Return: A tuple of (action, next_node)"""
return max(self._children.items(),
key=lambda act_node:
act_node[1].get_value(c_puct))
def get_value(self, c_puct):
self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
return self._Q + self._u
# 扩展
def expand(self, action_priors):
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)
# 更新一个
def update(self, leaf_value):
self._n_visits += 1
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
# 利用递归更新所有的节点
def update_recursive(self, leaf_value):
if self._parent:
# 注意这里传的是 -leaf_value 因为每层节点代表的执棋方不同
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
def is_leaf(self):
return self._children == {}
def is_root(self):
return self._parent is None
class MCTS(object):
"""An implementation of Monte Carlo Tree Search"""
def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
"""
:param policy_value_fn:传入的策略价值网络
:param c_puct: 控制探索程度
:param n_playout: 控制循环次数
"""
self._root = TreeNode(None, 1.0)
self._policy = policy_value_fn
self._c_puct = c_puct
self._n_playout = n_playout
# 将当前棋盘状态传入,执行蒙特卡洛树搜索的四个过程
def _playout(self, state):
"""完整的执行选择、扩展评估和回传更新等步骤"""
node = self._root
# 选择
while 1:
if node.is_leaf():
break
action, node = node.select(self._c_puct)
state.do_move(action)
# 扩展及评估
action_probs, leaf_value = self._policy(state)
end, winner = state.game_end()
if not end:
node.expand(action_probs)
else:
if winner == -1: # 平局
leaf_value = 0.0
else:
leaf_value = (1.0 if winner == state.get_current_player() else -1.0)
# 回传更新
node.update_recursive(-leaf_value)
# 返回该状态下所有的可行动作,以及对应的概率
def get_move_probs(self, state, temp=1e-3):
"""
:param state:
:param temp: 控制探索程度AlphaZero中令其前30把为1剩下的为0
"""
# 重复执行 _playout()
for n in range(self._n_playout):
state_copy = copy.deepcopy(state)
# (state_copy.get_current_player())
self._playout(state_copy)
act_visits = [(act, node._n_visits) for act, node in self._root._children.items()]
acts, visits = zip(*act_visits)
act_probs = Softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
return acts, act_probs
"""在自我对弈的过程中复用搜索的子树,传入上一步最终执行的动作 last_move 正常来说该动作对应了当前搜索树的根节点的子结点
那么将该节点设为根节点就能继续向下搜索;如果不是根节点,则会重新初始化一个根节点,这样后续执行蒙特卡洛搜索时就会完全重新开始了
起到了重置的效果"""
def update_with_move(self, last_move):
if last_move in self._root._children:
self._root = self._root._children[last_move]
self._root._parent = None
else:
self._root = TreeNode(None, 1.0)
class MCTSPlayer(object):
"""AI player based on MCTS"""
# 通过创建MCTS类的实例来执行蒙特卡洛树搜索的逻辑,另外传入的self._is_selfplay用于区分当前是否在进行自我对弈
def __init__(self, policy_value_function, c_puct=5, n_playout=2000, is_selfplay=0):
"""
:param policy_value_function: 策略价值网络
:param c_puct: 控制探索程度
:param n_playout:
:param is_selfplay: 区分当前是否在自我对弈
"""
self.mcts = MCTS(policy_value_function, c_puct, n_playout)
self._is_selfplay = is_selfplay
def get_action(self, board, temp=1e-3, return_prob=0):
sensible_moves = board.available
move_probs = np.zeros(board.width * board.height)
if len(sensible_moves) > 0:
acts, probs = self.mcts.get_move_probs(board, temp)
move_probs[list(acts)] = probs
if self._is_selfplay:
move = np.random.choice(acts,
p=0.75 * probs + 0.25 * np.random.dirichlet(0.3 * np.ones(len(probs))))
# 更新根节点,复用子树
self.mcts.update_with_move(move)
else:
move = np.random.choice(acts, p=probs)
# 重置根节点
self.mcts.update_with_move(-1)
location = board.move_to_location(move)
print("AI move: %d,%d\n" % (location[0], location[1]))
if return_prob:
return move, move_probs
else:
return move
else:
print("WARNING: the board is full")
def set_player_ind(self, p):
self.player = p
def reset_player(self):
self.mcts.update_with_move(-1)
def __str__(self):
return "MCTS {} ".format(self.player)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ma-canglong/gobang.git
git@gitee.com:ma-canglong/gobang.git
ma-canglong
gobang
gobang
master

搜索帮助