LRTA_star.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. LRTA_star_N 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import copy
  8. import matplotlib.pyplot as plt
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Search-based Planning/")
  11. from Search_2D import queue
  12. from Search_2D import plotting
  13. from Search_2D import env
  14. class LrtAstarN:
  15. def __init__(self, x_start, x_goal, heuristic_type):
  16. self.xI, self.xG = x_start, x_goal
  17. self.heuristic_type = heuristic_type
  18. self.Env = env.Env() # class Env
  19. self.u_set = self.Env.motions # feasible input set
  20. self.obs = self.Env.obs # position of obstacles
  21. self.N = 150
  22. self.visited = []
  23. def searching(self):
  24. s_start = self.xI
  25. path = []
  26. count = 0
  27. while True:
  28. # if count == 2:
  29. # return path
  30. # count += 1
  31. h_table = {}
  32. OPEN, CLOSED = self.Astar(s_start, self.N)
  33. if OPEN == "end":
  34. path.append(CLOSED)
  35. return path
  36. for x in CLOSED:
  37. h_table[x] = 2000
  38. while True:
  39. h_table_rec = copy.deepcopy(h_table)
  40. for s in CLOSED:
  41. h_list = []
  42. for u in self.u_set:
  43. s_next = tuple([s[i] + u[i] for i in range(2)])
  44. if s_next not in self.obs:
  45. if s_next not in CLOSED:
  46. h_list.append(self.get_cost(s, s_next) + self.h(s_next))
  47. else:
  48. h_list.append(self.get_cost(s, s_next) + h_table[s_next])
  49. h_table[s] = min(h_list)
  50. if h_table == h_table_rec:
  51. break
  52. path_k = [s_start]
  53. x = s_start
  54. while True:
  55. h_xlist = {}
  56. for u in self.u_set:
  57. x_next = tuple([x[i] + u[i] for i in range(2)])
  58. if x_next not in self.obs:
  59. if x_next in CLOSED:
  60. h_xlist[x_next] = h_table[x_next]
  61. else:
  62. h_xlist[x_next] = self.h(x_next)
  63. s_key = min(h_xlist, key=h_xlist.get)
  64. path_k.append(s_key)
  65. x = s_key
  66. if s_key not in CLOSED:
  67. break
  68. s_start = path_k[-1]
  69. path.append(path_k)
  70. def Astar(self, x_start, N):
  71. OPEN = queue.QueuePrior()
  72. OPEN.put(x_start, self.h(x_start))
  73. CLOSED = set()
  74. g_table = {x_start: 0, self.xG: float("inf")}
  75. parent = {x_start: x_start}
  76. count = 0
  77. visited = []
  78. while not OPEN.empty():
  79. count += 1
  80. s = OPEN.get()
  81. CLOSED.add(s)
  82. visited.append(s)
  83. if s == self.xG:
  84. path = self.extract_path(x_start, parent)
  85. self.visited.append(visited)
  86. return "end", path
  87. for u in self.u_set:
  88. s_next = tuple([s[i] + u[i] for i in range(len(s))])
  89. if s_next not in self.obs and s_next not in CLOSED:
  90. new_cost = g_table[s] + self.get_cost(s, u)
  91. if s_next not in g_table:
  92. g_table[s_next] = float("inf")
  93. if new_cost < g_table[s_next]: # conditions for updating cost
  94. g_table[s_next] = new_cost
  95. parent[s_next] = s
  96. OPEN.put(s_next, g_table[s_next] + self.h(s_next))
  97. if count == N:
  98. break
  99. self.visited.append(visited)
  100. return OPEN, CLOSED
  101. def extract_path(self, x_start, parent):
  102. """
  103. Extract the path based on the relationship of nodes.
  104. :return: The planning path
  105. """
  106. path_back = [self.xG]
  107. x_current = self.xG
  108. while True:
  109. x_current = parent[x_current]
  110. path_back.append(x_current)
  111. if x_current == x_start:
  112. break
  113. return list(reversed(path_back))
  114. def h(self, s):
  115. heuristic_type = self.heuristic_type
  116. goal = self.xG
  117. if heuristic_type == "manhattan":
  118. return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
  119. elif heuristic_type == "euclidean":
  120. return ((goal[0] - s[0]) ** 2 + (goal[1] - s[1]) ** 2) ** (1 / 2)
  121. else:
  122. print("Please choose right heuristic type!")
  123. @staticmethod
  124. def get_cost(x, u):
  125. """
  126. Calculate cost for this motion
  127. :param x: current node
  128. :param u: input
  129. :return: cost for this motion
  130. :note: cost function could be more complicate!
  131. """
  132. return 1
  133. def main():
  134. x_start = (10, 5) # Starting node
  135. x_goal = (45, 25) # Goal node
  136. lrtastarn = LrtAstarN(x_start, x_goal, "euclidean")
  137. plot = plotting.Plotting(x_start, x_goal)
  138. path = lrtastarn.searching()
  139. plot.plot_grid("LRTA_star_N")
  140. for k in range(len(path)):
  141. plot.plot_visited(lrtastarn.visited[k])
  142. plt.pause(0.5)
  143. plot.plot_path(path[k])
  144. plt.pause(0.5)
  145. plt.pause(0.5)
  146. path_u = []
  147. for i in range(len(path)):
  148. for j in range(len(path[i])):
  149. path_u.append(path[i][j])
  150. plot.plot_path(path_u)
  151. plt.pause(0.2)
  152. plt.show()
  153. if __name__ == '__main__':
  154. main()