rrt_star.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. RRT_star 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import matplotlib.patches as patches
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  12. "/../../Sampling-based Planning/")
  13. from rrt_2D import env
  14. from rrt_2D import plotting
  15. from rrt_2D import utils
  16. from rrt_2D import queue
  17. class Node:
  18. def __init__(self, n):
  19. self.x = n[0]
  20. self.y = n[1]
  21. self.parent = None
  22. class RrtStar:
  23. def __init__(self, x_start, x_goal, step_len,
  24. goal_sample_rate, search_radius, iter_max):
  25. self.s_start = Node(x_start)
  26. self.s_goal = Node(x_goal)
  27. self.step_len = step_len
  28. self.goal_sample_rate = goal_sample_rate
  29. self.search_radius = search_radius
  30. self.iter_max = iter_max
  31. self.vertex = [self.s_start]
  32. self.path = []
  33. self.visited = []
  34. self.env = env.Env()
  35. self.plotting = plotting.Plotting(x_start, x_goal)
  36. self.utils = utils.Utils()
  37. # self.fig, self.ax = plt.subplots()
  38. self.x_range = self.env.x_range
  39. self.y_range = self.env.y_range
  40. self.obs_circle = self.env.obs_circle
  41. self.obs_rectangle = self.env.obs_rectangle
  42. self.obs_boundary = self.env.obs_boundary
  43. def planning(self):
  44. for k in range(self.iter_max):
  45. node_rand = self.generate_random_node(self.goal_sample_rate)
  46. node_near = self.nearest_neighbor(self.vertex, node_rand)
  47. node_new = self.new_state(node_near, node_rand)
  48. if k % 500 == 0:
  49. print(k)
  50. if node_new and not self.utils.is_collision(node_near, node_new):
  51. neighbor_index = self.find_near_neighbor(node_new)
  52. self.vertex.append(node_new)
  53. # if k % 20 == 0:
  54. # self.visited.append([[[node.x, node.parent.x], [node.y, node.parent.y]]
  55. # for node in self.vertex[1: len(self.vertex)]])
  56. if neighbor_index:
  57. self.choose_parent(node_new, neighbor_index)
  58. self.rewire(node_new, neighbor_index)
  59. index = self.search_goal_parent()
  60. self.path = self.extract_path(self.vertex[index])
  61. self.plotting.animation(self.vertex, self.path, "rrt*")
  62. def new_state(self, node_start, node_goal):
  63. dist, theta = self.get_distance_and_angle(node_start, node_goal)
  64. dist = min(self.step_len, dist)
  65. node_new = Node((node_start.x + dist * math.cos(theta),
  66. node_start.y + dist * math.sin(theta)))
  67. node_new.parent = node_start
  68. return node_new
  69. def choose_parent(self, node_new, neighbor_index):
  70. cost = [self.get_new_cost(self.vertex[i], node_new) for i in neighbor_index]
  71. cost_min_index = neighbor_index[int(np.argmin(cost))]
  72. node_new.parent = self.vertex[cost_min_index]
  73. def rewire(self, node_new, neighbor_index):
  74. for i in neighbor_index:
  75. node_neighbor = self.vertex[i]
  76. if self.cost(node_neighbor) > self.get_new_cost(node_new, node_neighbor):
  77. node_neighbor.parent = node_new
  78. def search_goal_parent(self):
  79. dist_list = [math.hypot(n.x - self.s_goal.x, n.y - self.s_goal.y) for n in self.vertex]
  80. node_index = [dist_list.index(i) for i in dist_list if i <= self.step_len]
  81. if node_index:
  82. cost_list = [dist_list[i] + self.cost(self.vertex[i]) for i in node_index
  83. if not self.utils.is_collision(self.vertex[i], self.s_goal)]
  84. return node_index[int(np.argmin(cost_list))]
  85. return len(self.vertex) - 1
  86. def get_new_cost(self, node_start, node_end):
  87. dist, _ = self.get_distance_and_angle(node_start, node_end)
  88. return self.cost(node_start) + dist
  89. def generate_random_node(self, goal_sample_rate):
  90. delta = self.utils.delta
  91. if np.random.random() > goal_sample_rate:
  92. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  93. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  94. return self.s_goal
  95. @staticmethod
  96. def nearest_neighbor(node_list, n):
  97. return node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  98. for nd in node_list]))]
  99. def find_near_neighbor(self, node_new):
  100. n = len(self.vertex) + 1
  101. r = min(self.search_radius * math.sqrt((math.log(n) / n)), self.step_len)
  102. dist_table = [math.hypot(nd.x - node_new.x, nd.y - node_new.y) for nd in self.vertex]
  103. dist_table_index = [ind for ind in range(len(dist_table)) if dist_table[ind] <= r and
  104. not self.utils.is_collision(node_new, self.vertex[ind])]
  105. return dist_table_index
  106. @staticmethod
  107. def cost(node_p):
  108. node = node_p
  109. cost = 0.0
  110. while node.parent:
  111. cost += math.hypot(node.x - node.parent.x, node.y - node.parent.y)
  112. node = node.parent
  113. return cost
  114. def animation(self, name):
  115. self.plot_grid(name)
  116. plt.pause(4)
  117. for edge_set in self.visited:
  118. plt.cla()
  119. self.plot_grid(name)
  120. for edges in edge_set:
  121. plt.plot(edges[0], edges[1], "-g")
  122. plt.pause(0.0001)
  123. if self.path:
  124. plt.plot([x[0] for x in self.path], [x[1] for x in self.path], '-r', linewidth=2)
  125. plt.pause(0.5)
  126. plt.show()
  127. def plot_grid(self, name):
  128. for (ox, oy, w, h) in self.obs_boundary:
  129. self.ax.add_patch(
  130. patches.Rectangle(
  131. (ox, oy), w, h,
  132. edgecolor='black',
  133. facecolor='black',
  134. fill=True
  135. )
  136. )
  137. for (ox, oy, w, h) in self.obs_rectangle:
  138. self.ax.add_patch(
  139. patches.Rectangle(
  140. (ox, oy), w, h,
  141. edgecolor='black',
  142. facecolor='gray',
  143. fill=True
  144. )
  145. )
  146. for (ox, oy, r) in self.obs_circle:
  147. self.ax.add_patch(
  148. patches.Circle(
  149. (ox, oy), r,
  150. edgecolor='black',
  151. facecolor='gray',
  152. fill=True
  153. )
  154. )
  155. plt.plot(self.s_start.x, self.s_start.y, "bs", linewidth=3)
  156. plt.plot(self.s_goal.x, self.s_goal.y, "gs", linewidth=3)
  157. plt.title(name)
  158. plt.axis("equal")
  159. def update_cost(self, parent_node):
  160. OPEN = queue.QueueFIFO()
  161. OPEN.put(parent_node)
  162. while not OPEN.empty():
  163. node = OPEN.get()
  164. if len(node.child) == 0:
  165. continue
  166. for node_c in node.child:
  167. node_c.Cost = self.get_new_cost(node, node_c)
  168. OPEN.put(node_c)
  169. def extract_path(self, node_end):
  170. path = [[self.s_goal.x, self.s_goal.y]]
  171. node = node_end
  172. while node.parent is not None:
  173. path.append([node.x, node.y])
  174. node = node.parent
  175. path.append([node.x, node.y])
  176. return path
  177. @staticmethod
  178. def get_distance_and_angle(node_start, node_end):
  179. dx = node_end.x - node_start.x
  180. dy = node_end.y - node_start.y
  181. return math.hypot(dx, dy), math.atan2(dy, dx)
  182. def main():
  183. x_start = (2, 2) # Starting node
  184. x_goal = (49, 24) # Goal node
  185. rrt_star = RrtStar(x_start, x_goal, 10, 0.10, 20, 15000)
  186. rrt_star.planning()
  187. if __name__ == '__main__':
  188. main()