rrt.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """
  2. RRT_2D
  3. @author: huiming zhou
  4. """
  5. from rrt_2D import env
  6. from rrt_2D import plotting
  7. import numpy as np
  8. import math
  9. class Node:
  10. def __init__(self, n):
  11. self.x = n[0]
  12. self.y = n[1]
  13. self.parent = None
  14. class Rrt:
  15. def __init__(self, x_start, x_goal, expand_len, goal_sample_rate, iter_limit):
  16. self.xI = Node(x_start)
  17. self.xG = Node(x_goal)
  18. self.expand_len = expand_len
  19. self.goal_sample_rate = goal_sample_rate
  20. self.iter_limit = iter_limit
  21. self.vertex = [self.xI]
  22. self.env = env.Env()
  23. self.plotting = plotting.Plotting(x_start, x_goal)
  24. self.x_range = self.env.x_range
  25. self.y_range = self.env.y_range
  26. self.obs_circle = self.env.obs_circle
  27. self.obs_rectangle = self.env.obs_rectangle
  28. self.obs_boundary = self.env.obs_boundary
  29. def planning(self):
  30. for i in range(self.iter_limit):
  31. node_rand = self.random_state(self.goal_sample_rate)
  32. node_near = self.nearest_neighbor(self.vertex, node_rand)
  33. node_new = self.new_state(node_near, node_rand)
  34. if node_new and not self.check_collision(node_new):
  35. self.vertex.append(node_new)
  36. dist, _ = self.get_distance_and_angle(node_new, self.xG)
  37. if dist <= self.expand_len:
  38. self.new_state(node_new, self.xG)
  39. return self.extract_path(node_new)
  40. return None
  41. def random_state(self, goal_sample_rate):
  42. if np.random.random() > goal_sample_rate:
  43. return Node((np.random.uniform(self.x_range[0], self.x_range[1]),
  44. np.random.uniform(self.y_range[0], self.y_range[1])))
  45. return self.xG
  46. def nearest_neighbor(self, node_list, n):
  47. return self.vertex[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  48. for nd in node_list]))]
  49. def new_state(self, node_start, node_end):
  50. node_new = Node((node_start.x, node_start.y))
  51. dist, theta = self.get_distance_and_angle(node_new, node_end)
  52. dist = min(self.expand_len, dist)
  53. node_new.x += dist * math.cos(theta)
  54. node_new.y += dist * math.sin(theta)
  55. node_new.parent = node_start
  56. return node_new
  57. def extract_path(self, node_end):
  58. path = [(self.xG.x, self.xG.y)]
  59. node_now = node_end
  60. while node_now.parent is not None:
  61. node_now = node_now.parent
  62. path.append((node_now.x, node_now.y))
  63. return path
  64. def check_collision(self, node_end):
  65. for (ox, oy, r) in self.obs_circle:
  66. if math.hypot(node_end.x - ox, node_end.y - oy) <= r:
  67. return True
  68. for (ox, oy, w, h) in self.obs_rectangle:
  69. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  70. return True
  71. for (ox, oy, w, h) in self.obs_boundary:
  72. if 0 <= (node_end.x - ox) <= w and 0 <= (node_end.y - oy) <= h:
  73. return True
  74. return False
  75. @staticmethod
  76. def get_distance_and_angle(node_start, node_end):
  77. dx = node_end.x - node_start.x
  78. dy = node_end.y - node_start.y
  79. return math.hypot(dx, dy), math.atan2(dy, dx)
  80. def main():
  81. x_start = (2, 2) # Starting node
  82. x_goal = (49, 28) # Goal node
  83. rrt = Rrt(x_start, x_goal, 0.4, 0.05, 2000)
  84. path = rrt.planning()
  85. if path:
  86. rrt.plotting.animation(rrt.vertex, path)
  87. else:
  88. print("No Path Found!")
  89. if __name__ == '__main__':
  90. main()