RRT.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import env
  2. import plotting
  3. import numpy as np
  4. import math
  5. class Node:
  6. def __init__(self, n):
  7. self.x = n[0]
  8. self.y = n[1]
  9. self.path_x = []
  10. self.path_y = []
  11. self.parent = None
  12. class RRT:
  13. def __init__(self, xI, xG):
  14. self.xI = Node(xI)
  15. self.xG = Node(xG)
  16. self.expand_len = 0.8
  17. self.goal_sample_rate = 0.05
  18. self.iterations = 5000
  19. self.node_list = [self.xI]
  20. self.env = env.Env()
  21. self.plotting = plotting.Plotting(xI, xG)
  22. self.x_range = self.env.x_range
  23. self.y_range = self.env.y_range
  24. self.obs_circle = self.env.obs
  25. self.obs_rectangle = self.env.obs_boundary
  26. self.path = self.planning()
  27. self.plotting.animation(self.node_list, self.path)
  28. def planning(self):
  29. for i in range(self.iterations):
  30. node_rand = self.generate_random_node()
  31. node_near = self.get_nearest_node(self.node_list, node_rand)
  32. node_new = self.new_node(node_near, node_rand, self.expand_len)
  33. if not self.check_collision(node_new, self.obs_circle, self.obs_rectangle):
  34. self.node_list.append(node_new)
  35. if self.cal_dis_to_goal(self.node_list[-1]) <= self.expand_len:
  36. self.new_node(self.node_list[-1], self.xG, self.expand_len)
  37. return self.extract_path(self.node_list)
  38. return None
  39. def extract_path(self, nodelist):
  40. path = [(self.xG.x, self.xG.y)]
  41. node_now = nodelist[-1]
  42. while node_now.parent is not None:
  43. node_now = node_now.parent
  44. path.append((node_now.x, node_now.y))
  45. return path
  46. def cal_dis_to_goal(self, node_cal):
  47. return math.hypot(node_cal.x - self.xG.x, node_cal.y - self.xG.y)
  48. def new_node(self, node_start, node_goal, expand_len):
  49. new_node = Node((node_start.x, node_start.y))
  50. d, theta = self.calc_distance_and_angle(new_node, node_goal)
  51. new_node.path_x = [new_node.x]
  52. new_node.path_y = [new_node.y]
  53. if d < expand_len:
  54. expand_len = d
  55. new_node.x += expand_len * math.cos(theta)
  56. new_node.y += expand_len * math.sin(theta)
  57. new_node.path_x.append(new_node.x)
  58. new_node.path_y.append(new_node.y)
  59. new_node.parent = node_start
  60. return new_node
  61. def generate_random_node(self):
  62. if np.random.random() > self.goal_sample_rate:
  63. return Node((np.random.uniform(self.x_range[0], self.x_range[1]),
  64. np.random.uniform(self.y_range[0], self.y_range[1])))
  65. return self.xG
  66. def get_nearest_node(self, node_list, n):
  67. return self.node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  68. for nd in node_list]))]
  69. @staticmethod
  70. def calc_distance_and_angle(from_node, to_node):
  71. dx = to_node.x - from_node.x
  72. dy = to_node.y - from_node.y
  73. return math.hypot(dx, dy), math.atan2(dy, dx)
  74. @staticmethod
  75. def check_collision(node_end, obs_circle, obs_rectangle):
  76. if node_end is None:
  77. return True
  78. for (ox, oy, r) in obs_circle:
  79. if math.hypot(node_end.x - ox, node_end.y - oy) <= r:
  80. return True
  81. for (ox, oy, w, h) in obs_rectangle:
  82. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  83. return True
  84. return False
  85. if __name__ == '__main__':
  86. x_Start = (15, 5) # Starting node
  87. x_Goal = (45, 25) # Goal node
  88. rrt = RRT(x_Start, x_Goal)