zhm-real 5 anni fa
parent
commit
8549bd00d3

+ 22 - 15
Model-free Control/Q-learning.py

@@ -17,14 +17,13 @@ class QLEARNING:
     def __init__(self, x_start, x_goal):
         self.u_set = motion_model.motions                       # feasible input set
         self.xI, self.xG = x_start, x_goal
-        self.M = 500
+        self.M = 500                                            # iteration numbers
         self.gamma = 0.9                                        # discount factor
         self.alpha = 0.5
-        self.epsilon = 0.1
+        self.epsilon = 0.1                                      # epsilon error
         self.obs = env.obs_map()                                # position of obstacles
         self.lose = env.lose_map()                              # position of lose states
         self.name1 = "Qlearning, M=" + str(self.M)
-        self.name2 = "convergence of error"
 
 
     def Monte_Carlo(self):
@@ -34,23 +33,21 @@ class QLEARNING:
         :return: Q_table, policy
         """
 
-        Q_table = self.table_init()
-        policy = {}
-        count = 0
-
-        for k in range(self.M):
-            count += 1
-            x = self.state_init()
-            while x != self.xG:
-                u = self.epsilon_greedy(int(np.argmax(Q_table[x])), self.epsilon)
-                x_next = self.move_next(x, self.u_set[u])
-                reward = env.get_reward(x_next, self.lose)
+        Q_table = self.table_init()                                                 # Q_table initialization
+        policy = {}                                                                 # policy table
+
+        for k in range(self.M):                                                     # iterations
+            x = self.state_init()                                                   # initial state
+            while x != self.xG:                                                     # stop condition
+                u = self.epsilon_greedy(int(np.argmax(Q_table[x])), self.epsilon)   # epsilon_greedy policy
+                x_next = self.move_next(x, self.u_set[u])                           # next state
+                reward = env.get_reward(x_next, self.lose)                          # reward observed
                 Q_table[x][u] = (1 - self.alpha) * Q_table[x][u] + \
                                 self.alpha * (reward + self.gamma * max(Q_table[x_next]))
                 x = x_next
 
         for x in Q_table:
-            policy[x] = int(np.argmax(Q_table[x]))
+            policy[x] = int(np.argmax(Q_table[x]))                                  # extract policy
 
         return Q_table, policy
 
@@ -152,10 +149,20 @@ class QLEARNING:
                     tools.plot_dots(x)  # each state in optimal path
                     path.append(x)
         plt.show()
+        self.message()
 
         return path
 
 
+    def message(self):
+        print("starting state: ", self.xI)
+        print("goal state: ", self.xG)
+        print("iteration numbers: ", self.M)
+        print("discount factor: ", self.gamma)
+        print("epsilon error: ", self.epsilon)
+        print("alpha: ", self.alpha)
+
+
 if __name__ == '__main__':
     x_Start = (1, 1)
     x_Goal = (12, 1)

+ 20 - 14
Model-free Control/Sarsa.py

@@ -10,21 +10,19 @@ import motion_model
 
 import matplotlib.pyplot as plt
 import numpy as np
-import sys
 
 
 class SARSA:
     def __init__(self, x_start, x_goal):
         self.u_set = motion_model.motions                       # feasible input set
         self.xI, self.xG = x_start, x_goal
-        self.M = 500
+        self.M = 500                                            # iteration numbers
         self.gamma = 0.9                                        # discount factor
         self.alpha = 0.5
-        self.epsilon = 0.1
+        self.epsilon = 0.1                                      # epsilon error
         self.obs = env.obs_map()                                # position of obstacles
         self.lose = env.lose_map()                              # position of lose states
         self.name1 = "SARSA, M=" + str(self.M)
-        self.name2 = "convergence of error"
 
 
     def Monte_Carlo(self):
@@ -34,24 +32,22 @@ class SARSA:
         :return: Q_table, policy
         """
 
-        Q_table = self.table_init()
-        policy = {}
-        count = 0
+        Q_table = self.table_init()                                             # Q_table initialization
+        policy = {}                                                             # policy table
 
-        for k in range(self.M):
-            count += 1
-            x = self.state_init()
+        for k in range(self.M):                                                 # iterations
+            x = self.state_init()                                               # initial state
             u = self.epsilon_greedy(int(np.argmax(Q_table[x])), self.epsilon)
-            while x != self.xG:
-                x_next = self.move_next(x, self.u_set[u])
-                reward = env.get_reward(x_next, self.lose)
+            while x != self.xG:                                                 # stop condition
+                x_next = self.move_next(x, self.u_set[u])                       # next state
+                reward = env.get_reward(x_next, self.lose)                      # reward observed
                 u_next = self.epsilon_greedy(int(np.argmax(Q_table[x_next])), self.epsilon)
                 Q_table[x][u] = (1 - self.alpha) * Q_table[x][u] + \
                                 self.alpha * (reward + self.gamma * Q_table[x_next][u_next])
                 x, u = x_next, u_next
 
         for x in Q_table:
-            policy[x] = int(np.argmax(Q_table[x]))
+            policy[x] = int(np.argmax(Q_table[x]))                              # extract policy
 
         return Q_table, policy
 
@@ -153,10 +149,20 @@ class SARSA:
                     tools.plot_dots(x)  # each state in optimal path
                     path.append(x)
         plt.show()
+        self.message()
 
         return path
 
 
+    def message(self):
+        print("starting state: ", self.xI)
+        print("goal state: ", self.xG)
+        print("iteration numbers: ", self.M)
+        print("discount factor: ", self.gamma)
+        print("epsilon error: ", self.epsilon)
+        print("alpha: ", self.alpha)
+
+
 if __name__ == '__main__':
     x_Start = (1, 1)
     x_Goal = (12, 1)

BIN
Model-free Control/__pycache__/tools.cpython-37.pyc


+ 1 - 1
Model-free Control/tools.py

@@ -84,7 +84,7 @@ def plot_dots(x):
     :return: a plot
     """
 
-    plt.plot(x[0], x[1], linewidth='3', color='#808080', marker='o', ms = 24)    # plot dots for animation
+    plt.plot(x[0], x[1], linewidth='3', color='#808080', marker='o', ms = 23)    # plot dots for animation
     plt.gcf().canvas.mpl_connect('key_release_event',
                                  lambda event: [exit(0) if event.key == 'escape' else None])
     plt.pause(0.001)