rrt*.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from rrt_2D import env
  2. from rrt_2D import plotting
  3. import numpy as np
  4. import math
  5. class Node:
  6. def __init__(self, n):
  7. self.x = n[0]
  8. self.y = n[1]
  9. self.cost = 0.0
  10. self.parent = None
  11. class RRT:
  12. def __init__(self, xI, xG):
  13. self.xI = Node(xI)
  14. self.xG = Node(xG)
  15. self.expand_len = 1
  16. self.goal_sample_rate = 0.05
  17. self.connect_dist = 10
  18. self.iterations = 5000
  19. self.node_list = [self.xI]
  20. self.env = env.Env()
  21. self.plotting = plotting.Plotting(xI, xG)
  22. self.x_range = self.env.x_range
  23. self.y_range = self.env.y_range
  24. self.obs_circle = self.env.obs_circle
  25. self.obs_rectangle = self.env.obs_rectangle
  26. self.obs_boundary = self.env.obs_boundary
  27. self.path = self.planning()
  28. self.plotting.animation(self.node_list, self.path, False)
  29. def planning(self):
  30. for k in range(self.iterations):
  31. node_rand = self.random_state()
  32. node_near = self.nearest_neighbor(self.node_list, node_rand)
  33. node_new = self.new_state(node_near, node_rand)
  34. if not self.check_collision(node_new):
  35. neighbor_index = self.find_near_neighbor(node_new)
  36. node_new = self.choose_parent(node_new, neighbor_index)
  37. if node_new:
  38. self.node_list.append(node_new)
  39. self.rewire(node_new, neighbor_index)
  40. # if self.dis_to_goal(self.node_list[-1]) <= self.expand_len:
  41. # self.new_state(self.node_list[-1], self.xG)
  42. # return self.extract_path()
  43. index = self.search_best_goal_node()
  44. self.xG.parent = self.node_list[index]
  45. return self.extract_path()
  46. def random_state(self):
  47. if np.random.random() > self.goal_sample_rate:
  48. return Node((np.random.uniform(self.x_range[0], self.x_range[1]),
  49. np.random.uniform(self.y_range[0], self.y_range[1])))
  50. return self.xG
  51. def nearest_neighbor(self, node_list, n):
  52. return self.node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  53. for nd in node_list]))]
  54. def new_state(self, node_start, node_goal):
  55. node_new = Node((node_start.x, node_start.y))
  56. dist, theta = self.get_distance_and_angle(node_new, node_goal)
  57. dist = min(self.expand_len, dist)
  58. node_new.x += dist * math.cos(theta)
  59. node_new.y += dist * math.sin(theta)
  60. node_new.parent = node_start
  61. return node_new
  62. def find_near_neighbor(self, node_new):
  63. n = len(self.node_list) + 1
  64. r = min(self.connect_dist * math.sqrt((math.log(n) / n)), self.expand_len)
  65. dist_table = [math.hypot(nd.x - node_new.x, nd.y - node_new.y) for nd in self.node_list]
  66. node_index = [dist_table.index(d) for d in dist_table if d <= r]
  67. return node_index
  68. def choose_parent(self, node_new, neighbor_index):
  69. if not neighbor_index:
  70. return None
  71. cost = []
  72. for i in neighbor_index:
  73. node_near = self.node_list[i]
  74. node_mid = self.new_state(node_near, node_new)
  75. if node_mid and not self.check_collision(node_mid):
  76. cost.append(self.update_cost(node_near, node_mid))
  77. else:
  78. cost.append(float("inf"))
  79. if min(cost) != float('inf'):
  80. index = int(np.argmin(cost))
  81. neighbor_min = neighbor_index[index]
  82. node_new = self.new_state(self.node_list[neighbor_min], node_new)
  83. node_new.cost = min(cost)
  84. return node_new
  85. return None
  86. def search_best_goal_node(self):
  87. dist_to_goal_list = [self.dis_to_goal(n) for n in self.node_list]
  88. goal_inds = [dist_to_goal_list.index(i) for i in dist_to_goal_list if i <= self.expand_len]
  89. return goal_inds[0]
  90. # safe_goal_inds = []
  91. # for goal_ind in goal_inds:
  92. # t_node = self.new_state(self.node_list[goal_ind], self.xG)
  93. # if self.check_collision(t_node):
  94. # safe_goal_inds.append(goal_ind)
  95. #
  96. # if not safe_goal_inds:
  97. # print('hahhah')
  98. # return None
  99. #
  100. # min_cost = min([self.node_list[i].cost for i in safe_goal_inds])
  101. # for i in safe_goal_inds:
  102. # if self.node_list[i].cost == min_cost:
  103. # self.xG.parent = self.node_list[i]
  104. def rewire(self, node_new, neighbor_index):
  105. for i in neighbor_index:
  106. node_near = self.node_list[i]
  107. node_edge = self.new_state(node_new, node_near)
  108. if not node_edge:
  109. continue
  110. node_edge.cost = self.update_cost(node_new, node_near)
  111. collision = self.check_collision(node_edge)
  112. improved_cost = node_near.cost > node_edge.cost
  113. if not collision and improved_cost:
  114. self.node_list[i] = node_edge
  115. self.propagate_cost_to_leaves(node_new)
  116. def update_cost(self, node_start, node_end):
  117. dist, theta = self.get_distance_and_angle(node_start, node_end)
  118. return node_start.cost + dist
  119. def propagate_cost_to_leaves(self, parent_node):
  120. for node in self.node_list:
  121. if node.parent == parent_node:
  122. node.cost = self.update_cost(parent_node, node)
  123. self.propagate_cost_to_leaves(node)
  124. def extract_path(self):
  125. path = [[self.xG.x, self.xG.y]]
  126. node = self.xG
  127. while node.parent is not None:
  128. path.append([node.x, node.y])
  129. node = node.parent
  130. path.append([node.x, node.y])
  131. return path
  132. def dis_to_goal(self, node_cal):
  133. return math.hypot(node_cal.x - self.xG.x, node_cal.y - self.xG.y)
  134. def check_collision(self, node_end):
  135. if node_end is None:
  136. return True
  137. for (ox, oy, r) in self.obs_circle:
  138. if math.hypot(node_end.x - ox, node_end.y - oy) <= r:
  139. return True
  140. for (ox, oy, w, h) in self.obs_rectangle:
  141. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  142. return True
  143. for (ox, oy, w, h) in self.obs_boundary:
  144. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  145. return True
  146. return False
  147. @staticmethod
  148. def get_distance_and_angle(node_start, node_end):
  149. dx = node_end.x - node_start.x
  150. dy = node_end.y - node_start.y
  151. return math.hypot(dx, dy), math.atan2(dy, dx)
  152. if __name__ == '__main__':
  153. x_Start = (2, 2) # Starting node
  154. x_Goal = (49, 28) # Goal node
  155. rrt = RRT(x_Start, x_Goal)