D_star.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """
  2. D_star 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import matplotlib.pyplot as plt
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Search-based Planning/")
  11. from Search_2D import plotting
  12. from Search_2D import env
  13. class Dstar:
  14. def __init__(self, s_start, s_goal):
  15. self.s_start, self.s_goal = s_start, s_goal
  16. self.Env = env.Env()
  17. self.Plot = plotting.Plotting(self.s_start, self.s_goal)
  18. self.u_set = self.Env.motions
  19. self.obs = self.Env.obs
  20. self.x = self.Env.x_range
  21. self.y = self.Env.y_range
  22. self.fig = plt.figure()
  23. self.OPEN = set()
  24. self.t = {}
  25. self.PARENT = {}
  26. self.h = {}
  27. self.k = {}
  28. self.path = []
  29. self.visited = set()
  30. self.count = 0
  31. for i in range(self.Env.x_range):
  32. for j in range(self.Env.y_range):
  33. self.t[(i, j)] = 'NEW'
  34. self.k[(i, j)] = 0.0
  35. self.h[(i, j)] = float("inf")
  36. self.PARENT[(i, j)] = None
  37. self.h[self.s_goal] = 0.0
  38. def run(self, s_start, s_end):
  39. self.insert(s_end, 0)
  40. while True:
  41. self.process_state()
  42. if self.t[s_start] == 'CLOSED':
  43. break
  44. self.path = self.extract_path(s_start, s_end)
  45. self.Plot.plot_grid("Dynamic A* (D*)")
  46. self.plot_path(self.path)
  47. self.fig.canvas.mpl_connect('button_press_event', self.on_press)
  48. plt.show()
  49. def on_press(self, event):
  50. x, y = event.xdata, event.ydata
  51. if x < 0 or x > self.x - 1 or y < 0 or y > self.y - 1:
  52. print("Please choose right area!")
  53. else:
  54. x, y = int(x), int(y)
  55. print("Add obstacle at: x =", x, ",", "y =", y)
  56. self.obs.add((x, y))
  57. plt.plot(x, y, 'sk')
  58. s = self.s_start
  59. self.visited = set()
  60. self.count += 1
  61. while s != self.s_goal:
  62. if self.is_collision(s, self.PARENT[s]):
  63. self.modify(s)
  64. continue
  65. s = self.PARENT[s]
  66. self.path = self.extract_path(self.s_start, self.s_goal)
  67. self.plot_visited(self.visited)
  68. self.plot_path(self.path)
  69. self.fig.canvas.draw_idle()
  70. def extract_path(self, s_start, s_end):
  71. path = [s_start]
  72. s = s_start
  73. while True:
  74. s = self.PARENT[s]
  75. path.append(s)
  76. if s == s_end:
  77. return path
  78. def process_state(self):
  79. s = self.min_state()
  80. self.visited.add(s)
  81. if s is None:
  82. return -1
  83. k_old = self.get_k_min()
  84. self.delete(s)
  85. if k_old < self.h[s]:
  86. for s_n in self.get_neighbor(s):
  87. if self.h[s_n] <= k_old and self.h[s] > self.h[s_n] + self.cost(s_n, s):
  88. self.PARENT[s] = s_n
  89. self.h[s] = self.h[s_n] + self.cost(s_n, s)
  90. if k_old == self.h[s]:
  91. for s_n in self.get_neighbor(s):
  92. if self.t[s_n] == 'NEW' or \
  93. (self.PARENT[s_n] == s and self.h[s_n] != self.h[s] + self.cost(s, s_n)) or \
  94. (self.PARENT[s_n] != s and self.h[s_n] > self.h[s] + self.cost(s, s_n)):
  95. self.PARENT[s_n] = s
  96. self.insert(s_n, self.h[s] + self.cost(s, s_n))
  97. else:
  98. for s_n in self.get_neighbor(s):
  99. if self.t[s_n] == 'NEW' or \
  100. (self.PARENT[s_n] == s and self.h[s_n] != self.h[s] + self.cost(s, s_n)):
  101. self.PARENT[s_n] = s
  102. self.insert(s_n, self.h[s] + self.cost(s, s_n))
  103. else:
  104. if self.PARENT[s_n] != s and self.h[s_n] > self.h[s] + self.cost(s, s_n):
  105. self.insert(s, self.h[s])
  106. else:
  107. if self.PARENT[s_n] != s and \
  108. self.h[s] > self.h[s_n] + self.cost(s_n, s) and \
  109. self.t[s_n] == 'CLOSED' and \
  110. self.h[s_n] > k_old:
  111. self.insert(s_n, self.h[s_n])
  112. return self.get_k_min()
  113. def min_state(self):
  114. if not self.OPEN:
  115. return None
  116. return min(self.OPEN, key=lambda x: self.k[x])
  117. def get_k_min(self):
  118. if not self.OPEN:
  119. return -1
  120. return min([self.k[x] for x in self.OPEN])
  121. def insert(self, s, h_new):
  122. if self.t[s] == 'NEW':
  123. self.k[s] = h_new
  124. elif self.t[s] == 'OPEN':
  125. self.k[s] = min(self.k[s], h_new)
  126. elif self.t[s] == 'CLOSED':
  127. self.k[s] = min(self.h[s], h_new)
  128. self.h[s] = h_new
  129. self.t[s] = 'OPEN'
  130. self.OPEN.add(s)
  131. def delete(self, s):
  132. if self.t[s] == 'OPEN':
  133. self.t[s] = 'CLOSED'
  134. self.OPEN.remove(s)
  135. def modify(self, s):
  136. self.modify_cost(s)
  137. while True:
  138. k_min = self.process_state()
  139. if k_min >= self.h[s]:
  140. break
  141. def modify_cost(self, s):
  142. if self.t[s] == 'CLOSED':
  143. self.insert(s, self.h[self.PARENT[s]] + self.cost(s, self.PARENT[s]))
  144. def get_neighbor(self, s):
  145. nei_list = set()
  146. for u in self.u_set:
  147. s_next = tuple([s[i] + u[i] for i in range(2)])
  148. if s_next not in self.obs:
  149. nei_list.add(s_next)
  150. return nei_list
  151. def cost(self, s_start, s_goal):
  152. """
  153. Calculate Cost for this motion
  154. :param s_start: starting node
  155. :param s_goal: end node
  156. :return: Cost for this motion
  157. :note: Cost function could be more complicate!
  158. """
  159. if self.is_collision(s_start, s_goal):
  160. return float("inf")
  161. return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
  162. def is_collision(self, s_start, s_end):
  163. if s_start in self.obs or s_end in self.obs:
  164. return True
  165. if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
  166. if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
  167. s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  168. s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  169. else:
  170. s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  171. s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  172. if s1 in self.obs or s2 in self.obs:
  173. return True
  174. return False
  175. def plot_path(self, path):
  176. px = [x[0] for x in path]
  177. py = [x[1] for x in path]
  178. plt.plot(px, py, linewidth=2)
  179. plt.plot(self.s_start[0], self.s_start[1], "bs")
  180. plt.plot(self.s_goal[0], self.s_goal[1], "gs")
  181. def plot_visited(self, visited):
  182. color = ['gainsboro', 'lightgray', 'silver', 'darkgray',
  183. 'bisque', 'navajowhite', 'moccasin', 'wheat',
  184. 'powderblue', 'skyblue', 'lightskyblue', 'cornflowerblue']
  185. if self.count >= len(color) - 1:
  186. self.count = 0
  187. for x in visited:
  188. plt.plot(x[0], x[1], marker='s', color=color[self.count])
  189. def main():
  190. s_start = (5, 5)
  191. s_goal = (45, 25)
  192. dstar = Dstar(s_start, s_goal)
  193. dstar.run(s_start, s_goal)
  194. if __name__ == '__main__':
  195. main()