Bidirectional_a_star.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. """
  2. Bidirectional_a_star 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  9. "/../../Search-based Planning/")
  10. from Search_2D import queue
  11. from Search_2D import plotting
  12. from Search_2D import env
  13. class BidirectionalAstar:
  14. def __init__(self, s_start, s_goal, heuristic_type):
  15. self.s_start, self.s_goal = s_start, s_goal
  16. self.heuristic_type = heuristic_type
  17. self.Env = env.Env() # class Env
  18. self.u_set = self.Env.motions # feasible input set
  19. self.obs = self.Env.obs # position of obstacles
  20. self.g_fore = {self.s_start: 0, self.s_goal: float("inf")} # cost to come: from s_start
  21. self.g_back = {self.s_goal: 0, self.s_start: float("inf")} # cost to come: form s_goal
  22. self.OPEN_fore = queue.QueuePrior() # OPEN set for foreward searching
  23. self.OPEN_fore.put(self.s_start,
  24. self.g_fore[self.s_start] + self.h(self.s_start, self.s_goal))
  25. self.OPEN_back = queue.QueuePrior() # OPEN set for backward searching
  26. self.OPEN_back.put(self.s_goal,
  27. self.g_back[self.s_goal] + self.h(self.s_goal, self.s_start))
  28. self.CLOSED_fore = [] # CLOSED set for foreward
  29. self.CLOSED_back = [] # CLOSED set for backward
  30. self.PARENT_fore = {self.s_start: self.s_start}
  31. self.PARENT_back = {self.s_goal: self.s_goal}
  32. def searching(self):
  33. s_meet = self.s_start
  34. while self.OPEN_fore and self.OPEN_back:
  35. # solve foreward-search
  36. s_fore = self.OPEN_fore.get()
  37. if s_fore in self.PARENT_back:
  38. s_meet = s_fore
  39. break
  40. self.CLOSED_fore.append(s_fore)
  41. for s_n in self.get_neighbor(s_fore):
  42. new_cost = self.g_fore[s_fore] + self.cost(s_fore, s_n)
  43. if s_n not in self.g_fore:
  44. self.g_fore[s_n] = float("inf")
  45. if new_cost < self.g_fore[s_n]:
  46. self.g_fore[s_n] = new_cost
  47. self.PARENT_fore[s_n] = s_fore
  48. self.OPEN_fore.put(s_n, new_cost + self.h(s_n, self.s_goal))
  49. # solve backward-search
  50. s_back = self.OPEN_back.get()
  51. if s_back in self.PARENT_fore:
  52. s_meet = s_back
  53. break
  54. self.CLOSED_back.append(s_back)
  55. for s_n in self.get_neighbor(s_back):
  56. new_cost = self.g_back[s_back] + self.cost(s_back, s_n)
  57. if s_n not in self.g_back:
  58. self.g_back[s_n] = float("inf")
  59. if new_cost < self.g_back[s_n]:
  60. self.g_back[s_n] = new_cost
  61. self.PARENT_back[s_n] = s_back
  62. self.OPEN_back.put(s_n, new_cost + self.h(s_n, self.s_start))
  63. return self.extract_path(s_meet), self.CLOSED_fore, self.CLOSED_back
  64. def get_neighbor(self, s):
  65. """
  66. find neighbors of state s that not in obstacles.
  67. :param s: state
  68. :return: neighbors
  69. """
  70. s_list = set()
  71. for u in self.u_set:
  72. s_next = tuple([s[i] + u[i] for i in range(2)])
  73. if s_next not in self.obs:
  74. s_list.add(s_next)
  75. return s_list
  76. def extract_path(self, s_meet):
  77. """
  78. extract path from start and goal
  79. :param s_meet: meet point of bi-direction a*
  80. :return: path
  81. """
  82. # extract path for foreward part
  83. path_fore = [s_meet]
  84. s = s_meet
  85. while True:
  86. s = self.PARENT_fore[s]
  87. path_fore.append(s)
  88. if s == self.s_start:
  89. break
  90. # extract path for backward part
  91. path_back = []
  92. s = s_meet
  93. while True:
  94. s = self.PARENT_back[s]
  95. path_back.append(s)
  96. if s == self.s_goal:
  97. break
  98. return list(reversed(path_fore)) + list(path_back)
  99. def h(self, s, goal):
  100. """
  101. Calculate heuristic value.
  102. :param s: current node (state)
  103. :param goal: goal node (state)
  104. :return: heuristic value
  105. """
  106. heuristic_type = self.heuristic_type
  107. if heuristic_type == "manhattan":
  108. return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
  109. else:
  110. return math.hypot(goal[0] - s[0], goal[1] - s[1])
  111. def cost(self, s_start, s_goal):
  112. """
  113. Calculate cost for this motion
  114. :param s_start: starting node
  115. :param s_goal: end node
  116. :return: cost for this motion
  117. :note: cost function could be more complicate!
  118. """
  119. if self.is_collision(s_start, s_goal):
  120. return float("inf")
  121. return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
  122. def is_collision(self, s_start, s_end):
  123. if s_start in self.obs or s_end in self.obs:
  124. return True
  125. if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
  126. if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
  127. s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  128. s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  129. else:
  130. s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  131. s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  132. if s1 in self.obs or s2 in self.obs:
  133. return True
  134. return False
  135. def main():
  136. x_start = (5, 5)
  137. x_goal = (45, 25)
  138. bastar = BidirectionalAstar(x_start, x_goal, "euclidean")
  139. plot = plotting.Plotting(x_start, x_goal)
  140. path, visited_fore, visited_back = bastar.searching()
  141. plot.animation_bi_astar(path, visited_fore, visited_back, "Bidirectional-A*") # animation
  142. if __name__ == '__main__':
  143. main()