dynamic_rrt.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. """
  2. DYNAMIC_RRT_2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import matplotlib.patches as patches
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  12. "/../../Sampling-based Planning/")
  13. from rrt_2D import env
  14. from rrt_2D import plotting
  15. from rrt_2D import utils
  16. class Node:
  17. def __init__(self, n):
  18. self.x = n[0]
  19. self.y = n[1]
  20. self.parent = None
  21. class Edge:
  22. def __init__(self, n_p, n_c):
  23. self.parent = n_p
  24. self.child = n_c
  25. class DynamicRrt:
  26. def __init__(self, s_start, s_goal, step_len, goal_sample_rate, waypoint_sample_rate, iter_max):
  27. self.s_start = Node(s_start)
  28. self.s_goal = Node(s_goal)
  29. self.step_len = step_len
  30. self.goal_sample_rate = goal_sample_rate
  31. self.waypoint_sample_rate = waypoint_sample_rate
  32. self.iter_max = iter_max
  33. self.vertex = [self.s_start]
  34. self.edges = set()
  35. self.env = env.Env()
  36. self.plotting = plotting.Plotting(s_start, s_goal)
  37. self.utils = utils.Utils()
  38. self.fig, self.ax = plt.subplots()
  39. self.x_range = self.env.x_range
  40. self.y_range = self.env.y_range
  41. self.obs_circle = self.env.obs_circle
  42. self.obs_rectangle = self.env.obs_rectangle
  43. self.obs_boundary = self.env.obs_boundary
  44. self.path = []
  45. self.waypoint = []
  46. def planning(self):
  47. for i in range(self.iter_max):
  48. node_rand = self.generate_random_node(self.goal_sample_rate)
  49. node_near = self.nearest_neighbor(self.vertex, node_rand)
  50. node_new = self.new_state(node_near, node_rand)
  51. if node_new and not self.utils.is_collision(node_near, node_new):
  52. self.vertex.append(node_new)
  53. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  54. if dist <= self.step_len:
  55. self.new_state(node_new, self.s_goal)
  56. path = self.extract_path(node_new)
  57. self.plot_grid("Extended_RRT")
  58. self.plot_visited()
  59. self.plot_path(path)
  60. self.path = path
  61. self.waypoint = self.extract_waypoint(node_new)
  62. self.fig.canvas.mpl_connect('button_press_event', self.on_press)
  63. plt.show()
  64. return
  65. return None
  66. def on_press(self, event):
  67. x, y = event.xdata, event.ydata
  68. if x < 0 or x > 50 or y < 0 or y > 30:
  69. print("Please choose right area!")
  70. else:
  71. x, y = int(x), int(y)
  72. print("Add circle obstacle at: x =", x, ",", "y =", y)
  73. self.obs_circle.append([x, y, 2])
  74. self.utils.update_obs(self.obs_circle, self.obs_boundary, self.obs_rectangle)
  75. path, waypoint = self.replanning()
  76. plt.cla()
  77. self.plot_grid("Extended_RRT")
  78. self.plot_path(self.path, color='blue')
  79. self.plot_visited()
  80. self.plot_path(path)
  81. self.path = path
  82. self.waypoint = waypoint
  83. self.fig.canvas.draw_idle()
  84. def replanning(self):
  85. self.vertex = [self.s_start]
  86. for i in range(self.iter_max):
  87. node_rand = self.generate_random_node_replanning(self.goal_sample_rate, self.waypoint_sample_rate)
  88. node_near = self.nearest_neighbor(self.vertex, node_rand)
  89. node_new = self.new_state(node_near, node_rand)
  90. if node_new and not self.utils.is_collision(node_near, node_new):
  91. self.vertex.append(node_new)
  92. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  93. if dist <= self.step_len:
  94. self.new_state(node_new, self.s_goal)
  95. path = self.extract_path(node_new)
  96. waypoint = self.extract_waypoint(node_new)
  97. return path, waypoint
  98. return None
  99. def generate_random_node(self, goal_sample_rate):
  100. delta = self.utils.delta
  101. if np.random.random() > goal_sample_rate:
  102. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  103. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  104. return self.s_goal
  105. def generate_random_node_replanning(self, goal_sample_rate, waypoint_sample_rate):
  106. delta = self.utils.delta
  107. p = np.random.random()
  108. if p < goal_sample_rate:
  109. return self.s_goal
  110. elif goal_sample_rate < p < goal_sample_rate + waypoint_sample_rate:
  111. return self.waypoint[np.random.randint(0, len(self.path) - 1)]
  112. else:
  113. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  114. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  115. @staticmethod
  116. def nearest_neighbor(node_list, n):
  117. return node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  118. for nd in node_list]))]
  119. def new_state(self, node_start, node_end):
  120. dist, theta = self.get_distance_and_angle(node_start, node_end)
  121. dist = min(self.step_len, dist)
  122. node_new = Node((node_start.x + dist * math.cos(theta),
  123. node_start.y + dist * math.sin(theta)))
  124. node_new.parent = node_start
  125. return node_new
  126. def extract_path(self, node_end):
  127. path = [(self.s_goal.x, self.s_goal.y)]
  128. node_now = node_end
  129. while node_now.parent is not None:
  130. node_now = node_now.parent
  131. path.append((node_now.x, node_now.y))
  132. return path
  133. def extract_waypoint(self, node_end):
  134. waypoint = [self.s_goal]
  135. node_now = node_end
  136. while node_now.parent is not None:
  137. node_now = node_now.parent
  138. waypoint.append(node_now)
  139. return waypoint
  140. @staticmethod
  141. def get_distance_and_angle(node_start, node_end):
  142. dx = node_end.x - node_start.x
  143. dy = node_end.y - node_start.y
  144. return math.hypot(dx, dy), math.atan2(dy, dx)
  145. def plot_grid(self, name):
  146. for (ox, oy, w, h) in self.obs_boundary:
  147. self.ax.add_patch(
  148. patches.Rectangle(
  149. (ox, oy), w, h,
  150. edgecolor='black',
  151. facecolor='black',
  152. fill=True
  153. )
  154. )
  155. for (ox, oy, w, h) in self.obs_rectangle:
  156. self.ax.add_patch(
  157. patches.Rectangle(
  158. (ox, oy), w, h,
  159. edgecolor='black',
  160. facecolor='gray',
  161. fill=True
  162. )
  163. )
  164. for (ox, oy, r) in self.obs_circle:
  165. self.ax.add_patch(
  166. patches.Circle(
  167. (ox, oy), r,
  168. edgecolor='black',
  169. facecolor='gray',
  170. fill=True
  171. )
  172. )
  173. plt.plot(self.s_start.x, self.s_start.y, "bs", linewidth=3)
  174. plt.plot(self.s_goal.x, self.s_goal.y, "gs", linewidth=3)
  175. plt.title(name)
  176. plt.axis("equal")
  177. def plot_visited(self):
  178. animation = True
  179. if animation:
  180. count = 0
  181. for node in self.vertex:
  182. count += 1
  183. if node.parent:
  184. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  185. plt.gcf().canvas.mpl_connect('key_release_event',
  186. lambda event:
  187. [exit(0) if event.key == 'escape' else None])
  188. if count % 10 == 0:
  189. plt.pause(0.001)
  190. else:
  191. for node in self.vertex:
  192. if node.parent:
  193. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  194. @staticmethod
  195. def plot_path(path, color='red'):
  196. plt.plot([x[0] for x in path], [x[1] for x in path], linewidth=2, color=color)
  197. plt.pause(0.01)
  198. def main():
  199. x_start = (2, 2) # Starting node
  200. x_goal = (49, 24) # Goal node
  201. drrt = DynamicRrt(x_start, x_goal, 0.5, 0.1, 0.6, 5000)
  202. drrt.planning()
  203. if __name__ == '__main__':
  204. main()