rrt.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from rrt_2D import env
  2. from rrt_2D import plotting
  3. import numpy as np
  4. import math
  5. class Node:
  6. def __init__(self, n):
  7. self.x = n[0]
  8. self.y = n[1]
  9. self.parent = None
  10. class RRT:
  11. def __init__(self, xI, xG):
  12. self.xI = Node(xI)
  13. self.xG = Node(xG)
  14. self.expand_len = 0.4
  15. self.goal_sample_rate = 0.05
  16. self.iterations = 5000
  17. self.node_list = [self.xI]
  18. self.env = env.Env()
  19. self.plotting = plotting.Plotting(xI, xG)
  20. self.x_range = self.env.x_range
  21. self.y_range = self.env.y_range
  22. self.obs_circle = self.env.obs_circle
  23. self.obs_rectangle = self.env.obs_rectangle
  24. self.obs_boundary = self.env.obs_boundary
  25. self.path = self.planning()
  26. self.plotting.animation(self.node_list, self.path)
  27. def planning(self):
  28. for i in range(self.iterations):
  29. node_rand = self.random_state()
  30. node_near = self.nearest_neighbor(self.node_list, node_rand)
  31. node_new = self.new_state(node_near, node_rand)
  32. if not self.check_collision(node_new):
  33. self.node_list.append(node_new)
  34. if self.dis_to_goal(self.node_list[-1]) <= self.expand_len:
  35. self.new_state(self.node_list[-1], self.xG)
  36. return self.extract_path(self.node_list)
  37. return None
  38. def random_state(self):
  39. if np.random.random() > self.goal_sample_rate:
  40. return Node((np.random.uniform(self.x_range[0], self.x_range[1]),
  41. np.random.uniform(self.y_range[0], self.y_range[1])))
  42. return self.xG
  43. def nearest_neighbor(self, node_list, n):
  44. return self.node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  45. for nd in node_list]))]
  46. def new_state(self, node_start, node_end):
  47. node_new = Node((node_start.x, node_start.y))
  48. dist, theta = self.get_distance_and_angle(node_new, node_end)
  49. dist = min(self.expand_len, dist)
  50. node_new.x += dist * math.cos(theta)
  51. node_new.y += dist * math.sin(theta)
  52. node_new.parent = node_start
  53. return node_new
  54. def extract_path(self, nodelist):
  55. path = [(self.xG.x, self.xG.y)]
  56. node_now = nodelist[-1]
  57. while node_now.parent is not None:
  58. node_now = node_now.parent
  59. path.append((node_now.x, node_now.y))
  60. return path
  61. def dis_to_goal(self, node_cal):
  62. return math.hypot(node_cal.x - self.xG.x, node_cal.y - self.xG.y)
  63. def check_collision(self, node_end):
  64. if node_end is None:
  65. return True
  66. for (ox, oy, r) in self.obs_circle:
  67. if math.hypot(node_end.x - ox, node_end.y - oy) <= r:
  68. return True
  69. for (ox, oy, w, h) in self.obs_rectangle:
  70. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  71. return True
  72. for (ox, oy, w, h) in self.obs_boundary:
  73. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  74. return True
  75. return False
  76. @staticmethod
  77. def get_distance_and_angle(node_start, node_end):
  78. dx = node_end.x - node_start.x
  79. dy = node_end.y - node_start.y
  80. return math.hypot(dx, dy), math.atan2(dy, dx)
  81. if __name__ == '__main__':
  82. x_Start = (2, 2) # Starting node
  83. x_Goal = (49, 28) # Goal node
  84. rrt = RRT(x_Start, x_Goal)