dynamic_rrt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """
  2. DYNAMIC_RRT_2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import copy
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. import matplotlib.patches as patches
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  13. "/../../Sampling-based Planning/")
  14. from rrt_2D import env
  15. from rrt_2D import plotting
  16. from rrt_2D import utils
  17. class Node:
  18. def __init__(self, n):
  19. self.x = n[0]
  20. self.y = n[1]
  21. self.parent = None
  22. self.flag = "VALID"
  23. class Edge:
  24. def __init__(self, n_p, n_c):
  25. self.parent = n_p
  26. self.child = n_c
  27. self.flag = "VALID"
  28. class DynamicRrt:
  29. def __init__(self, s_start, s_goal, step_len, goal_sample_rate, waypoint_sample_rate, iter_max):
  30. self.s_start = Node(s_start)
  31. self.s_goal = Node(s_goal)
  32. self.step_len = step_len
  33. self.goal_sample_rate = goal_sample_rate
  34. self.waypoint_sample_rate = waypoint_sample_rate
  35. self.iter_max = iter_max
  36. self.vertex = [self.s_start]
  37. self.vertex_old = []
  38. self.vertex_new = []
  39. self.edges = []
  40. self.env = env.Env()
  41. self.plotting = plotting.Plotting(s_start, s_goal)
  42. self.utils = utils.Utils()
  43. self.fig, self.ax = plt.subplots()
  44. self.x_range = self.env.x_range
  45. self.y_range = self.env.y_range
  46. self.obs_circle = self.env.obs_circle
  47. self.obs_rectangle = self.env.obs_rectangle
  48. self.obs_boundary = self.env.obs_boundary
  49. self.obs_add = [0, 0, 0]
  50. self.path = []
  51. self.waypoint = []
  52. def planning(self):
  53. for i in range(self.iter_max):
  54. node_rand = self.generate_random_node(self.goal_sample_rate)
  55. node_near = self.nearest_neighbor(self.vertex, node_rand)
  56. node_new = self.new_state(node_near, node_rand)
  57. if node_new and not self.utils.is_collision(node_near, node_new):
  58. self.vertex.append(node_new)
  59. self.edges.append(Edge(node_near, node_new))
  60. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  61. if dist <= self.step_len:
  62. self.new_state(node_new, self.s_goal)
  63. path = self.extract_path(node_new)
  64. self.plot_grid("Dynamic_RRT")
  65. self.plot_visited()
  66. self.plot_path(path)
  67. self.path = path
  68. self.waypoint = self.extract_waypoint(node_new)
  69. self.fig.canvas.mpl_connect('button_press_event', self.on_press)
  70. plt.show()
  71. return
  72. return None
  73. def on_press(self, event):
  74. x, y = event.xdata, event.ydata
  75. if x < 0 or x > 50 or y < 0 or y > 30:
  76. print("Please choose right area!")
  77. else:
  78. x, y = int(x), int(y)
  79. print("Add circle obstacle at: x =", x, ",", "y =", y)
  80. self.obs_add = [x, y, 2]
  81. self.obs_circle.append([x, y, 2])
  82. self.utils.update_obs(self.obs_circle, self.obs_boundary, self.obs_rectangle)
  83. self.InvalidateNodes()
  84. if self.is_path_invalid():
  85. print("Path is Replanning ...")
  86. path, waypoint = self.replanning()
  87. print("len_vertex: ", len(self.vertex))
  88. print("len_vertex_old: ", len(self.vertex_old))
  89. print("len_vertex_new: ", len(self.vertex_new))
  90. plt.cla()
  91. self.plot_grid("Dynamic_RRT")
  92. self.plot_vertex_old()
  93. self.plot_path(self.path, color='blue')
  94. self.plot_vertex_new()
  95. self.vertex_new = []
  96. self.plot_path(path)
  97. self.path = path
  98. self.waypoint = waypoint
  99. else:
  100. print("Trimming Invalid Nodes ...")
  101. self.TrimRRT()
  102. plt.cla()
  103. self.plot_grid("Dynamic_RRT")
  104. self.plot_visited(animation=False)
  105. self.plot_path(self.path)
  106. self.fig.canvas.draw_idle()
  107. def InvalidateNodes(self):
  108. for edge in self.edges:
  109. if self.is_collision_obs_add(edge.parent, edge.child):
  110. edge.child.flag = "INVALID"
  111. edge.flag = "INVALID"
  112. def is_path_invalid(self):
  113. for node in self.waypoint:
  114. if node.flag == "INVALID":
  115. return True
  116. def is_collision_obs_add(self, start, end):
  117. delta = self.utils.delta
  118. obs_add = self.obs_add
  119. if math.hypot(start.x - obs_add[0], start.y - obs_add[1]) <= obs_add[2] + delta:
  120. return True
  121. if math.hypot(end.x - obs_add[0], end.y - obs_add[1]) <= obs_add[2] + delta:
  122. return True
  123. o, d = self.utils.get_ray(start, end)
  124. if self.utils.is_intersect_circle(o, d, [obs_add[0], obs_add[1]], obs_add[2]):
  125. return True
  126. return False
  127. def replanning(self):
  128. self.TrimRRT()
  129. for i in range(self.iter_max):
  130. node_rand = self.generate_random_node_replanning(self.goal_sample_rate, self.waypoint_sample_rate)
  131. node_near = self.nearest_neighbor(self.vertex, node_rand)
  132. node_new = self.new_state(node_near, node_rand)
  133. if node_new and not self.utils.is_collision(node_near, node_new):
  134. self.vertex.append(node_new)
  135. self.vertex_new.append(node_new)
  136. self.edges.append(Edge(node_near, node_new))
  137. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  138. if dist <= self.step_len:
  139. self.new_state(node_new, self.s_goal)
  140. path = self.extract_path(node_new)
  141. waypoint = self.extract_waypoint(node_new)
  142. print("path: ", len(path))
  143. print("waypoint: ", len(waypoint))
  144. return path, waypoint
  145. return None
  146. def TrimRRT(self):
  147. for i in range(1, len(self.vertex)):
  148. node = self.vertex[i]
  149. node_p = node.parent
  150. if node_p.flag == "INVALID":
  151. node.flag = "INVALID"
  152. self.vertex = [node for node in self.vertex if node.flag == "VALID"]
  153. self.vertex_old = copy.deepcopy(self.vertex)
  154. self.edges = [Edge(node.parent, node) for node in self.vertex[1:len(self.vertex)]]
  155. def generate_random_node(self, goal_sample_rate):
  156. delta = self.utils.delta
  157. if np.random.random() > goal_sample_rate:
  158. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  159. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  160. return self.s_goal
  161. def generate_random_node_replanning(self, goal_sample_rate, waypoint_sample_rate):
  162. delta = self.utils.delta
  163. p = np.random.random()
  164. if p < goal_sample_rate:
  165. return self.s_goal
  166. elif goal_sample_rate < p < goal_sample_rate + waypoint_sample_rate:
  167. return self.waypoint[np.random.randint(0, len(self.waypoint) - 1)]
  168. else:
  169. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  170. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  171. @staticmethod
  172. def nearest_neighbor(node_list, n):
  173. return node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  174. for nd in node_list]))]
  175. def new_state(self, node_start, node_end):
  176. dist, theta = self.get_distance_and_angle(node_start, node_end)
  177. dist = min(self.step_len, dist)
  178. node_new = Node((node_start.x + dist * math.cos(theta),
  179. node_start.y + dist * math.sin(theta)))
  180. node_new.parent = node_start
  181. return node_new
  182. def extract_path(self, node_end):
  183. path = [(self.s_goal.x, self.s_goal.y)]
  184. node_now = node_end
  185. while node_now.parent is not None:
  186. node_now = node_now.parent
  187. path.append((node_now.x, node_now.y))
  188. return path
  189. def extract_waypoint(self, node_end):
  190. waypoint = [self.s_goal]
  191. node_now = node_end
  192. while node_now.parent is not None:
  193. node_now = node_now.parent
  194. waypoint.append(node_now)
  195. return waypoint
  196. @staticmethod
  197. def get_distance_and_angle(node_start, node_end):
  198. dx = node_end.x - node_start.x
  199. dy = node_end.y - node_start.y
  200. return math.hypot(dx, dy), math.atan2(dy, dx)
  201. def plot_grid(self, name):
  202. for (ox, oy, w, h) in self.obs_boundary:
  203. self.ax.add_patch(
  204. patches.Rectangle(
  205. (ox, oy), w, h,
  206. edgecolor='black',
  207. facecolor='black',
  208. fill=True
  209. )
  210. )
  211. for (ox, oy, w, h) in self.obs_rectangle:
  212. self.ax.add_patch(
  213. patches.Rectangle(
  214. (ox, oy), w, h,
  215. edgecolor='black',
  216. facecolor='gray',
  217. fill=True
  218. )
  219. )
  220. for (ox, oy, r) in self.obs_circle:
  221. self.ax.add_patch(
  222. patches.Circle(
  223. (ox, oy), r,
  224. edgecolor='black',
  225. facecolor='gray',
  226. fill=True
  227. )
  228. )
  229. plt.plot(self.s_start.x, self.s_start.y, "bs", linewidth=3)
  230. plt.plot(self.s_goal.x, self.s_goal.y, "gs", linewidth=3)
  231. plt.title(name)
  232. plt.axis("equal")
  233. def plot_visited(self, animation=True):
  234. if animation:
  235. count = 0
  236. for node in self.vertex:
  237. count += 1
  238. if node.parent:
  239. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  240. plt.gcf().canvas.mpl_connect('key_release_event',
  241. lambda event:
  242. [exit(0) if event.key == 'escape' else None])
  243. if count % 10 == 0:
  244. plt.pause(0.001)
  245. else:
  246. for node in self.vertex:
  247. if node.parent:
  248. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  249. def plot_vertex_old(self):
  250. for node in self.vertex_old:
  251. if node.parent:
  252. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  253. def plot_vertex_new(self):
  254. count = 0
  255. for node in self.vertex_new:
  256. count += 1
  257. if node.parent:
  258. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], color='darkorange')
  259. plt.gcf().canvas.mpl_connect('key_release_event',
  260. lambda event:
  261. [exit(0) if event.key == 'escape' else None])
  262. if count % 10 == 0:
  263. plt.pause(0.001)
  264. @staticmethod
  265. def plot_path(path, color='red'):
  266. plt.plot([x[0] for x in path], [x[1] for x in path], linewidth=2, color=color)
  267. plt.pause(0.01)
  268. def main():
  269. x_start = (2, 2) # Starting node
  270. x_goal = (49, 24) # Goal node
  271. drrt = DynamicRrt(x_start, x_goal, 0.5, 0.1, 0.6, 5000)
  272. drrt.planning()
  273. if __name__ == '__main__':
  274. main()