追記(2019/02/19)
猛烈に勘違いしていました。
AlphaZeroは、バリューネットワークの出力には[-1,1]レンジを使っていますが、MCTSでは[0,1]レンジだそうです。
talkchess.com
今までのいくつかの謎が氷解しました。
これに合わせてコードのコメントとオリジナルコードのバグを修正しました。
1年前大騒ぎになったAlphaZeroの論文が正式に出版されました。
擬似コードも公開されたので、今日はそれを読みましょう。
コメントを挿入しました。
"""Pseudocode description of the AlphaZero algorithm."""
import math
import numpy
import tensorflow as tf
from typing import List
class AlphaZeroConfig(object):
def __init__(self):
self.num_actors = 5000
self.num_sampling_moves = 30
self.max_moves = 512
self.num_simulations = 800
self.root_dirichlet_alpha = 0.3
self.root_exploration_fraction = 0.25
self.pb_c_base = 19652
self.pb_c_init = 1.25
self.training_steps = int(700e3)
self.checkpoint_interval = int(1e3)
self.window_size = int(1e6)
self.batch_size = 4096
self.weight_decay = 1e-4
self.momentum = 0.9
self.learning_rate_schedule = {
0: 2e-1,
100e3: 2e-2,
300e3: 2e-3,
500e3: 2e-4
}
class Node(object):
def __init__(self, prior: float):
self.visit_count = 0
self.to_play = -1
self.prior = prior
self.value_sum = 0
self.children = {}
def expanded(self):
return len(self.children) > 0
def value(self):
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
class Game(object):
def __init__(self, history=None):
self.history = history or []
self.child_visits = []
self.num_actions = 4672
def terminal(self):
pass
def terminal_value(self, to_play):
pass
def legal_actions(self):
return []
def clone(self):
return Game(list(self.history))
def apply(self, action):
self.history.append(action)
def store_search_statistics(self, root):
sum_visits = sum(child.visit_count for child in root.children.values())
self.child_visits.append([
root.children[a].visit_count / sum_visits if a in root.children else 0
for a in range(self.num_actions)
])
def make_image(self, state_index: int):
return []
def make_target(self, state_index: int):
return (self.terminal_value(state_index % 2),
self.child_visits[state_index])
def to_play(self):
return len(self.history) % 2
class ReplayBuffer(object):
def __init__(self, config: AlphaZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)
self.buffer.append(game)
def sample_batch(self):
move_sum = float(sum(len(g.history) for g in self.buffer))
games = numpy.random.choice(
self.buffer,
size=self.batch_size,
p=[len(g.history) / move_sum for g in self.buffer])
game_pos = [(g, numpy.random.randint(len(g.history))) for g in games]
return [(g.make_image(i), g.make_target(i)) for (g, i) in game_pos]
class Network(object):
def inference(self, image):
return (-1, {})
"""
http://talkchess.com/forum3/viewtopic.php?f=2&t=69175&start=70&sid=8eb37b9c943011e51c0c3a88b427b745
matthewlai san said,
"All the values in the search are [0, 1].
We store them as [-1, 1] only for network training, to have training targets centered around 0.
At play time, when network evaluations come back, we shift them to [0, 1] before doing anything with them.
Yes, all values are initialized to loss value."
"""
def inference_0to1value(self, image):
value, policy = self.inference(image)
value = (value + 1) / 2
return value, policy
def get_weights(self):
return []
class SharedStorage(object):
def __init__(self):
self._networks = {}
def latest_network(self) -> Network:
if self._networks:
return self._networks[max(self._networks.keys())]
else:
return make_uniform_network()
def save_network(self, step: int, network: Network):
self._networks[step] = network
def alphazero(config: AlphaZeroConfig):
storage = SharedStorage()
replay_buffer = ReplayBuffer(config)
for i in range(config.num_actors):
launch_job(run_selfplay, config, storage, replay_buffer)
train_network(config, storage, replay_buffer)
return storage.latest_network()
def run_selfplay(config: AlphaZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
while True:
network = storage.latest_network()
game = play_game(config, network)
replay_buffer.save_game(game)
def play_game(config: AlphaZeroConfig, network: Network):
game = Game()
while not game.terminal() and len(game.history) < config.max_moves:
action, root = run_mcts(config, game, network)
game.apply(action)
game.store_search_statistics(root)
return game
def run_mcts(config: AlphaZeroConfig, game: Game, network: Network):
root = Node(0)
evaluate(root, game, network)
add_exploration_noise(config, root)
for _ in range(config.num_simulations):
node = root
scratch_game = game.clone()
search_path = [node]
while node.expanded():
action, node = select_child(config, node)
scratch_game.apply(action)
search_path.append(node)
value = evaluate(node, scratch_game, network)
backpropagate(search_path, value, scratch_game.to_play())
return select_action(config, game, root), root
def select_action(config: AlphaZeroConfig, game: Game, root: Node):
visit_counts = [(child.visit_count, action)
for action, child in root.children.items()]
if len(game.history) < config.num_sampling_moves:
_, action = softmax_sample(visit_counts)
else:
_, action = max(visit_counts)
return action
def select_child(config: AlphaZeroConfig, node: Node):
_, action, child = max((ucb_score(config, node, child), action, child)
for action, child in node.children.items())
return action, child
def ucb_score(config: AlphaZeroConfig, parent: Node, child: Node):
pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
config.pb_c_base) + config.pb_c_init
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)
prior_score = pb_c * child.prior
value_score = child.value()
return prior_score + value_score
def evaluate(node: Node, game: Game, network: Network):
value, policy_logits = network.inference_0to1value(game.make_image(-1))
node.to_play = game.to_play()
policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions()}
policy_sum = sum(policy.values())
for action, p in policy.items():
node.children[action] = Node(p / policy_sum)
return value
def backpropagate(search_path: List[Node], value: float, to_play):
for node in search_path:
node.value_sum += value if node.to_play == to_play else (1 - value)
node.visit_count += 1
def add_exploration_noise(config: AlphaZeroConfig, node: Node):
actions = node.children.keys()
noise = numpy.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
frac = config.root_exploration_fraction
for a, n in zip(actions, noise):
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac
def train_network(config: AlphaZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
network = Network()
optimizer = tf.train.MomentumOptimizer(config.learning_rate_schedule,
config.momentum)
for i in range(config.training_steps):
if i % config.checkpoint_interval == 0:
storage.save_network(i, network)
batch = replay_buffer.sample_batch()
update_weights(optimizer, network, batch, config.weight_decay)
storage.save_network(config.training_steps, network)
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
weight_decay: float):
loss = 0
for image, (target_value, target_policy) in batch:
value, policy_logits = network.inference(image)
loss += (
tf.losses.mean_squared_error(value, target_value) +
tf.nn.softmax_cross_entropy_with_logits(
logits=policy_logits, labels=target_policy))
for weights in network.get_weights():
loss += weight_decay * tf.nn.l2_loss(weights)
optimizer.minimize(loss)
def softmax_sample(d):
return 0, 0
def launch_job(f, *args):
f(*args)
def make_uniform_network():
return Network()