bidirectional_a_star.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. """
  2. Bidirectional_a_star 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  9. "/../../Search-based Planning/")
  10. from Search_2D import queue
  11. from Search_2D import plotting
  12. from Search_2D import env
  13. class BidirectionalAstar:
  14. def __init__(self, s_start, s_goal, heuristic_type):
  15. self.s_start, self.s_goal = s_start, s_goal
  16. self.heuristic_type = heuristic_type
  17. self.Env = env.Env() # class Env
  18. self.u_set = self.Env.motions # feasible input set
  19. self.obs = self.Env.obs # position of obstacles
  20. self.g_fore = {self.s_start: 0, self.s_goal: float("inf")} # cost to come: from s_start
  21. self.g_back = {self.s_goal: 0, self.s_start: float("inf")} # cost to come: form s_goal
  22. self.OPEN_fore = queue.QueuePrior() # U set for foreward searching
  23. self.OPEN_fore.put(self.s_start,
  24. self.g_fore[self.s_start] + self.h(self.s_start, self.s_goal))
  25. self.OPEN_back = queue.QueuePrior() # U set for backward searching
  26. self.OPEN_back.put(self.s_goal,
  27. self.g_back[self.s_goal] + self.h(self.s_goal, self.s_start))
  28. self.CLOSED_fore = [] # CLOSED set for foreward
  29. self.CLOSED_back = [] # CLOSED set for backward
  30. self.PARENT_fore = {self.s_start: self.s_start}
  31. self.PARENT_back = {self.s_goal: self.s_goal}
  32. def searching(self):
  33. s_meet = self.s_start
  34. while self.OPEN_fore and self.OPEN_back:
  35. # solve foreward-search
  36. s_fore = self.OPEN_fore.get()
  37. if s_fore in self.PARENT_back:
  38. s_meet = s_fore
  39. break
  40. self.CLOSED_fore.append(s_fore)
  41. for s_n in self.get_neighbor(s_fore):
  42. new_cost = self.g_fore[s_fore] + self.cost(s_fore, s_n)
  43. if s_n not in self.g_fore:
  44. self.g_fore[s_n] = float("inf")
  45. if new_cost < self.g_fore[s_n]:
  46. self.g_fore[s_n] = new_cost
  47. self.PARENT_fore[s_n] = s_fore
  48. self.OPEN_fore.put(s_n, new_cost + self.h(s_n, self.s_goal))
  49. # solve backward-search
  50. s_back = self.OPEN_back.get()
  51. if s_back in self.PARENT_fore:
  52. s_meet = s_back
  53. break
  54. self.CLOSED_back.append(s_back)
  55. for s_n in self.get_neighbor(s_back):
  56. new_cost = self.g_back[s_back] + self.cost(s_back, s_n)
  57. if s_n not in self.g_back:
  58. self.g_back[s_n] = float("inf")
  59. if new_cost < self.g_back[s_n]:
  60. self.g_back[s_n] = new_cost
  61. self.PARENT_back[s_n] = s_back
  62. self.OPEN_back.put(s_n, new_cost + self.h(s_n, self.s_start))
  63. return self.extract_path(s_meet), self.CLOSED_fore, self.CLOSED_back
  64. def get_neighbor(self, s):
  65. """
  66. find neighbors of state s that not in obstacles.
  67. :param s: state
  68. :return: neighbors
  69. """
  70. s_list = set()
  71. for u in self.u_set:
  72. s_next = tuple([s[i] + u[i] for i in range(2)])
  73. if s_next not in self.obs:
  74. s_list.add(s_next)
  75. return s_list
  76. def extract_path(self, s_meet):
  77. """
  78. extract path from start and goal
  79. :param s_meet: meet point of bi-direction a*
  80. :return: path
  81. """
  82. # extract path for foreward part
  83. path_fore = [s_meet]
  84. s = s_meet
  85. while True:
  86. s = self.PARENT_fore[s]
  87. path_fore.append(s)
  88. if s == self.s_start:
  89. break
  90. # extract path for backward part
  91. path_back = []
  92. s = s_meet
  93. while True:
  94. s = self.PARENT_back[s]
  95. path_back.append(s)
  96. if s == self.s_goal:
  97. break
  98. return list(reversed(path_fore)) + list(path_back)
  99. def h(self, s, goal):
  100. """
  101. Calculate heuristic value.
  102. :param s: current node (state)
  103. :param goal: goal node (state)
  104. :return: heuristic value
  105. """
  106. heuristic_type = self.heuristic_type
  107. if heuristic_type == "manhattan":
  108. return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
  109. else:
  110. return math.hypot(goal[0] - s[0], goal[1] - s[1])
  111. @staticmethod
  112. def cost(s_start, s_goal):
  113. """
  114. Calculate cost for this motion
  115. :param s_start: starting node
  116. :param s_goal: end node
  117. :return: cost for this motion
  118. :note: cost function could be more complicate!
  119. """
  120. return 1
  121. def main():
  122. x_start = (5, 5)
  123. x_goal = (45, 25)
  124. bastar = BidirectionalAstar(x_start, x_goal, "euclidean")
  125. plot = plotting.Plotting(x_start, x_goal)
  126. path, visited_fore, visited_back = bastar.searching()
  127. plot.animation_bi_astar(path, visited_fore, visited_back, "Bidirectional-A*") # animation
  128. if __name__ == '__main__':
  129. main()