Astar.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. A_star 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import heapq
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Search_based_Planning/")
  11. from Search_2D import plotting, env
  12. class AStar:
  13. """AStar set the cost + heuristics as the priority
  14. """
  15. def __init__(self, s_start, s_goal, heuristic_type):
  16. self.s_start = s_start
  17. self.s_goal = s_goal
  18. self.heuristic_type = heuristic_type
  19. self.Env = env.Env() # class Env
  20. self.u_set = self.Env.motions # feasible input set
  21. self.obs = self.Env.obs # position of obstacles
  22. self.OPEN = [] # priority queue / OPEN set
  23. self.CLOSED = [] # CLOSED set / VISITED order
  24. self.PARENT = dict() # recorded parent
  25. self.g = dict() # cost to come
  26. def searching(self):
  27. """
  28. A_star Searching.
  29. :return: path, visited order
  30. """
  31. self.PARENT[self.s_start] = self.s_start
  32. self.g[self.s_start] = 0
  33. self.g[self.s_goal] = math.inf
  34. heapq.heappush(self.OPEN,
  35. (self.f_value(self.s_start), self.s_start))
  36. while self.OPEN:
  37. _, s = heapq.heappop(self.OPEN)
  38. self.CLOSED.append(s)
  39. if s == self.s_goal: # stop condition
  40. break
  41. for s_n in self.get_neighbor(s):
  42. new_cost = self.g[s] + self.cost(s, s_n)
  43. if s_n not in self.g:
  44. self.g[s_n] = math.inf
  45. if new_cost < self.g[s_n]: # conditions for updating Cost
  46. self.g[s_n] = new_cost
  47. self.PARENT[s_n] = s
  48. heapq.heappush(self.OPEN, (self.f_value(s_n), s_n))
  49. return self.extract_path(self.PARENT), self.CLOSED
  50. def searching_repeated_astar(self, e):
  51. """
  52. repeated A*.
  53. :param e: weight of A*
  54. :return: path and visited order
  55. """
  56. path, visited = [], []
  57. while e >= 1:
  58. p_k, v_k = self.repeated_searching(self.s_start, self.s_goal, e)
  59. path.append(p_k)
  60. visited.append(v_k)
  61. e -= 0.5
  62. return path, visited
  63. def repeated_searching(self, s_start, s_goal, e):
  64. """
  65. run A* with weight e.
  66. :param s_start: starting state
  67. :param s_goal: goal state
  68. :param e: weight of a*
  69. :return: path and visited order.
  70. """
  71. g = {s_start: 0, s_goal: float("inf")}
  72. PARENT = {s_start: s_start}
  73. OPEN = []
  74. CLOSED = []
  75. heapq.heappush(OPEN,
  76. (g[s_start] + e * self.heuristic(s_start), s_start))
  77. while OPEN:
  78. _, s = heapq.heappop(OPEN)
  79. CLOSED.append(s)
  80. if s == s_goal:
  81. break
  82. for s_n in self.get_neighbor(s):
  83. new_cost = g[s] + self.cost(s, s_n)
  84. if s_n not in g:
  85. g[s_n] = math.inf
  86. if new_cost < g[s_n]: # conditions for updating Cost
  87. g[s_n] = new_cost
  88. PARENT[s_n] = s
  89. heapq.heappush(OPEN, (g[s_n] + e * self.heuristic(s_n), s_n))
  90. return self.extract_path(PARENT), CLOSED
  91. def get_neighbor(self, s):
  92. """
  93. find neighbors of state s that not in obstacles.
  94. :param s: state
  95. :return: neighbors
  96. """
  97. return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
  98. def cost(self, s_start, s_goal):
  99. """
  100. Calculate Cost for this motion
  101. :param s_start: starting node
  102. :param s_goal: end node
  103. :return: Cost for this motion
  104. :note: Cost function could be more complicate!
  105. """
  106. if self.is_collision(s_start, s_goal):
  107. return math.inf
  108. return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
  109. def is_collision(self, s_start, s_end):
  110. """
  111. check if the line segment (s_start, s_end) is collision.
  112. :param s_start: start node
  113. :param s_end: end node
  114. :return: True: is collision / False: not collision
  115. """
  116. if s_start in self.obs or s_end in self.obs:
  117. return True
  118. if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
  119. if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
  120. s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  121. s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  122. else:
  123. s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  124. s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  125. if s1 in self.obs or s2 in self.obs:
  126. return True
  127. return False
  128. def f_value(self, s):
  129. """
  130. f = g + h. (g: Cost to come, h: heuristic value)
  131. :param s: current state
  132. :return: f
  133. """
  134. return self.g[s] + self.heuristic(s)
  135. def extract_path(self, PARENT):
  136. """
  137. Extract the path based on the PARENT set.
  138. :return: The planning path
  139. """
  140. path = [self.s_goal]
  141. s = self.s_goal
  142. while True:
  143. s = PARENT[s]
  144. path.append(s)
  145. if s == self.s_start:
  146. break
  147. return list(path)
  148. def heuristic(self, s):
  149. """
  150. Calculate heuristic.
  151. :param s: current node (state)
  152. :return: heuristic function value
  153. """
  154. heuristic_type = self.heuristic_type # heuristic type
  155. goal = self.s_goal # goal node
  156. if heuristic_type == "manhattan":
  157. return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
  158. else:
  159. return math.hypot(goal[0] - s[0], goal[1] - s[1])
  160. def main():
  161. s_start = (5, 5)
  162. s_goal = (45, 25)
  163. astar = AStar(s_start, s_goal, "euclidean")
  164. plot = plotting.Plotting(s_start, s_goal)
  165. path, visited = astar.searching()
  166. plot.animation(path, visited, "A*") # animation
  167. # path, visited = astar.searching_repeated_astar(2.5) # initial weight e = 2.5
  168. # plot.animation_ara_star(path, visited, "Repeated A*")
  169. if __name__ == '__main__':
  170. main()