rrt_star.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. RRT_star 2D
  3. @author: huiming zhou
  4. """
  5. import math
  6. import numpy as np
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Sampling-based Planning/")
  11. from rrt_2D import env
  12. from rrt_2D import plotting
  13. from rrt_2D import utils
  14. class Node:
  15. def __init__(self, n):
  16. self.x = n[0]
  17. self.y = n[1]
  18. self.cost = 0.0
  19. self.parent = None
  20. class RrtStar:
  21. def __init__(self, x_start, x_goal, step_len,
  22. goal_sample_rate, search_radius, iter_max):
  23. self.xI = Node(x_start)
  24. self.xG = Node(x_goal)
  25. self.step_len = step_len
  26. self.goal_sample_rate = goal_sample_rate
  27. self.search_radius = search_radius
  28. self.iter_max = iter_max
  29. self.vertex = [self.xI]
  30. self.env = env.Env()
  31. self.plotting = plotting.Plotting(x_start, x_goal)
  32. self.utils = utils.Utils()
  33. self.x_range = self.env.x_range
  34. self.y_range = self.env.y_range
  35. self.obs_circle = self.env.obs_circle
  36. self.obs_rectangle = self.env.obs_rectangle
  37. self.obs_boundary = self.env.obs_boundary
  38. def planning(self):
  39. for k in range(self.iter_max):
  40. node_rand = self.random_state(self.goal_sample_rate)
  41. node_near = self.nearest_neighbor(self.vertex, node_rand)
  42. node_new = self.new_state(node_near, node_rand)
  43. if node_new and not self.utils.is_collision(node_near, node_new):
  44. neighbor_index = self.find_near_neighbor(node_new)
  45. if neighbor_index:
  46. node_new = self.choose_parent(node_new, neighbor_index)
  47. self.vertex.append(node_new)
  48. self.rewire(node_new, neighbor_index)
  49. index = self.search_goal_parent()
  50. return self.extract_path(self.vertex[index])
  51. def check_collision(self, node_end):
  52. for (ox, oy, r) in self.obs_circle:
  53. if math.hypot(node_end.x - ox, node_end.y - oy) <= r:
  54. return True
  55. for (ox, oy, w, h) in self.obs_rectangle:
  56. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  57. return True
  58. for (ox, oy, w, h) in self.obs_boundary:
  59. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  60. return True
  61. return False
  62. def random_state(self, goal_sample_rate):
  63. delta = self.utils.delta
  64. if np.random.random() > goal_sample_rate:
  65. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  66. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  67. return self.xG
  68. def nearest_neighbor(self, node_list, n):
  69. return self.vertex[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  70. for nd in node_list]))]
  71. def new_state(self, node_start, node_goal):
  72. dist, theta = self.get_distance_and_angle(node_start, node_goal)
  73. dist = min(self.step_len, dist)
  74. node_new = Node((node_start.x + dist * math.cos(theta),
  75. node_start.y + dist * math.sin(theta)))
  76. node_new.parent = node_start
  77. return node_new
  78. def find_near_neighbor(self, node_new):
  79. n = len(self.vertex) + 1
  80. r = min(self.search_radius * math.sqrt((math.log(n) / n)), self.step_len)
  81. dist_table = [math.hypot(nd.x - node_new.x, nd.y - node_new.y) for nd in self.vertex]
  82. return [dist_table.index(d) for d in dist_table if d <= r]
  83. def choose_parent(self, node_new, neighbor_index):
  84. cost = []
  85. for i in neighbor_index:
  86. node_neighbor = self.vertex[i]
  87. cost.append(self.get_new_cost(node_neighbor, node_new))
  88. cost_min_index = neighbor_index[int(np.argmin(cost))]
  89. node_new = self.new_state(self.vertex[cost_min_index], node_new)
  90. node_new.cost = min(cost)
  91. return node_new
  92. def search_goal_parent(self):
  93. dist_list = [math.hypot(n.x - self.xG.x, n.y - self.xG.y) for n in self.vertex]
  94. node_index = [dist_list.index(i) for i in dist_list if i <= self.step_len]
  95. if node_index:
  96. cost_list = [dist_list[i] + self.vertex[i].cost for i in node_index]
  97. return node_index[int(np.argmin(cost_list))]
  98. return None
  99. def rewire(self, node_new, neighbor_index):
  100. for i in neighbor_index:
  101. node_neighbor = self.vertex[i]
  102. new_cost = self.get_new_cost(node_new, node_neighbor)
  103. if node_neighbor.cost > new_cost:
  104. self.vertex[i] = self.new_state(node_new, node_neighbor)
  105. self.propagate_cost_to_leaves(node_new)
  106. def get_new_cost(self, node_start, node_end):
  107. dist, _ = self.get_distance_and_angle(node_start, node_end)
  108. return node_start.cost + dist
  109. def propagate_cost_to_leaves(self, parent_node):
  110. for node in self.vertex:
  111. if node.parent == parent_node:
  112. node.cost = self.get_new_cost(parent_node, node)
  113. self.propagate_cost_to_leaves(node)
  114. def extract_path(self, node_end):
  115. path = [[self.xG.x, self.xG.y]]
  116. node = node_end
  117. while node.parent is not None:
  118. path.append([node.x, node.y])
  119. node = node.parent
  120. path.append([node.x, node.y])
  121. return path
  122. @staticmethod
  123. def get_distance_and_angle(node_start, node_end):
  124. dx = node_end.x - node_start.x
  125. dy = node_end.y - node_start.y
  126. return math.hypot(dx, dy), math.atan2(dy, dx)
  127. def main():
  128. x_start = (2, 2) # Starting node
  129. x_goal = (49, 28) # Goal node
  130. rrt_star = RrtStar(x_start, x_goal, 1.2, 0.15, 10, 20000)
  131. path = rrt_star.planning()
  132. if path:
  133. rrt_star.plotting.animation(rrt_star.vertex, path)
  134. else:
  135. print("No Path Found!")
  136. if __name__ == '__main__':
  137. main()