Sarsa Class |
Namespace: Accord.MachineLearning
The Sarsa type exposes the following members.
Name | Description | |
---|---|---|
Sarsa(Int32, Int32, IExplorationPolicy) |
Initializes a new instance of the Sarsa class.
| |
Sarsa(Int32, Int32, IExplorationPolicy, Boolean) |
Initializes a new instance of the Sarsa class.
|
Name | Description | |
---|---|---|
ActionsCount |
Amount of possible actions.
| |
DiscountFactor |
Discount factor, [0, 1].
| |
ExplorationPolicy |
Exploration policy.
| |
LearningRate |
Learning rate, [0, 1].
| |
StatesCount |
Amount of possible states.
|
Name | Description | |
---|---|---|
Equals | Determines whether the specified object is equal to the current object. (Inherited from Object.) | |
Finalize | Allows an object to try to free resources and perform other cleanup operations before it is reclaimed by garbage collection. (Inherited from Object.) | |
GetAction |
Get next action from the specified state.
| |
GetHashCode | Serves as the default hash function. (Inherited from Object.) | |
GetType | Gets the Type of the current instance. (Inherited from Object.) | |
MemberwiseClone | Creates a shallow copy of the current Object. (Inherited from Object.) | |
ToString | Returns a string that represents the current object. (Inherited from Object.) | |
UpdateState(Int32, Int32, Double) |
Update Q-function's value for the previous state-action pair.
| |
UpdateState(Int32, Int32, Double, Int32, Int32) |
Update Q-function's value for the previous state-action pair.
|
Name | Description | |
---|---|---|
HasMethod |
Checks whether an object implements a method with the given name.
(Defined by ExtensionMethods.) | |
IsEqual |
Compares two objects for equality, performing an elementwise
comparison if the elements are vectors or matrices.
(Defined by Matrix.) | |
To(Type) | Overloaded.
Converts an object into another type, irrespective of whether
the conversion can be done at compile time or not. This can be
used to convert generic types to numeric types during runtime.
(Defined by ExtensionMethods.) | |
ToT | Overloaded.
Converts an object into another type, irrespective of whether
the conversion can be done at compile time or not. This can be
used to convert generic types to numeric types during runtime.
(Defined by ExtensionMethods.) |
The following example shows how to learn a model using reinforcement learning through the Sarsa algorithm. The following code has been inherited from the AForge.NET Framework, and has not been modified ever since. If you have better ideas on how to improve its interface, please share it in the project's issue tracker at https://github.com/accord-net/framework/issues. If you would like, and if your ideas are feasible and encouraging enough, you can be named an official contributor of the project. If you would like, you could opt to "inherit" the reinforcement learning portion of the project such that you could be free to commit, modify and, more importantly, authorship those modules directly from your own GitHub account without having to wait for Pull Request approvals. You can be listed as an official author of the Accord.NET Framework, making it possible to list the creation or shared authorship of the reinforcement learning project in your CV.
// Fix the random number generator Accord.Math.Random.Generator.Seed = 0; // In this example, we will be using the Sarsa algorithm // to make a robot learn how to navigate a map. The map // is shown below, where a 1 denotes a wall and 0 denotes // areas where the robot can navigate: // int[,] map = { { 1, 1, 1, 1, 1, 1, 1, 1, 1 }, { 1, 1, 0, 0, 0, 0, 0, 0, 1 }, { 1, 1, 0, 0, 0, 1, 1, 0, 1 }, { 1, 0, 0, 1, 0, 0, 0, 0, 1 }, { 1, 0, 0, 1, 1, 1, 1, 0, 1 }, { 1, 0, 0, 1, 1, 0, 0, 0, 1 }, { 1, 1, 0, 1, 0, 0, 0, 0, 1 }, { 1, 1, 0, 1, 0, 1, 1, 0, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1, 1 }, }; // Now, we define the initial and target points from which the // robot will be spawn and where it should go, respectively: int agentStartX = 1; int agentStartY = 4; int agentStopX = 7; int agentStopY = 4; // The robot is able to sense the environment though 8 sensors // that capture whether the robot is near a wall or not. Based // on the robot's current location, the sensors will return an // integer number representing which sensors have detected walls Func<int, int, int> getState = (int x, int y) => { int c1 = (map[y - 1, x - 1] != 0) ? 1 : 0; int c2 = (map[y - 1, x + 0] != 0) ? 1 : 0; int c3 = (map[y - 1, x + 1] != 0) ? 1 : 0; int c4 = (map[y + 0, x + 1] != 0) ? 1 : 0; int c5 = (map[y + 1, x + 1] != 0) ? 1 : 0; int c6 = (map[y + 1, x + 0] != 0) ? 1 : 0; int c7 = (map[y + 1, x - 1] != 0) ? 1 : 0; int c8 = (map[y + 0, x - 1] != 0) ? 1 : 0; return c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7); }; // The actions are the possible directions the robot can go: // // - case 0: go to north (up) // - case 1: go to east (right) // - case 2: go to south (down) // - case 3: go to west (left) // int learningIterations = 1000; double explorationRate = 0.5; double learningRate = 0.5; double moveReward = 0; double wallReward = -1; double goalReward = 1; // The function below specifies how the robot should perform an action given its // current position and an action number. This will cause the robot to update its // current X and Y locations given the direction (above) it was instructed to go: Func<int, int, int, Tuple<double, int, int>> doAction = (int currentX, int currentY, int action) => { // default reward is equal to moving reward double reward = moveReward; // moving direction int dx = 0, dy = 0; switch (action) { case 0: // go to north (up) dy = -1; break; case 1: // go to east (right) dx = 1; break; case 2: // go to south (down) dy = 1; break; case 3: // go to west (left) dx = -1; break; } int newX = currentX + dx; int newY = currentY + dy; // check new agent's coordinates if ((map[newY, newX] != 0) || (newX < 0) || (newX >= map.Columns()) || (newY < 0) || (newY >= map.Rows())) { // we found a wall or got outside of the world reward = wallReward; } else { currentX = newX; currentY = newY; // check if we found the goal if ((currentX == agentStopX) && (currentY == agentStopY)) reward = goalReward; } return Tuple.Create(reward, currentX, currentY); }; // After defining all those functions, we create a new Sarsa algorithm: var explorationPolicy = new EpsilonGreedyExploration(explorationRate); var tabuPolicy = new TabuSearchExploration(4, explorationPolicy); var sarsa = new Sarsa(256, 4, tabuPolicy); // curent coordinates of the agent int agentCurrentX = -1; int agentCurrentY = -1; bool needToStop = false; int iteration = 0; // loop while ((!needToStop) && (iteration < learningIterations)) { // set exploration rate for this iteration explorationPolicy.Epsilon = explorationRate - ((double)iteration / learningIterations) * explorationRate; // set learning rate for this iteration sarsa.LearningRate = learningRate - ((double)iteration / learningIterations) * learningRate; // clear tabu list tabuPolicy.ResetTabuList(); // reset agent's coordinates to the starting position agentCurrentX = agentStartX; agentCurrentY = agentStartY; // steps performed by agent to get to the goal int steps = 1; // previous state and action int previousState = getState(agentCurrentX, agentCurrentY); int previousAction = sarsa.GetAction(previousState); // update agent's current position and get his reward var r = doAction(agentCurrentX, agentCurrentY, previousAction); double reward = r.Item1; agentCurrentX = r.Item2; agentCurrentY = r.Item3; while ((!needToStop) && ((agentCurrentX != agentStopX) || (agentCurrentY != agentStopY))) { steps++; // set tabu action tabuPolicy.SetTabuAction((previousAction + 2) % 4, 1); // get agent's next state int nextState = getState(agentCurrentX, agentCurrentY); // get agent's next action int nextAction = sarsa.GetAction(nextState); // do learning of the agent - update his Q-function sarsa.UpdateState(previousState, previousAction, reward, nextState, nextAction); // update agent's new position and get his reward r = doAction(agentCurrentX, agentCurrentY, nextAction); reward = r.Item1; agentCurrentX = r.Item2; agentCurrentY = r.Item3; previousState = nextState; previousAction = nextAction; } if (!needToStop) { // update Q-function if terminal state was reached sarsa.UpdateState(previousState, previousAction, reward); } iteration++; } // The end position for the robot will be (7, 4): int finalPosX = agentCurrentX; // 7 int finalPosY = agentCurrentY; // 4;