extended_rrt.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. """
  2. EXTENDED_RRT_2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import matplotlib.patches as patches
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  12. "/../../Sampling-based Planning/")
  13. from rrt_2D import env
  14. from rrt_2D import plotting
  15. from rrt_2D import utils
  16. class Node:
  17. def __init__(self, n):
  18. self.x = n[0]
  19. self.y = n[1]
  20. self.parent = None
  21. class ExtendedRrt:
  22. def __init__(self, s_start, s_goal, step_len, goal_sample_rate, waypoint_sample_rate, iter_max):
  23. self.s_start = Node(s_start)
  24. self.s_goal = Node(s_goal)
  25. self.step_len = step_len
  26. self.goal_sample_rate = goal_sample_rate
  27. self.waypoint_sample_rate = waypoint_sample_rate
  28. self.iter_max = iter_max
  29. self.vertex = [self.s_start]
  30. self.env = env.Env()
  31. self.plotting = plotting.Plotting(s_start, s_goal)
  32. self.utils = utils.Utils()
  33. self.fig, self.ax = plt.subplots()
  34. self.x_range = self.env.x_range
  35. self.y_range = self.env.y_range
  36. self.obs_circle = self.env.obs_circle
  37. self.obs_rectangle = self.env.obs_rectangle
  38. self.obs_boundary = self.env.obs_boundary
  39. self.path = []
  40. self.waypoint = []
  41. def planning(self):
  42. for i in range(self.iter_max):
  43. node_rand = self.generate_random_node(self.goal_sample_rate)
  44. node_near = self.nearest_neighbor(self.vertex, node_rand)
  45. node_new = self.new_state(node_near, node_rand)
  46. if node_new and not self.utils.is_collision(node_near, node_new):
  47. self.vertex.append(node_new)
  48. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  49. if dist <= self.step_len:
  50. self.new_state(node_new, self.s_goal)
  51. path = self.extract_path(node_new)
  52. self.plot_grid("Extended_RRT")
  53. self.plot_visited()
  54. self.plot_path(path)
  55. self.path = path
  56. self.waypoint = self.extract_waypoint(node_new)
  57. self.fig.canvas.mpl_connect('button_press_event', self.on_press)
  58. plt.show()
  59. return
  60. return None
  61. def on_press(self, event):
  62. x, y = event.xdata, event.ydata
  63. if x < 0 or x > 50 or y < 0 or y > 30:
  64. print("Please choose right area!")
  65. else:
  66. x, y = int(x), int(y)
  67. print("Add circle obstacle at: x =", x, ",", "y =", y)
  68. self.obs_circle.append([x, y, 2])
  69. self.utils.update_obs(self.obs_circle, self.obs_boundary, self.obs_rectangle)
  70. path, waypoint = self.replanning()
  71. plt.cla()
  72. self.plot_grid("Extended_RRT")
  73. self.plot_path(self.path, color='blue')
  74. self.plot_visited()
  75. self.plot_path(path)
  76. self.path = path
  77. self.waypoint = waypoint
  78. self.fig.canvas.draw_idle()
  79. def replanning(self):
  80. self.vertex = [self.s_start]
  81. for i in range(self.iter_max):
  82. node_rand = self.generate_random_node_replanning(self.goal_sample_rate, self.waypoint_sample_rate)
  83. node_near = self.nearest_neighbor(self.vertex, node_rand)
  84. node_new = self.new_state(node_near, node_rand)
  85. if node_new and not self.utils.is_collision(node_near, node_new):
  86. self.vertex.append(node_new)
  87. dist, _ = self.get_distance_and_angle(node_new, self.s_goal)
  88. if dist <= self.step_len:
  89. self.new_state(node_new, self.s_goal)
  90. path = self.extract_path(node_new)
  91. waypoint = self.extract_waypoint(node_new)
  92. return path, waypoint
  93. return None
  94. def generate_random_node(self, goal_sample_rate):
  95. delta = self.utils.delta
  96. if np.random.random() > goal_sample_rate:
  97. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  98. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  99. return self.s_goal
  100. def generate_random_node_replanning(self, goal_sample_rate, waypoint_sample_rate):
  101. delta = self.utils.delta
  102. p = np.random.random()
  103. if p < goal_sample_rate:
  104. return self.s_goal
  105. elif goal_sample_rate < p < goal_sample_rate + waypoint_sample_rate:
  106. return self.waypoint[np.random.randint(0, len(self.path) - 1)]
  107. else:
  108. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  109. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  110. @staticmethod
  111. def nearest_neighbor(node_list, n):
  112. return node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  113. for nd in node_list]))]
  114. def new_state(self, node_start, node_end):
  115. dist, theta = self.get_distance_and_angle(node_start, node_end)
  116. dist = min(self.step_len, dist)
  117. node_new = Node((node_start.x + dist * math.cos(theta),
  118. node_start.y + dist * math.sin(theta)))
  119. node_new.parent = node_start
  120. return node_new
  121. def extract_path(self, node_end):
  122. path = [(self.s_goal.x, self.s_goal.y)]
  123. node_now = node_end
  124. while node_now.parent is not None:
  125. node_now = node_now.parent
  126. path.append((node_now.x, node_now.y))
  127. return path
  128. def extract_waypoint(self, node_end):
  129. waypoint = [self.s_goal]
  130. node_now = node_end
  131. while node_now.parent is not None:
  132. node_now = node_now.parent
  133. waypoint.append(node_now)
  134. return waypoint
  135. @staticmethod
  136. def get_distance_and_angle(node_start, node_end):
  137. dx = node_end.x - node_start.x
  138. dy = node_end.y - node_start.y
  139. return math.hypot(dx, dy), math.atan2(dy, dx)
  140. def plot_grid(self, name):
  141. for (ox, oy, w, h) in self.obs_boundary:
  142. self.ax.add_patch(
  143. patches.Rectangle(
  144. (ox, oy), w, h,
  145. edgecolor='black',
  146. facecolor='black',
  147. fill=True
  148. )
  149. )
  150. for (ox, oy, w, h) in self.obs_rectangle:
  151. self.ax.add_patch(
  152. patches.Rectangle(
  153. (ox, oy), w, h,
  154. edgecolor='black',
  155. facecolor='gray',
  156. fill=True
  157. )
  158. )
  159. for (ox, oy, r) in self.obs_circle:
  160. self.ax.add_patch(
  161. patches.Circle(
  162. (ox, oy), r,
  163. edgecolor='black',
  164. facecolor='gray',
  165. fill=True
  166. )
  167. )
  168. plt.plot(self.s_start.x, self.s_start.y, "bs", linewidth=3)
  169. plt.plot(self.s_goal.x, self.s_goal.y, "gs", linewidth=3)
  170. plt.title(name)
  171. plt.axis("equal")
  172. def plot_visited(self):
  173. animation = True
  174. if animation:
  175. count = 0
  176. for node in self.vertex:
  177. count += 1
  178. if node.parent:
  179. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  180. plt.gcf().canvas.mpl_connect('key_release_event',
  181. lambda event:
  182. [exit(0) if event.key == 'escape' else None])
  183. if count % 10 == 0:
  184. plt.pause(0.001)
  185. else:
  186. for node in self.vertex:
  187. if node.parent:
  188. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  189. @staticmethod
  190. def plot_path(path, color='red'):
  191. plt.plot([x[0] for x in path], [x[1] for x in path], linewidth=2, color=color)
  192. plt.pause(0.01)
  193. def main():
  194. x_start = (2, 2) # Starting node
  195. x_goal = (49, 24) # Goal node
  196. errt = ExtendedRrt(x_start, x_goal, 0.5, 0.1, 0.6, 5000)
  197. errt.planning()
  198. if __name__ == '__main__':
  199. main()