rrt_connect.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. RRT_CONNECT_2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import copy
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  12. "/../../Sampling-based Planning/")
  13. from rrt_2D import env
  14. from rrt_2D import plotting
  15. from rrt_2D import utils
  16. class Node:
  17. def __init__(self, n):
  18. self.x = n[0]
  19. self.y = n[1]
  20. self.parent = None
  21. class RrtConnect:
  22. def __init__(self, s_start, s_goal, step_len, goal_sample_rate, iter_max):
  23. self.s_start = Node(s_start)
  24. self.s_goal = Node(s_goal)
  25. self.step_len = step_len
  26. self.goal_sample_rate = goal_sample_rate
  27. self.iter_max = iter_max
  28. self.V1 = [self.s_start]
  29. self.V2 = [self.s_goal]
  30. self.env = env.Env()
  31. self.plotting = plotting.Plotting(s_start, s_goal)
  32. self.utils = utils.Utils()
  33. self.x_range = self.env.x_range
  34. self.y_range = self.env.y_range
  35. self.obs_circle = self.env.obs_circle
  36. self.obs_rectangle = self.env.obs_rectangle
  37. self.obs_boundary = self.env.obs_boundary
  38. def planning(self):
  39. for i in range(self.iter_max):
  40. node_rand = self.generate_random_node(self.s_goal, self.goal_sample_rate)
  41. node_near = self.nearest_neighbor(self.V1, node_rand)
  42. node_new = self.new_state(node_near, node_rand)
  43. if node_new and not self.utils.is_collision(node_near, node_new):
  44. self.V1.append(node_new)
  45. node_near_prim = self.nearest_neighbor(self.V2, node_new)
  46. node_new_prim = self.new_state(node_near_prim, node_new)
  47. if node_new_prim and not self.utils.is_collision(node_new_prim, node_new_prim):
  48. self.V2.append(node_new_prim)
  49. while True:
  50. node_new_prim2 = self.new_state(node_new_prim, node_new)
  51. if node_new_prim2 and not self.utils.is_collision(node_new_prim2, node_new_prim):
  52. self.V2.append(node_new_prim2)
  53. node_new_prim = self.change_node(node_new_prim, node_new_prim2)
  54. else:
  55. break
  56. if self.is_node_same(node_new_prim, node_new):
  57. break
  58. if self.is_node_same(node_new_prim, node_new):
  59. return self.extract_path(node_new, node_new_prim)
  60. if len(self.V2) < len(self.V1):
  61. list_mid = copy.deepcopy(self.V1)
  62. self.V1 = copy.deepcopy(self.V2)
  63. self.V2 = copy.deepcopy(list_mid)
  64. return None
  65. @staticmethod
  66. def change_node(node_new_prim, node_new_prim2):
  67. node_new = Node((node_new_prim2.x, node_new_prim2.y))
  68. node_new.parent = node_new_prim
  69. return node_new
  70. @staticmethod
  71. def is_node_same(node_new_prim, node_new):
  72. if node_new_prim.x == node_new.x and \
  73. node_new_prim.y == node_new.y:
  74. return True
  75. return False
  76. def generate_random_node(self, sample_goal, goal_sample_rate):
  77. delta = self.utils.delta
  78. if np.random.random() > goal_sample_rate:
  79. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  80. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  81. return sample_goal
  82. @staticmethod
  83. def nearest_neighbor(node_list, n):
  84. return node_list[int(np.argmin([math.hypot(nd.x - n.x, nd.y - n.y)
  85. for nd in node_list]))]
  86. def new_state(self, node_start, node_end):
  87. dist, theta = self.get_distance_and_angle(node_start, node_end)
  88. dist = min(self.step_len, dist)
  89. node_new = Node((node_start.x + dist * math.cos(theta),
  90. node_start.y + dist * math.sin(theta)))
  91. node_new.parent = node_start
  92. return node_new
  93. @staticmethod
  94. def extract_path(node_new, node_new_prim):
  95. path1 = [(node_new.x, node_new.y)]
  96. node_now = node_new
  97. while node_now.parent is not None:
  98. node_now = node_now.parent
  99. path1.append((node_now.x, node_now.y))
  100. path2 = [(node_new_prim.x, node_new_prim.y)]
  101. node_now = node_new_prim
  102. while node_now.parent is not None:
  103. node_now = node_now.parent
  104. path2.append((node_now.x, node_now.y))
  105. return list(list(reversed(path1)) + path2)
  106. @staticmethod
  107. def get_distance_and_angle(node_start, node_end):
  108. dx = node_end.x - node_start.x
  109. dy = node_end.y - node_start.y
  110. return math.hypot(dx, dy), math.atan2(dy, dx)
  111. def main():
  112. x_start = (2, 2) # Starting node
  113. x_goal = (49, 24) # Goal node
  114. rrt_conn = RrtConnect(x_start, x_goal, 0.8, 0.05, 5000)
  115. path = rrt_conn.planning()
  116. rrt_conn.plotting.animation_connect(rrt_conn.V1, rrt_conn.V2, path, "RRT_CONNECT")
  117. if __name__ == '__main__':
  118. main()