Iterative Policy Improvement

Iterative Policy Improvement

Introduction

In this tutorial, I am going to code up the iterative policy improvement algorithm. It starts from a uniform random policy, computes the value function, and use it to update the policy again. This process continuous until the optimal policy is reached. The policy evaluation part was covered on a previous tutorial. This tutorial is an extension on that.

The grid world example and psuedo-code are taken from the University of Alberta’s Fundamentals of Reinforcement Learning course. The lecture corresponding to this tutorial can be found here. Please note that my implementation has a slight variation in the value function. This because of the way I coded up the environment.

The Grid World

The gird world is a 4×4 matrix with the first and last cells are being terminal states. The only possible actions the agent can take any time are up, down, left, or right. For every action except the terminal states the reward will be −1. If the agent bumbs into the wall it will remain in the same state. Since the agent is incentivised to get a 0 reward only at the terminal states, it will be learning to navigate to terminal cells from any other cells.

Maze example
Maze example

Iterative Policy Evaluation

We can see the value function gets converged to a fixed one, no matter whatever be the initialial value function is. An iteration of around 200 seems enough for convergence. 

 

Optimal Value Function and Policy

The policy is updated with value function and wiseversa. This is continued until optimal policy is reached. i.e. policy $P_{n+1}$ becomes same as policy $P_n$. The following figure shows this conceptual process of iterative policy and value function improvement.

Convergence of policy and value function

The optimal policy and value function of this grid world example is shown below.

Optimal policy and value function
Optimal policy and value function

Learned Agent

Once the optimum policy is found out, the learned agent (denoted with the smiley) can navigate around this environment starting from any arbitrary cell. Here I have taken few starting points and showing the decision making and movement of the learned agent to reach the closest goal.

Code

The code to reproduce the results are given below.

#!/usr/bin/python
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt


class Env(object):

    def __init__(self):

        # self.states = np.zeros((4,4))

        self.states = np.random.random((4, 4))

        # self.fstates = np.zeros((4,4))

        self.fstates = np.random.random((4, 4))


class Agent(object):

    def __init__(self, start=[0, 0]):
        self.start = start
        self.cstate = start
        self.reward = 0.
        self.validActions = ['U', 'D', 'L', 'R']
        self.policy = np.ones((4, 4, 4))

        # self.policy = np.random.randint(0, 2, size=(4,4,4))

        self.npolicy = np.zeros((4, 4, 4))

    def showPolicy(
        self,
        E,
        k=0,
        m='',
        ):
        arrows = '\xe2\x86\x91\xe2\x86\x93\xe2\x86\x90\xe2\x86\x92'
        shift = np.array([(-0.4, 0.), (0.4, 0.), (0., -0.4), (0., 0.4)])
        (fig, ax) = plt.subplots()
        for ((i, j), z) in np.ndenumerate(E):
            ax.text(j, i, round(z, 4), ha='center', va='center')
        for action in self.validActions:
            ai = self.validActions.index(action)
            dirs = self.npolicy[:, :, ai] > 0
            self.npolicy[:, :, ai] = dirs
            for ((i, j), f) in np.ndenumerate(dirs):
                if f:
                    ax.text(j + shift[ai][1], i + shift[ai][0],
                            arrows[ai], ha='center', va='center')
        ax.text(self.cstate[1] - 0.25, self.cstate[0], ':-)',
                ha='center', va='center')
        ax.matshow(E, cmap='winter')
        plt.title(m)
        plt.axis('off')

        # plt.show()

        plt.savefig('Junkyard/anim' + str('{0:0=4d}'.format(k)) + '.png'
                    , dpi=80)
        plt.close()

    def checkNextState(self, action):
        if action not in self.validActions:
            print 'Not a valid action'
        elif action == 'U':
            nstate = [self.cstate[0] - 1, self.cstate[1]]
            self.reward = -1.
        elif action == 'D':
            nstate = [self.cstate[0] + 1, self.cstate[1]]
            self.reward = -1.
        elif action == 'L':
            nstate = [self.cstate[0], self.cstate[1] - 1]
            self.reward = -1.
        elif action == 'R':
            nstate = [self.cstate[0], self.cstate[1] + 1]
            self.reward = -1.
        if nstate[0] < 0 or nstate[0] > 3 or nstate[1] < 0 or nstate[1] \
            > 3:
            self.reward = -1.
            nstate = self.cstate
        if nstate == [0, 0]:
            self.reward = 0.
        if nstate == [3, 3]:
            self.reward = 0.
        return nstate

    def takeAction(self, action):
        self.cstate = self.checkNextState(action)

    def gpolicy(self, state, a):
        a = ['U', 'D', 'L', 'R'].index(a)
        return self.policy[state[0], state[1], a]


def policyEvaluation(myEnv, myAgent, commands):

    k = 0
    while True:
        k += 1
        for ((i, j), z) in np.ndenumerate(myEnv.states):
            sums = 0
            myAgent.start = [i, j]
            myAgent.cstate = [i, j]
            for act in commands:
                if myAgent.start == [0, 0]:
                    step = [0, 0]
                    myAgent.reward = 0.
                    break
                elif myAgent.start == [3, 3]:
                    step = [3, 3]
                    myAgent.reward = 0.
                    break
                else:
                    step = myAgent.checkNextState(act)
                sums += 0.25 * (myAgent.reward + myEnv.states[step[0],
                                step[1]])
            myEnv.fstates[i, j] = sums
        m = 'Performaing Policy Evaluation-Iteration ' + str(k)

        # myAgent.showPolicy(myEnv.fstates,k,m)

        myEnv.states = myEnv.fstates
        if k > 200:
            break


def policyImprovement(myEnv, myAgent, commands):
    k = 0
    for ((i, j), z) in np.ndenumerate(myEnv.states):
        k += 1
        myAgent.start = [i, j]
        myAgent.cstate = [i, j]
        oldacts = myAgent.policy[i, j, :] > 0
        nextas = []
        for act in commands:
            if myAgent.start == [0, 0]:
                step = [0, 0]
                myAgent.reward = 0.
                break
            elif myAgent.start == [3, 3]:
                step = [3, 3]
                myAgent.reward = 0.
                break
            else:
                step = myAgent.checkNextState(act)
            nextv = myAgent.reward + myEnv.states[step[0], step[1]]
            nextas.append(nextv)
        if len(nextas):
            newdir = nextas == np.max(nextas)
            olddir = myAgent.policy[i, j, :] == 1.
            if sum(olddir != newdir) > 0:

                    # print('Found Change at: ', i, j)

                myAgent.npolicy[i, j, :] = newdir


            # m = 'Performaing Policy Improvement at Cell ('+str(i)+','+str(j)+')'
            # myAgent.showPolicy(myEnv.fstates,k,m)

policyStable = False
myEnv = Env()
myAgent = Agent([0, 0])
commands = ['U', 'D', 'L', 'R']

while ~policyStable:
    policyEvaluation(myEnv, myAgent, commands)
    policyImprovement(myEnv, myAgent, commands)
    policyStable = (myAgent.policy == myAgent.npolicy).all()
    if ~policyStable:
        myAgent.policy = myAgent.npolicy

if policyStable:
    m = 'Found Optimum Policy & Value Function'
    print m

    # k = 0
    # myAgent.showPolicy(myEnv.fstates,k, m)

myAgent.cstate = np.random.randint(0, 4, size=(1, 2))[0]

# myAgent.cstate = np.array((1,2))

k = 0
while True:
    myAgent.showPolicy(myEnv.fstates, k, m='Navigating')
    k += 1
    a = myAgent.policy[myAgent.cstate[0], myAgent.cstate[1]]
    bacti = np.where(a == 1)[0]
    if bacti.shape[0] > 1:
        myAgent.showPolicy(myEnv.fstates, k, m='Breaking Ties')
        k += 1
        bacti = np.random.choice(bacti)
    else:
        bacti = bacti[0]
    myAgent.showPolicy(myEnv.fstates, k, m='Taking ' + commands[bacti])
    k += 1
    myAgent.takeAction(commands[bacti])
    if myAgent.cstate == [0] * 2 or myAgent.cstate == [3] * 2:
        myAgent.showPolicy(myEnv.fstates, k, m='Reached.... :-)')
        k += 1
        break

1 thought on “Iterative Policy Improvement”

  1. Pingback: Coding a Simple Markov Decision Process - Intuitive Tutorials

Leave a Comment

Your email address will not be published.