|
@@ -1,9 +1,3 @@
|
|
|
-#!/usr/bin/env python3
|
|
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
|
|
-"""
|
|
|
|
|
-@author: huiming zhou
|
|
|
|
|
-"""
|
|
|
|
|
-
|
|
|
|
|
import env
|
|
import env
|
|
|
import plotting
|
|
import plotting
|
|
|
import motion_model
|
|
import motion_model
|
|
@@ -11,20 +5,21 @@ import motion_model
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import sys
|
|
import sys
|
|
|
|
|
|
|
|
|
|
+
|
|
|
class Value_iteration:
|
|
class Value_iteration:
|
|
|
def __init__(self, x_start, x_goal):
|
|
def __init__(self, x_start, x_goal):
|
|
|
self.xI, self.xG = x_start, x_goal
|
|
self.xI, self.xG = x_start, x_goal
|
|
|
- self.e = 0.001 # threshold for convergence
|
|
|
|
|
- self.gamma = 0.9 # discount factor
|
|
|
|
|
|
|
+ self.e = 0.001 # threshold for convergence
|
|
|
|
|
+ self.gamma = 0.9 # discount factor
|
|
|
|
|
|
|
|
- self.env = env.Env(self.xI, self.xG) # class Env
|
|
|
|
|
- self.motion = motion_model.Motion_model(self.xI, self.xG) # class Motion_model
|
|
|
|
|
- self.plotting = plotting.Plotting(self.xI, self.xG) # class Plotting
|
|
|
|
|
|
|
+ self.env = env.Env(self.xI, self.xG) # class Env
|
|
|
|
|
+ self.motion = motion_model.Motion_model(self.xI, self.xG) # class Motion_model
|
|
|
|
|
+ self.plotting = plotting.Plotting(self.xI, self.xG) # class Plotting
|
|
|
|
|
|
|
|
- self.u_set = self.env.motions # feasible input set
|
|
|
|
|
- self.stateSpace = self.env.stateSpace # state space
|
|
|
|
|
- self.obs = self.env.obs_map() # position of obstacles
|
|
|
|
|
- self.lose = self.env.lose_map() # position of lose states
|
|
|
|
|
|
|
+ self.u_set = self.env.motions # feasible input set
|
|
|
|
|
+ self.stateSpace = self.env.stateSpace # state space
|
|
|
|
|
+ self.obs = self.env.obs_map() # position of obstacles
|
|
|
|
|
+ self.lose = self.env.lose_map() # position of lose states
|
|
|
|
|
|
|
|
self.name1 = "value_iteration, gamma=" + str(self.gamma)
|
|
self.name1 = "value_iteration, gamma=" + str(self.gamma)
|
|
|
self.name2 = "converge process, e=" + str(self.e)
|
|
self.name2 = "converge process, e=" + str(self.e)
|
|
@@ -34,7 +29,6 @@ class Value_iteration:
|
|
|
self.plotting.animation(self.path, self.name1)
|
|
self.plotting.animation(self.path, self.name1)
|
|
|
self.plotting.plot_diff(self.diff, self.name2)
|
|
self.plotting.plot_diff(self.diff, self.name2)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def iteration(self, xI, xG):
|
|
def iteration(self, xI, xG):
|
|
|
"""
|
|
"""
|
|
|
value_iteration.
|
|
value_iteration.
|
|
@@ -42,36 +36,35 @@ class Value_iteration:
|
|
|
:return: converged value table, optimal policy and variation of difference,
|
|
:return: converged value table, optimal policy and variation of difference,
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- value_table = {} # value table
|
|
|
|
|
- policy = {} # policy
|
|
|
|
|
- diff = [] # maximum difference between two successive iteration
|
|
|
|
|
- delta = sys.maxsize # initialize maximum difference
|
|
|
|
|
- count = 0 # iteration times
|
|
|
|
|
|
|
+ value_table = {} # value table
|
|
|
|
|
+ policy = {} # policy
|
|
|
|
|
+ diff = [] # maximum difference between two successive iteration
|
|
|
|
|
+ delta = sys.maxsize # initialize maximum difference
|
|
|
|
|
+ count = 0 # iteration times
|
|
|
|
|
|
|
|
- for x in self.stateSpace: # initialize value table for feasible states
|
|
|
|
|
|
|
+ for x in self.stateSpace: # initialize value table for feasible states
|
|
|
value_table[x] = 0
|
|
value_table[x] = 0
|
|
|
|
|
|
|
|
- while delta > self.e: # converged condition
|
|
|
|
|
|
|
+ while delta > self.e: # converged condition
|
|
|
count += 1
|
|
count += 1
|
|
|
x_value = 0
|
|
x_value = 0
|
|
|
for x in self.stateSpace:
|
|
for x in self.stateSpace:
|
|
|
if x not in xG:
|
|
if x not in xG:
|
|
|
value_list = []
|
|
value_list = []
|
|
|
for u in self.u_set:
|
|
for u in self.u_set:
|
|
|
- [x_next, p_next] = self.motion.move_next(x, u) # recall motion model
|
|
|
|
|
- value_list.append(self.cal_Q_value(x_next, p_next, value_table)) # cal Q value
|
|
|
|
|
- policy[x] = self.u_set[int(np.argmax(value_list))] # update policy
|
|
|
|
|
- v_diff = abs(value_table[x] - max(value_list)) # maximum difference
|
|
|
|
|
- value_table[x] = max(value_list) # update value table
|
|
|
|
|
|
|
+ [x_next, p_next] = self.motion.move_next(x, u) # recall motion model
|
|
|
|
|
+ value_list.append(self.cal_Q_value(x_next, p_next, value_table)) # cal Q value
|
|
|
|
|
+ policy[x] = self.u_set[int(np.argmax(value_list))] # update policy
|
|
|
|
|
+ v_diff = abs(value_table[x] - max(value_list)) # maximum difference
|
|
|
|
|
+ value_table[x] = max(value_list) # update value table
|
|
|
x_value = max(x_value, v_diff)
|
|
x_value = max(x_value, v_diff)
|
|
|
- delta = x_value # update delta
|
|
|
|
|
|
|
+ delta = x_value # update delta
|
|
|
diff.append(delta)
|
|
diff.append(delta)
|
|
|
|
|
|
|
|
- self.message(count) # print messages
|
|
|
|
|
|
|
+ self.message(count) # print messages
|
|
|
|
|
|
|
|
return value_table, policy, diff
|
|
return value_table, policy, diff
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def cal_Q_value(self, x, p, table):
|
|
def cal_Q_value(self, x, p, table):
|
|
|
"""
|
|
"""
|
|
|
cal Q_value.
|
|
cal Q_value.
|
|
@@ -83,13 +76,12 @@ class Value_iteration:
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
value = 0
|
|
value = 0
|
|
|
- reward = self.env.get_reward(x) # get reward of next state
|
|
|
|
|
|
|
+ reward = self.env.get_reward(x) # get reward of next state
|
|
|
for i in range(len(x)):
|
|
for i in range(len(x)):
|
|
|
- value += p[i] * (reward[i] + self.gamma * table[x[i]]) # cal Q-value
|
|
|
|
|
|
|
+ value += p[i] * (reward[i] + self.gamma * table[x[i]]) # cal Q-value
|
|
|
|
|
|
|
|
return value
|
|
return value
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def extract_path(self, xI, xG, policy):
|
|
def extract_path(self, xI, xG, policy):
|
|
|
"""
|
|
"""
|
|
|
extract path from converged policy.
|
|
extract path from converged policy.
|
|
@@ -112,7 +104,6 @@ class Value_iteration:
|
|
|
x = x_next
|
|
x = x_next
|
|
|
return path
|
|
return path
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def message(self, count):
|
|
def message(self, count):
|
|
|
"""
|
|
"""
|
|
|
print important message.
|
|
print important message.
|
|
@@ -129,7 +120,7 @@ class Value_iteration:
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
|
- x_Start = (5, 5) # starting state
|
|
|
|
|
- x_Goal = [(49, 5), (49, 25)] # goal states
|
|
|
|
|
|
|
+ x_Start = (5, 5) # starting state
|
|
|
|
|
+ x_Goal = [(49, 5), (49, 25)] # goal states
|
|
|
|
|
|
|
|
VI = Value_iteration(x_Start, x_Goal)
|
|
VI = Value_iteration(x_Start, x_Goal)
|