1 Star 0 Fork 0

zhoub86/multiplayer-alphazero

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
trainer.py 4.06 KB
一键复制 编辑 原始数据 按行查看 历史
Nick Petosa 提交于 2019-11-24 16:05 . Initial commit
import time
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
from mcts import MCTS
from play import play_match
from players.uninformed_mcts_player import UninformedMCTSPlayer
from players.deep_mcts_player import DeepMCTSPlayer
# Object that coordinates AlphaZero training.
class Trainer:
def __init__(self, game, nn, num_simulations, num_games, num_updates, buffer_size_limit, cpuct, num_threads):
self.game = game
self.nn = nn
self.num_simulations = num_simulations
self.num_games = num_games
self.num_updates = num_updates
self.buffer_size_limit = buffer_size_limit
self.training_data = np.zeros((0,3))
self.cpuct = cpuct
self.num_threads = num_threads
self.error_log = []
# Does one game of self play and generates training samples.
def self_play(self, temperature):
s = self.game.get_initial_state()
tree = MCTS(self.game, self.nn)
data = []
scores = self.game.check_game_over(s)
root = True
alpha = 1
weight = .25
while scores is None:
# Think
for _ in range(self.num_simulations):
tree.simulate(s, cpuct=self.cpuct)
# Fetch action distribution and append training example template.
dist = tree.get_distribution(s, temperature=temperature)
# Add dirichlet noise to root
if root:
noise = np.random.dirichlet(np.array(alpha*np.ones_like(dist[:,1].astype(np.float32))))
dist[:,1] = dist[:,1]*(1-weight) + noise*weight
root = False
data.append([s, dist[:,1], None]) # state, prob, outcome
# Sample an action
idx = np.random.choice(len(dist), p=dist[:,1].astype(np.float))
a = tuple(dist[idx, 0])
# Apply action
available = self.game.get_available_actions(s)
template = np.zeros_like(available)
template[a] = 1
s = self.game.take_action(s, template)
# Check scores
scores = self.game.check_game_over(s)
# Update training examples with outcome
for i, _ in enumerate(data):
data[i][-1] = scores
return np.array(data)
# Performs one iteration of policy improvement.
# Creates some number of games, then updates network parameters some number of times from that training data.
def policy_iteration(self, verbose=False):
temperature = 1
if verbose:
print("SIMULATING " + str(self.num_games) + " games")
start = time.time()
if self.num_threads > 1:
jobs = [temperature]*self.num_games
pool = ThreadPool(self.num_threads)
new_data = pool.map(self.self_play, jobs)
pool.close()
pool.join()
self.training_data = np.concatenate([self.training_data] + new_data, axis=0)
else:
for _ in range(self.num_games): # Self-play games
new_data = self.self_play(temperature)
self.training_data = np.concatenate([self.training_data, new_data], axis=0)
if verbose:
print("Simulating took " + str(int(time.time()-start)) + " seconds")
# Prune oldest training samples if a buffer size limit is set.
if self.buffer_size_limit is not None:
self.training_data = self.training_data[-self.buffer_size_limit:,:]
if verbose:
print("TRAINING")
start = time.time()
mean_loss = None
count = 0
for _ in range(self.num_updates):
self.nn.train(self.training_data)
new_loss = self.nn.latest_loss.item()
if mean_loss is None:
mean_loss = new_loss
else:
(mean_loss*count + new_loss)/(count+1)
count += 1
self.error_log.append(mean_loss)
if verbose:
print("Training took " + str(int(time.time()-start)) + " seconds")
print("Average train error:", mean_loss)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhoub86/multiplayer-alphazero.git
git@gitee.com:zhoub86/multiplayer-alphazero.git
zhoub86
multiplayer-alphazero
multiplayer-alphazero
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385