From-scratch implementation of AlphaZero for Connect4

Share on facebook
Share on twitter
Share on linkedin
Share on email
Share on whatsapp
Share on telegram
Share on pocket

Step-by-step illustration on how one can implement AlphaZero on games using just PyTorch and standard python libraries

In 2016, Google DeepMind created a big stir when its computer program AlphaGo defeated reigning Go world champion Lee Sedol 4–1 in a match watched by millions, a feat never before achieved by any computer program in the ultra-complicated game of Go which has been dominated by humans until then. However, AlphaGo Zero, published by DeepMind about a year later in 2017, pushed boundaries one big step further by achieving a similar feat without any human data inputs. (AlphaGo referenced Go Grandmaster games for initial training) A subsequent paper released by the same group successfully applied the same reinforcement learning + supervised learning framework to chess, outperforming the previous best chess program Stockfish after just 4 hours of training.

Awed by the power of such reinforcement learning models, I wanted to understand how it works to gain some insights, and there’s nothing better than trying to build my own chess AI program from scratch, closely following the methods as described in the papers above. However, things quickly got too expensive to bear, as even though the program was up and running, training it to a reasonable skill level would most likely require millions in terms of GPU and TPU costs.

Unable to match the deep pockets of Google, I decided to try to implement AlphaZero on Connect4 instead, a game which is much simpler than chess and would be more gentle on computational power. The point here, is to demonstrate that the AlphaZero algorithm works well to create a powerful Connect4 AI program, eventually. The implementation scripts on the methods described here are all available on my Github repo.

The Connect4 Board

Firstly, we need to create the Connect4 board in Python for us to play around with. I’ve created a class called “board” with 4 methods ”__init__”, “drop_piece”, “check_winner”, “actions”.

class board():
    def __init__(self):
        self.init_board = np.zeros([6,7]).astype(str)
        self.init_board[self.init_board == "0.0"] = " "
        self.player = 0
        self.current_board = self.init_board
Connect4 board in Python

1) “__init__” constructor initializes an empty Connect4 board of 6 rows and 7 columns as an np.array, stores the board state as well as the current player to play

2) “drop_piece” updates the board with “X” or “O” as each player plays

3) “check_winner” returns True if somebody wins in the current board state

4) “actions” returns all possible moves which can be played given the current board state, so that no illegal moves are played

The Big Picture

There are 3 key components in AlphaZero, and I will describe their implementations in more detail later. They are:

1) Deep convolutional residual neural network
Input : Connect 4 board state
Outputs : policy(probability distribution of possible moves), value(O wins: +1, X wins:-1, draw:0)

2) Monte-Carlo Tree Search (MCTS)
Self-play guided by policy from neural network to generate games dataset to train neural network, in an iterative process

3) Evaluate neural network
Player vs player, each guided by current net and previous net respectively, retain net that wins the match for next iteration

Deep Convolutional Residual Neural Network

Neural network architecture of AlphaZero used here

We use a deep convolutional residual neural network (using PyTorch) with the above architecture similar to AlphaZero to map an input Connect4 board state to its associated policy and value. The policy is essentially a probability distribution of next moves the player should move from the current board state (the strategy), and the value represents the probability of the current player winning from that board state. This neural net is an integral part of the MCTS, where it helps guide the tree search via its policy and value outputs as we will see later. We build this neural net (which I call ConnectNet) using one initial convolution block, followed by 19 residual blocks and finally one output block as detailed below.

Convolutional Block

class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.action_size = 7
        self.conv1 = nn.Conv2d(3, 128, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128)

    def forward(self, s):
        s = s.view(-1, 3, 6, 7)  # batch_size x channels x board_x x board_y
        s = F.relu(self.bn1(self.conv1(s)))
        return s

Residual Block

class ResBlock(nn.Module):
    def __init__(self, inplanes=128, planes=128, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out

Output Block

class OutBlock(nn.Module):
    def __init__(self):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(128, 3, kernel_size=1) # value head
        self.bn = nn.BatchNorm2d(3)
        self.fc1 = nn.Linear(3*6*7, 32)
        self.fc2 = nn.Linear(32, 1)
        
        self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head
        self.bn1 = nn.BatchNorm2d(32)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.fc = nn.Linear(6*7*32, 7)
    
    def forward(self,s):
        v = F.relu(self.bn(self.conv(s))) # value head
        v = v.view(-1, 3*6*7)  # batch_size X channel X height X width
        v = F.relu(self.fc1(v))
        v = torch.tanh(self.fc2(v))
        
        p = F.relu(self.bn1(self.conv1(s))) # policy head
        p = p.view(-1, 6*7*32)
        p = self.fc(p)
        p = self.logsoftmax(p).exp()
        return p, v

Putting it altogether

class ConnectNet(nn.Module):
    def __init__(self):
        super(ConnectNet, self).__init__()
        self.conv = ConvBlock()
        for block in range(19):
            setattr(self, "res_%i" % block,ResBlock())
        self.outblock = OutBlock()
    
    def forward(self,s):
        s = self.conv(s)
        for block in range(19):
            s = getattr(self, "res_%i" % block)(s)
        s = self.outblock(s)
        return s

The raw Connect4 board is encoded into a 6 by 7 by 3 matrix of 1’s and 0’s before input into the neural net, where the 3 channels each of board dimensions 6 by 7 encode the presence of “X”, “O” (1 being present and 0 being empty), and player to move (0 being “O” and 1 being “X”), respectively.

### Encoder to encode Connect4 board for neural net input
def encode_board(board):
    board_state = board.current_board
    encoded = np.zeros([6,7,3]).astype(int)
    encoder_dict = {"O":0, "X":1}
    for row in range(6):
        for col in range(7):
            if board_state[row,col] != " ":
                encoded[row, col, encoder_dict[board_state[row,col]]] = 1
    if board.player == 1:
        encoded[:,:,2] = 1 # player to move
    return encoded

Finally, to properly train this neural net which has a two-headed output, a custom loss function (AlphaLoss) is defined as simply the sum of the mean-squared error value and cross-entropy policy losses.

### Neural Net loss function implemented via PyTorch
class AlphaLoss(torch.nn.Module):
    def __init__(self):
        super(AlphaLoss, self).__init__()

    def forward(self, y_value, value, y_policy, policy):
        value_error = (value - y_value) ** 2
        policy_error = torch.sum((-policy* 
                                (1e-8 + y_policy.float()).float().log()), 1)
        total_error = (value_error.view(-1).float() + policy_error).mean()
        return total_error

Monte-Carlo Tree Search

A game can be described as a tree in which the root is the board state and its branches are all the possible states that can result from it. In a game such as Go where the number of branches increase exponentially as the game progresses, it is practically impossible to simply brute-force evaluate all branches. Hence, the Monte-Carlo Tree Search (MCTS) algorithm is devised to search in a smarter and more efficient way. Essentially, one wants to optimize the exploration-exploitation tradeoff, where one wants to search just exhaustively enough (exploration) to discover the best possible reward (exploitation). This is succinctly described in a single equation in the MCTS algorithm that defines the upper confidence bound (UCB):

Here, Q is the mean action value (average reward), cpuct is a constant determining the level of exploration (set as 1), P(s=state,a=action) is the prior probability of choosing action given by the policy output of the neural net, N(s,a) is the number of times the branch corresponding to action has been visited. The N sum over b in the numerator sums over all explored branches (actions) from state s which is essentially the number of times the parent of (s,a) has been visited.

The MCTS algorithm proceeds in the following steps.

  1. Select
Select — AlphaGo Zero
### Recursively selects the nodes based on highest UCB (best move) until leaf node or terminal node is reached. Adds ###node of best move if its not yet created.
def select_leaf(self):
    current = self
    while current.is_expanded:
      best_move = current.best_child()
      current = current.maybe_add_child(best_move)
    return current

Starting from s, the search selects the next branch that has the highest UCB, until a leaf node ( a state in which none of its branches have yet been explored) or a terminal node (end game state) is reached. We can see that if the reward Q is high, then it is more likely to choose that branch. The second exploration term also plays a big part, where we see that if action is only visited a few times, then this term would be large and the algorithm is then more likely to choose the associated branch. The neural net guides the selection by providing the prior probability P, which initially would be random when the neural network is untrained.

2. Expand and Evaluate

Expand and Evaluate — AlphaGo Zero

### Expand only nodes that result from legal moves, mask illegal moves and add Dirichlet noise to prior probabilities of ###root node.
def expand(self, child_priors):
    self.is_expanded = True
    action_idxs = self.game.actions(); c_p = child_priors
    if action_idxs == []:
        self.is_expanded = False
    self.action_idxes = action_idxs
    # mask all illegal actions
    for i in range(len(child_priors)): 
        if i not in action_idxs:
            c_p[i] = 0.0000000000
    # add dirichlet noise to child_priors in root node
    if self.parent.parent == None: 
        c_p = self.add_dirichlet_noise(action_idxs,c_p)
    self.child_priors = c_p

Here, the leaf node is expanded by evaluating the states associated with the expanded nodes with the neural net to obtain and store P. Of course, illegal moves should not be expanded and will be masked (by setting prior probabilities to zero). We will also add Dirichlet noise here if the node is a root node to provide randomness to the exploration so that every MCTS simulation would be likely different.

3. Backup

Backup — AlphaGo Zero
### Recursively update the visits counts and values of nodes once leaf node is evaluated.
def backup(self, value_estimate: float):
    current = self
    while current.parent is not None:
        current.number_visits += 1
        if current.game.player == 1: # same as current.parent.game.player = 0
            current.total_value += (1*value_estimate) # value estimate +1 = O wins
        elif current.game.player == 0: # same as current.parent.game.player = 1
            current.total_value += (-1*value_estimate)
        current = current.parent

Now, the leaf node is evaluated by the neural net to determine its value v. This value v is then used to update the average v of all parent nodes above it. The update should be such that O and X would play to their best (Minimax) eg. If O wins (v = +1 evaluated for leaf node), then in the direct parent node of this leaf node it would be O’s turn to play and we would update v = +1 for this parent node, then update v = -1 for all other parent nodes where X is to play to denote that this action is bad for X. Finally, update v = 0 in case of a draw.

### Code snippet for each simulation of Select, Expand and Evaluate, and Backup. num_reads here is the parameter ###controlling the number of simulations.
def UCT_search(game_state, num_reads,net,temp):
    root = UCTNode(game_state, move=None, parent=DummyNode())
    for i in range(num_reads):
        leaf = root.select_leaf()
        encoded_s = ed.encode_board(leaf.game); encoded_s = encoded_s.transpose(2,0,1)
        encoded_s = torch.from_numpy(encoded_s).float().cuda()
        child_priors, value_estimate = net(encoded_s)
        child_priors = child_priors.detach().cpu().numpy().reshape(-1); value_estimate = value_estimate.item()
        if leaf.game.check_winner() == True or leaf.game.actions() == []: # if somebody won or draw
            leaf.backup(value_estimate); continue
        leaf.expand(child_priors) # need to make sure valid moves
        leaf.backup(value_estimate)
    return root

The above process of Select, Expand and Evaluate and Backup represents one search path or simulation for each root node for the MCTS algorithm. In AlphaGo Zero, 1600 such simulations are done. For our Connect4 implementation, we only run 777 since it’s a much simpler game. After running 777 simulations for that root node, we will then formulate the policy p for the root node which is defined to be proportional to the number of visits of its direct child nodes. This policy p will then be used to select the next move to the next board state, and this board state will then be treated as the root node for next MCTS simulations and so on until the game terminates when someone wins or draw. The whole procedure in which one runs MCTS simulations for each root node as one moves through until the end of the game is termed as MCTS self-play.

### Function to execute MCTS self-play
def MCTS_self_play(connectnet,num_games,cpu):
    for idxx in range(0,num_games):
        current_board = c_board()
        checkmate = False
        dataset = [] # to store state, policy, value for neural network training
        states = []; value = 0; move_count = 0
        # play game against self
        while checkmate == False and current_board.actions() != []:
            # set temperature parameter
            if move_count < 11:
                t = 1
            else:
                t = 0.1
            states.append(copy.deepcopy(current_board.current_board))
            board_state = copy.deepcopy(ed.encode_board(current_board))
            root = UCT_search(current_board,777,connectnet,t) # run 777 MCTS simulations
            policy = get_policy(root, t); print(policy) # formulate policy based on results of MCTS simulations
            current_board = do_decode_n_move_pieces(current_board,\
                                                    np.random.choice(np.array([0,1,2,3,4,5,6]), \
                                                                     p = policy)) # decode action and make a move
            dataset.append([board_state,policy]) # stores s, p
            print(current_board.current_board,current_board.player); print(" ")
            if current_board.check_winner() == True: # if somebody won, update v
                if current_board.player == 0: # X wins
                    value = -1
                elif current_board.player == 1: # O wins
                    value = 1
                checkmate = True
            move_count += 1
        dataset_p = []
        # update v for all (s, p) except for starting board state s
        for idx,data in enumerate(dataset):
            s,p = data
            if idx == 0:
                dataset_p.append([s,p,0])
            else:
                dataset_p.append([s,p,value])
        del dataset
        # save (s,p,v) datasets for neural net training
        save_as_pickle("dataset_cpu%i_%i_%s" % (cpu,idxx, datetime.datetime.today().strftime("%Y-%m-%d")),dataset_p)

In each step of the MCTS self-play where a MCTS simulation is run, we will have a board state s, its associated policy p, and value v, hence when the MCTS self-play game finishes, one will have a set of (s, p, v) values. These set of (s, p, v) values will then be used to train the neural network to improve its policy and value prediction, and this trained neural network will then be used to guide the subsequent MCTS iteratively. In this way, one can see that eventually after many, many iterations, the neural net and MCTS together would be very good at generating optimal moves.

Evaluate Neural Network

After one iteration in which the neural net is trained using MCTS self-play data, this trained neural net is then pitted against its previous version, again using MCTS guided by the respective neural net. The neural network that performs better (eg. Wins the majority of games) would then be used for the next iteration. This ensures that the net is always improving.

Iteration Pipeline

In summary, a full iteration pipeline consists of:

1. Self-play using MCTS to generate game datasets (spv), with the neural net guiding the search by providing the prior probabilities in choosing the action

2. Train the neural network using the (spv) datasets generated from MCTS self-play

3. Evaluate the trained neural net (at predefined epoch checkpoints) by pitting it against the neural net from the previous iteration, again using MCTS guided by the respective neural nets, and keep only the neural net that performs better.

4. Rinse and repeat

Results

Iteration 0:
alpha_net_0 (Initialized with random weights)
151 games of MCTS self-play generated


Iteration 1:
alpha_net_1 (trained from iteration 0)
148 games of MCTS self-play generated


Iteration 2:
alpha_net_2 (trained from iteration 1)
310 games of MCTS self-play generated


Evaluation 1:
Alpha_net_2 is pitted against alpha_net_0
Out of 100 games played, alpha_net_2 won 83 and lost 17


Iteration 3:
alpha_net_3 (trained from iteration 2)
584 games of MCTS self-play generated


Iteration 4:
alpha_net_4 (trained from iteration 3)
753 games of MCTS self-play generated


Iteration 5:
alpha_net_5 (trained from iteration 4)
1286 games of MCTS self-play generated


Iteration 6:
alpha_net_6 (trained from iteration 5)
1670 games of MCTS self-play generated


Evaluation 2:
Alpha_net_6 pitted against alpha_net_3
Out of 100 games played, alpha_net_6 won 92 and lost 8.


Typical Loss vs Epoch for neural net training at each iteration.

Over a period of several weeks of sporadic training on Google Colab, a total of 6 iterations for a total of 4902 MCTS self-play games was generated. A typical loss vs epoch of the neural network training for each iteration is shown above, showing that training proceeds quite well. From both evaluations 1 & 2 at selected points in the iterations, we can see that the neural net is indeed always improving and becoming stronger than its previous version in generating winning moves.

Now is probably time to show some actual games! The gif below shows an example game between alpha_net_6 (playing as X) and alpha_net_3 (playing as O), where X won.

At this moment, I am still training net/running MCTS self-play. I hope to be able to reach a stage whereby the MCTS + net are able to generate perfect moves (Connect4 is a solved game, such that the player that moves first can always force a win), but who knows how many iterations that would need…

Anyway, that’s all folks! Hope that you would find this post interesting and useful. Any comments on the implementations and improvements are greatly welcome. For further reading into more details on how AlphaZero works, nothing beats reading DeepMind’s actual paper, which I highly recommend!

The original article was first published here.

Leave a Reply

Please Login to comment

  Subscribe  
Notify of
Previous

Object Detection for Product Images

HPC to Deep Learning from an Asia Perspective

Next