Click or drag to resize
Accord.NET (logo)

Sarsa Class

Sarsa learning algorithm.
Inheritance Hierarchy
SystemObject
  Accord.MachineLearningSarsa

Namespace:  Accord.MachineLearning
Assembly:  Accord.MachineLearning (in Accord.MachineLearning.dll) Version: 3.8.0
Syntax
public class Sarsa
Request Example View Source

The Sarsa type exposes the following members.

Constructors
  NameDescription
Public methodSarsa(Int32, Int32, IExplorationPolicy)
Initializes a new instance of the Sarsa class.
Public methodSarsa(Int32, Int32, IExplorationPolicy, Boolean)
Initializes a new instance of the Sarsa class.
Top
Properties
  NameDescription
Public propertyActionsCount
Amount of possible actions.
Public propertyDiscountFactor
Discount factor, [0, 1].
Public propertyExplorationPolicy
Exploration policy.
Public propertyLearningRate
Learning rate, [0, 1].
Public propertyStatesCount
Amount of possible states.
Top
Methods
  NameDescription
Public methodEquals
Determines whether the specified object is equal to the current object.
(Inherited from Object.)
Protected methodFinalize
Allows an object to try to free resources and perform other cleanup operations before it is reclaimed by garbage collection.
(Inherited from Object.)
Public methodGetAction
Get next action from the specified state.
Public methodGetHashCode
Serves as the default hash function.
(Inherited from Object.)
Public methodGetType
Gets the Type of the current instance.
(Inherited from Object.)
Protected methodMemberwiseClone
Creates a shallow copy of the current Object.
(Inherited from Object.)
Public methodToString
Returns a string that represents the current object.
(Inherited from Object.)
Public methodUpdateState(Int32, Int32, Double)
Update Q-function's value for the previous state-action pair.
Public methodUpdateState(Int32, Int32, Double, Int32, Int32)
Update Q-function's value for the previous state-action pair.
Top
Extension Methods
  NameDescription
Public Extension MethodHasMethod
Checks whether an object implements a method with the given name.
(Defined by ExtensionMethods.)
Public Extension MethodIsEqual
Compares two objects for equality, performing an elementwise comparison if the elements are vectors or matrices.
(Defined by Matrix.)
Public Extension MethodTo(Type)Overloaded.
Converts an object into another type, irrespective of whether the conversion can be done at compile time or not. This can be used to convert generic types to numeric types during runtime.
(Defined by ExtensionMethods.)
Public Extension MethodToTOverloaded.
Converts an object into another type, irrespective of whether the conversion can be done at compile time or not. This can be used to convert generic types to numeric types during runtime.
(Defined by ExtensionMethods.)
Top
Remarks
The class provides implementation of Sarsa algorithm, known as on-policy Temporal Difference control.
Examples

The following example shows how to learn a model using reinforcement learning through the Sarsa algorithm. The following code has been inherited from the AForge.NET Framework, and has not been modified ever since. If you have better ideas on how to improve its interface, please share it in the project's issue tracker at https://github.com/accord-net/framework/issues. If you would like, and if your ideas are feasible and encouraging enough, you can be named an official contributor of the project. If you would like, you could opt to "inherit" the reinforcement learning portion of the project such that you could be free to commit, modify and, more importantly, authorship those modules directly from your own GitHub account without having to wait for Pull Request approvals. You can be listed as an official author of the Accord.NET Framework, making it possible to list the creation or shared authorship of the reinforcement learning project in your CV.

// Fix the random number generator
Accord.Math.Random.Generator.Seed = 0;

// In this example, we will be using the Sarsa algorithm
// to make a robot learn how to navigate a map. The map
// is shown below, where a 1 denotes a wall and 0 denotes
// areas where the robot can navigate:
// 
int[,] map =
{
    { 1, 1, 1, 1, 1, 1, 1, 1, 1 },
    { 1, 1, 0, 0, 0, 0, 0, 0, 1 },
    { 1, 1, 0, 0, 0, 1, 1, 0, 1 },
    { 1, 0, 0, 1, 0, 0, 0, 0, 1 },
    { 1, 0, 0, 1, 1, 1, 1, 0, 1 },
    { 1, 0, 0, 1, 1, 0, 0, 0, 1 },
    { 1, 1, 0, 1, 0, 0, 0, 0, 1 },
    { 1, 1, 0, 1, 0, 1, 1, 0, 1 },
    { 1, 1, 1, 1, 1, 1, 1, 1, 1 },
};

// Now, we define the initial and target points from which the
// robot will be spawn and where it should go, respectively:
int agentStartX = 1;
int agentStartY = 4;

int agentStopX = 7;
int agentStopY = 4;

// The robot is able to sense the environment though 8 sensors
// that capture whether the robot is near a wall or not. Based
// on the robot's current location, the sensors will return an
// integer number representing which sensors have detected walls

Func<int, int, int> getState = (int x, int y) =>
{
    int c1 = (map[y - 1, x - 1] != 0) ? 1 : 0;
    int c2 = (map[y - 1, x + 0] != 0) ? 1 : 0;
    int c3 = (map[y - 1, x + 1] != 0) ? 1 : 0;
    int c4 = (map[y + 0, x + 1] != 0) ? 1 : 0;
    int c5 = (map[y + 1, x + 1] != 0) ? 1 : 0;
    int c6 = (map[y + 1, x + 0] != 0) ? 1 : 0;
    int c7 = (map[y + 1, x - 1] != 0) ? 1 : 0;
    int c8 = (map[y + 0, x - 1] != 0) ? 1 : 0;

    return c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7);
};

// The actions are the possible directions the robot can go:
// 
//   - case 0: go to north (up)
//   - case 1: go to east (right)
//   - case 2: go to south (down)
//   - case 3: go to west (left)
// 

int learningIterations = 1000;
double explorationRate = 0.5;
double learningRate = 0.5;

double moveReward = 0;
double wallReward = -1;
double goalReward = 1;

// The function below specifies how the robot should perform an action given its 
// current position and an action number. This will cause the robot to update its 
// current X and Y locations given the direction (above) it was instructed to go:
Func<int, int, int, Tuple<double, int, int>> doAction = (int currentX, int currentY, int action) =>
{
    // default reward is equal to moving reward
    double reward = moveReward;

    // moving direction
    int dx = 0, dy = 0;

    switch (action)
    {
        case 0:         // go to north (up)
            dy = -1;
            break;
        case 1:         // go to east (right)
            dx = 1;
            break;
        case 2:         // go to south (down)
            dy = 1;
            break;
        case 3:         // go to west (left)
            dx = -1;
            break;
    }

    int newX = currentX + dx;
    int newY = currentY + dy;

    // check new agent's coordinates
    if ((map[newY, newX] != 0) || (newX < 0) || (newX >= map.Columns()) || (newY < 0) || (newY >= map.Rows()))
    {
        // we found a wall or got outside of the world
        reward = wallReward;
    }
    else
    {
        currentX = newX;
        currentY = newY;

        // check if we found the goal
        if ((currentX == agentStopX) && (currentY == agentStopY))
            reward = goalReward;
    }

    return Tuple.Create(reward, currentX, currentY);
};


// After defining all those functions, we create a new Sarsa algorithm:
var explorationPolicy = new EpsilonGreedyExploration(explorationRate);
var tabuPolicy = new TabuSearchExploration(4, explorationPolicy);
var sarsa = new Sarsa(256, 4, tabuPolicy);

// curent coordinates of the agent
int agentCurrentX = -1;
int agentCurrentY = -1;

bool needToStop = false;
int iteration = 0;

// loop
while ((!needToStop) && (iteration < learningIterations))
{
    // set exploration rate for this iteration
    explorationPolicy.Epsilon = explorationRate - ((double)iteration / learningIterations) * explorationRate;

    // set learning rate for this iteration
    sarsa.LearningRate = learningRate - ((double)iteration / learningIterations) * learningRate;

    // clear tabu list
    tabuPolicy.ResetTabuList();

    // reset agent's coordinates to the starting position
    agentCurrentX = agentStartX;
    agentCurrentY = agentStartY;

    // steps performed by agent to get to the goal
    int steps = 1;

    // previous state and action
    int previousState = getState(agentCurrentX, agentCurrentY);
    int previousAction = sarsa.GetAction(previousState);

    // update agent's current position and get his reward
    var r = doAction(agentCurrentX, agentCurrentY, previousAction);
    double reward = r.Item1;
    agentCurrentX = r.Item2;
    agentCurrentY = r.Item3;

    while ((!needToStop) && ((agentCurrentX != agentStopX) || (agentCurrentY != agentStopY)))
    {
        steps++;

        // set tabu action
        tabuPolicy.SetTabuAction((previousAction + 2) % 4, 1);

        // get agent's next state
        int nextState = getState(agentCurrentX, agentCurrentY);

        // get agent's next action
        int nextAction = sarsa.GetAction(nextState);

        // do learning of the agent - update his Q-function
        sarsa.UpdateState(previousState, previousAction, reward, nextState, nextAction);

        // update agent's new position and get his reward
        r = doAction(agentCurrentX, agentCurrentY, nextAction);
        reward = r.Item1;
        agentCurrentX = r.Item2;
        agentCurrentY = r.Item3;

        previousState = nextState;
        previousAction = nextAction;
    }

    if (!needToStop)
    {
        // update Q-function if terminal state was reached
        sarsa.UpdateState(previousState, previousAction, reward);
    }

    iteration++;
}

// The end position for the robot will be (7, 4):
int finalPosX = agentCurrentX; // 7
int finalPosY = agentCurrentY; // 4;
See Also