dynamic_rrt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. """
  2. DYNAMIC_RRT_2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import copy
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. import matplotlib.patches as patches
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  13. "/../../Sampling-based Planning/")
  14. from rrt_2D import env
  15. from rrt_2D import plotting
  16. from rrt_2D import utils
  17. class Node:
  18. def __init__(self, n):
  19. self.x = n[0]
  20. self.y = n[1]
  21. self.parent = None
  22. self.flag = "VALID"
  23. class Edge:
  24. def __init__(self, n_p, n_c):
  25. self.parent = n_p
  26. self.child = n_c
  27. self.flag = "VALID"
  28. class DynamicRrt:
  29. def __init__(self, s_start, s_goal, step_len, goal_sample_rate, waypoint_sample_rate, iter_max):
  30. self.s_start = Node(s_start)
  31. self.s_goal = Node(s_goal)
  32. self.step_len = step_len
  33. self.goal_sample_rate = goal_sample_rate
  34. self.waypoint_sample_rate = waypoint_sample_rate
  35. self.iter_max = iter_max
  36. self.vertex = [self.s_start]
  37. self.vertex_old = []
  38. self.vertex_new = []
  39. self.edges = []
  40. self.env = env.Env()
  41. self.plotting = plotting.Plotting(s_start, s_goal)
  42. self.utils = utils.Utils()
  43. self.fig, self.ax = plt.subplots()
  44. self.x_range = self.env.x_range
  45. self.y_range = self.env.y_range
  46. self.obs_circle = self.env.obs_circle
  47. self.obs_rectangle = self.env.obs_rectangle
  48. self.obs_boundary = self.env.obs_boundary
  49. self.obs_add = [0, 0, 0]
  50. self.path = []
  51. self.waypoint = []
  52. def planning(self):
  53. for i in range(self.iter_max):
  54. node_rand = self.generate_random_node(self.goal_sample_rate)
  55. node_near = self.nearest_neighbor(self.vertex, node_rand)
  56. node_new = self.new_state(node_near, node_rand)
  57. if node_new and not self.utils.is_collision(node_near, node_new):
  58. self.vertex.append(node_new)
  59. self.edges.append(Edge(node_near, node_new))
  60. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  61. if dist <= self.step_len:
  62. self.new_state(node_new, self.s_goal)
  63. path = self.extract_path(node_new)
  64. self.plot_grid("Dynamic_RRT")
  65. self.plot_visited()
  66. self.plot_path(path)
  67. self.path = path
  68. self.waypoint = self.extract_waypoint(node_new)
  69. self.fig.canvas.mpl_connect('button_press_event', self.on_press)
  70. plt.show()
  71. return
  72. return None
  73. def on_press(self, event):
  74. x, y = event.xdata, event.ydata
  75. if x < 0 or x > 50 or y < 0 or y > 30:
  76. print("Please choose right area!")
  77. else:
  78. x, y = int(x), int(y)
  79. print("Add circle obstacle at: x =", x, ",", "y =", y)
  80. self.obs_add = [x, y, 2]
  81. self.obs_circle.append([x, y, 2])
  82. self.utils.update_obs(self.obs_circle, self.obs_boundary, self.obs_rectangle)
  83. self.InvalidateNodes()
  84. if self.is_path_invalid():
  85. print("Path is Replanning ...")
  86. path, waypoint = self.replanning()
  87. print("len_vertex: ", len(self.vertex))
  88. print("len_vertex_old: ", len(self.vertex_old))
  89. print("len_vertex_new: ", len(self.vertex_new))
  90. plt.cla()
  91. self.plot_grid("Dynamic_RRT")
  92. self.plot_vertex_old()
  93. self.plot_path(self.path, color='blue')
  94. self.plot_vertex_new()
  95. self.vertex_new = []
  96. self.plot_path(path)
  97. self.path = path
  98. self.waypoint = waypoint
  99. else:
  100. print("Trimming Invalid Nodes ...")
  101. self.TrimRRT()
  102. plt.cla()
  103. self.plot_grid("Dynamic_RRT")
  104. self.plot_visited(animation=False)
  105. self.plot_path(self.path)
  106. self.fig.canvas.draw_idle()
  107. def InvalidateNodes(self):
  108. for edge in self.edges:
  109. if self.is_collision_obs_add(edge.parent, edge.child):
  110. edge.child.flag = "INVALID"
  111. def is_path_invalid(self):
  112. for node in self.waypoint:
  113. if node.flag == "INVALID":
  114. return True
  115. def is_collision_obs_add(self, start, end):
  116. delta = self.utils.delta
  117. obs_add = self.obs_add
  118. if math.hypot(start.x - obs_add[0], start.y - obs_add[1]) <= obs_add[2] + delta:
  119. return True
  120. if math.hypot(end.x - obs_add[0], end.y - obs_add[1]) <= obs_add[2] + delta:
  121. return True
  122. o, d = self.utils.get_ray(start, end)
  123. if self.utils.is_intersect_circle(o, d, [obs_add[0], obs_add[1]], obs_add[2]):
  124. return True
  125. return False
  126. def replanning(self):
  127. self.TrimRRT()
  128. for i in range(self.iter_max):
  129. node_rand = self.generate_random_node_replanning(self.goal_sample_rate, self.waypoint_sample_rate)
  130. node_near = self.nearest_neighbor(self.vertex, node_rand)
  131. node_new = self.new_state(node_near, node_rand)
  132. if node_new and not self.utils.is_collision(node_near, node_new):
  133. self.vertex.append(node_new)
  134. self.vertex_new.append(node_new)
  135. self.edges.append(Edge(node_near, node_new))
  136. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  137. if dist <= self.step_len:
  138. self.new_state(node_new, self.s_goal)
  139. path = self.extract_path(node_new)
  140. waypoint = self.extract_waypoint(node_new)
  141. print("path: ", len(path))
  142. print("waypoint: ", len(waypoint))
  143. return path, waypoint
  144. return None
  145. def TrimRRT(self):
  146. for i in range(1, len(self.vertex)):
  147. node = self.vertex[i]
  148. node_p = node.parent
  149. if node_p.flag == "INVALID":
  150. node.flag = "INVALID"
  151. self.vertex = [node for node in self.vertex if node.flag == "VALID"]
  152. self.vertex_old = copy.deepcopy(self.vertex)
  153. self.edges = [Edge(node.parent, node) for node in self.vertex[1:len(self.vertex)]]
  154. def generate_random_node(self, goal_sample_rate):
  155. delta = self.utils.delta
  156. if np.random.random() > goal_sample_rate:
  157. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  158. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  159. return self.s_goal
  160. def generate_random_node_replanning(self, goal_sample_rate, waypoint_sample_rate):
  161. delta = self.utils.delta
  162. p = np.random.random()
  163. if p < goal_sample_rate:
  164. return self.s_goal
  165. elif goal_sample_rate < p < goal_sample_rate + waypoint_sample_rate:
  166. return self.waypoint[np.random.randint(0, len(self.waypoint) - 1)]
  167. else:
  168. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  169. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  170. @staticmethod
  171. def nearest_neighbor(node_list, n):
  172. return node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  173. for nd in node_list]))]
  174. def new_state(self, node_start, node_end):
  175. dist, theta = self.get_distance_and_angle(node_start, node_end)
  176. dist = min(self.step_len, dist)
  177. node_new = Node((node_start.x + dist * math.cos(theta),
  178. node_start.y + dist * math.sin(theta)))
  179. node_new.parent = node_start
  180. return node_new
  181. def extract_path(self, node_end):
  182. path = [(self.s_goal.x, self.s_goal.y)]
  183. node_now = node_end
  184. while node_now.parent is not None:
  185. node_now = node_now.parent
  186. path.append((node_now.x, node_now.y))
  187. return path
  188. def extract_waypoint(self, node_end):
  189. waypoint = [self.s_goal]
  190. node_now = node_end
  191. while node_now.parent is not None:
  192. node_now = node_now.parent
  193. waypoint.append(node_now)
  194. return waypoint
  195. @staticmethod
  196. def get_distance_and_angle(node_start, node_end):
  197. dx = node_end.x - node_start.x
  198. dy = node_end.y - node_start.y
  199. return math.hypot(dx, dy), math.atan2(dy, dx)
  200. def plot_grid(self, name):
  201. for (ox, oy, w, h) in self.obs_boundary:
  202. self.ax.add_patch(
  203. patches.Rectangle(
  204. (ox, oy), w, h,
  205. edgecolor='black',
  206. facecolor='black',
  207. fill=True
  208. )
  209. )
  210. for (ox, oy, w, h) in self.obs_rectangle:
  211. self.ax.add_patch(
  212. patches.Rectangle(
  213. (ox, oy), w, h,
  214. edgecolor='black',
  215. facecolor='gray',
  216. fill=True
  217. )
  218. )
  219. for (ox, oy, r) in self.obs_circle:
  220. self.ax.add_patch(
  221. patches.Circle(
  222. (ox, oy), r,
  223. edgecolor='black',
  224. facecolor='gray',
  225. fill=True
  226. )
  227. )
  228. plt.plot(self.s_start.x, self.s_start.y, "bs", linewidth=3)
  229. plt.plot(self.s_goal.x, self.s_goal.y, "gs", linewidth=3)
  230. plt.title(name)
  231. plt.axis("equal")
  232. def plot_visited(self, animation=True):
  233. if animation:
  234. count = 0
  235. for node in self.vertex:
  236. count += 1
  237. if node.parent:
  238. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  239. plt.gcf().canvas.mpl_connect('key_release_event',
  240. lambda event:
  241. [exit(0) if event.key == 'escape' else None])
  242. if count % 10 == 0:
  243. plt.pause(0.001)
  244. else:
  245. for node in self.vertex:
  246. if node.parent:
  247. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  248. def plot_vertex_old(self):
  249. for node in self.vertex_old:
  250. if node.parent:
  251. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  252. def plot_vertex_new(self):
  253. count = 0
  254. for node in self.vertex_new:
  255. count += 1
  256. if node.parent:
  257. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], color='darkorange')
  258. plt.gcf().canvas.mpl_connect('key_release_event',
  259. lambda event:
  260. [exit(0) if event.key == 'escape' else None])
  261. if count % 10 == 0:
  262. plt.pause(0.001)
  263. @staticmethod
  264. def plot_path(path, color='red'):
  265. plt.plot([x[0] for x in path], [x[1] for x in path], linewidth=2, color=color)
  266. plt.pause(0.01)
  267. def main():
  268. x_start = (2, 2) # Starting node
  269. x_goal = (49, 24) # Goal node
  270. drrt = DynamicRrt(x_start, x_goal, 0.5, 0.1, 0.6, 5000)
  271. drrt.planning()
  272. if __name__ == '__main__':
  273. main()