LRTAstar.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. LRTA_star 2D (Learning Real-time A*)
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import copy
  8. import matplotlib.pyplot as plt
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Search-based Planning/")
  11. from Search_2D import queue
  12. from Search_2D import plotting
  13. from Search_2D import env
  14. class LrtAstarN:
  15. def __init__(self, x_start, x_goal, N, heuristic_type):
  16. self.xI, self.xG = x_start, x_goal
  17. self.heuristic_type = heuristic_type
  18. self.Env = env.Env()
  19. self.u_set = self.Env.motions # feasible input set
  20. self.obs = self.Env.obs # position of obstacles
  21. self.N = N # number of expand nodes each iteration
  22. self.visited = [] # order of visited nodes in planning
  23. self.path = [] # path of each iteration
  24. def searching(self):
  25. s_start = self.xI # initialize start node
  26. while True:
  27. OPEN, CLOSED = self.Astar(s_start, self.N) # OPEN, CLOSED sets in each iteration
  28. if OPEN == "FOUND": # reach the goal node
  29. self.path.append(CLOSED)
  30. break
  31. h_value = self.iteration(CLOSED) # h_value table of CLOSED nodes
  32. s_start, path_k = self.extract_path_in_CLOSE(s_start, h_value) # s_start -> expected node in OPEN set
  33. self.path.append(path_k)
  34. def extract_path_in_CLOSE(self, s_start, h_value):
  35. path = [s_start]
  36. s = s_start
  37. while True:
  38. h_list = {}
  39. for u in self.u_set:
  40. s_next = tuple([s[i] + u[i] for i in range(2)])
  41. if s_next not in self.obs:
  42. if s_next in h_value:
  43. h_list[s_next] = h_value[s_next]
  44. else:
  45. h_list[s_next] = self.h(s_next)
  46. s_key = min(h_list, key=h_list.get) # move to the smallest node with min h_value
  47. path.append(s_key) # generate path
  48. s = s_key # use end of this iteration as the start of next
  49. if s_key not in h_value: # reach the expected node in OPEN set
  50. return s_key, path
  51. def iteration(self, CLOSED):
  52. h_value = {}
  53. for s in CLOSED:
  54. h_value[s] = float("inf") # initialize h_value of CLOSED nodes
  55. while True:
  56. h_value_rec = copy.deepcopy(h_value)
  57. for s in CLOSED:
  58. h_list = []
  59. for u in self.u_set:
  60. s_next = tuple([s[i] + u[i] for i in range(2)])
  61. if s_next not in self.obs:
  62. if s_next not in CLOSED:
  63. h_list.append(self.get_cost(s, s_next) + self.h(s_next))
  64. else:
  65. h_list.append(self.get_cost(s, s_next) + h_value[s_next])
  66. h_value[s] = min(h_list) # update h_value of current node
  67. if h_value == h_value_rec: # h_value table converged
  68. return h_value
  69. def Astar(self, x_start, N):
  70. OPEN = queue.QueuePrior() # OPEN set
  71. OPEN.put(x_start, self.h(x_start))
  72. CLOSED = set() # CLOSED set
  73. g_table = {x_start: 0, self.xG: float("inf")} # cost to come
  74. PARENT = {x_start: x_start} # relations
  75. visited = [] # order of visited nodes
  76. count = 0 # counter
  77. while not OPEN.empty():
  78. count += 1
  79. s = OPEN.get()
  80. CLOSED.add(s)
  81. visited.append(s)
  82. if s == self.xG: # reach the goal node
  83. self.visited.append(visited)
  84. return "FOUND", self.extract_path(x_start, PARENT)
  85. for u in self.u_set:
  86. s_next = tuple([s[i] + u[i] for i in range(len(s))])
  87. if s_next not in self.obs and s_next not in CLOSED:
  88. new_cost = g_table[s] + self.get_cost(s, u)
  89. if s_next not in g_table:
  90. g_table[s_next] = float("inf")
  91. if new_cost < g_table[s_next]: # conditions for updating cost
  92. g_table[s_next] = new_cost
  93. PARENT[s_next] = s
  94. OPEN.put(s_next, g_table[s_next] + self.h(s_next))
  95. if count == N: # expand needed CLOSED nodes
  96. break
  97. self.visited.append(visited) # visited nodes in each iteration
  98. return OPEN, CLOSED
  99. def extract_path(self, x_start, parent):
  100. """
  101. Extract the path based on the relationship of nodes.
  102. :return: The planning path
  103. """
  104. path_back = [self.xG]
  105. x_current = self.xG
  106. while True:
  107. x_current = parent[x_current]
  108. path_back.append(x_current)
  109. if x_current == x_start:
  110. break
  111. return list(reversed(path_back))
  112. def h(self, s):
  113. heuristic_type = self.heuristic_type
  114. goal = self.xG
  115. if heuristic_type == "manhattan":
  116. return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
  117. elif heuristic_type == "euclidean":
  118. return ((goal[0] - s[0]) ** 2 + (goal[1] - s[1]) ** 2) ** (1 / 2)
  119. else:
  120. print("Please choose right heuristic type!")
  121. @staticmethod
  122. def get_cost(x, u):
  123. """
  124. Calculate cost for this motion
  125. :param x: current node
  126. :param u: input
  127. :return: cost for this motion
  128. :note: cost function could be more complicate!
  129. """
  130. return 1
  131. def main():
  132. x_start = (10, 5)
  133. x_goal = (45, 25)
  134. lrta = LrtAstarN(x_start, x_goal, 220, "euclidean")
  135. plot = plotting.Plotting(x_start, x_goal)
  136. fig_name = "Learning Real-time A* (LRTA*)"
  137. lrta.searching()
  138. plot.animation_lrta(lrta.path, lrta.visited, fig_name)
  139. if __name__ == '__main__':
  140. main()