rrt_star_smart.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. RRT_STAR_SMART 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import random
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from scipy.spatial.transform import Rotation as Rot
  12. import matplotlib.patches as patches
  13. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  14. "/../../Sampling-based Planning/")
  15. from rrt_2D import env
  16. from rrt_2D import plotting
  17. from rrt_2D import utils
  18. class Node:
  19. def __init__(self, n):
  20. self.x = n[0]
  21. self.y = n[1]
  22. self.parent = None
  23. class RrtStarSmart:
  24. def __init__(self, x_start, x_goal, step_len,
  25. goal_sample_rate, search_radius, iter_max):
  26. self.x_start = Node(x_start)
  27. self.x_goal = Node(x_goal)
  28. self.step_len = step_len
  29. self.goal_sample_rate = goal_sample_rate
  30. self.search_radius = search_radius
  31. self.iter_max = iter_max
  32. self.env = env.Env()
  33. self.plotting = plotting.Plotting(x_start, x_goal)
  34. self.utils = utils.Utils()
  35. self.fig, self.ax = plt.subplots()
  36. self.delta = self.utils.delta
  37. self.x_range = self.env.x_range
  38. self.y_range = self.env.y_range
  39. self.obs_circle = self.env.obs_circle
  40. self.obs_rectangle = self.env.obs_rectangle
  41. self.obs_boundary = self.env.obs_boundary
  42. self.V = [self.x_start]
  43. self.beacons = []
  44. self.beacons_radius = 2
  45. self.path = None
  46. def planning(self):
  47. n = 0
  48. b = 3
  49. for k in range(self.iter_max):
  50. if (k - n) % b == 0:
  51. x_rand = self.Sample(self.beacons)
  52. else:
  53. x_rand = self.Sample()
  54. x_nearest = self.Nearest(self.V, x_rand)
  55. x_new = self.Steer(x_nearest, x_rand)
  56. if x_new and not self.utils.is_collision(x_nearest, x_new):
  57. X_near = self.Near(self.V, x_new)
  58. c_min = self.Cost(x_new)
  59. self.V.append(x_new)
  60. for x_near in X_near:
  61. c_new = self.Cost(x_near) + self.Line(x_near, x_new)
  62. if c_new < c_min:
  63. x_new.parent = x_near
  64. c_min = c_new
  65. for x_near in X_near:
  66. c_near = self.Cost(x_near)
  67. c_new = c_min + self.Line(x_new, x_near)
  68. if c_new < c_near:
  69. x_near.parent = x_new
  70. if self.InGoalRegion(x_new):
  71. self.X_soln.add(x_new)
  72. new_cost = self.Cost(x_new) + self.Line(x_new, self.x_goal)
  73. if new_cost < c_best:
  74. c_best = new_cost
  75. x_best = x_new
  76. if k % 20 == 0:
  77. self.animation(x_center=x_center, c_best=c_best, dist=dist, theta=theta)
  78. self.path = self.ExtractPath(x_best)
  79. self.animation(x_center=x_center, c_best=c_best, dist=dist, theta=theta)
  80. plt.plot([x for x, _ in self.path], [y for _, y in self.path], '-r')
  81. plt.pause(0.01)
  82. plt.show()
  83. def Steer(self, x_start, x_goal):
  84. dist, theta = self.get_distance_and_angle(x_start, x_goal)
  85. dist = min(self.step_len, dist)
  86. node_new = Node((x_start.x + dist * math.cos(theta),
  87. x_start.y + dist * math.sin(theta)))
  88. node_new.parent = x_start
  89. return node_new
  90. def Near(self, nodelist, node):
  91. n = len(nodelist) + 1
  92. r = 50 * math.sqrt((math.log(n) / n))
  93. dist_table = [(nd.x - node.x) ** 2 + (nd.y - node.y) ** 2 for nd in nodelist]
  94. X_near = [nodelist[ind] for ind in range(len(dist_table)) if dist_table[ind] <= r ** 2 and
  95. not self.utils.is_collision(node, nodelist[ind])]
  96. return X_near
  97. def Sample(self, goal=None):
  98. if goal in None:
  99. delta = self.utils.delta
  100. goal_sample_rate = self.goal_sample_rate
  101. if np.random.random() > goal_sample_rate:
  102. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  103. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  104. return self.x_goal
  105. else:
  106. R = self.beacons_radius
  107. r = random.uniform(0, R)
  108. theta = random.uniform(0, 2 * math.pi)
  109. ind = random.randint(0, len(goal) - 1)
  110. return Node((goal[ind][0] + r * math.cos(theta),
  111. goal[ind][1] + r * math.sin(theta)))
  112. def SampleFreeSpace(self):
  113. delta = self.delta
  114. if np.random.random() > self.goal_sample_rate:
  115. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  116. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  117. return self.x_goal
  118. def ExtractPath(self, node):
  119. path = [[self.x_goal.x, self.x_goal.y]]
  120. while node.parent:
  121. path.append([node.x, node.y])
  122. node = node.parent
  123. path.append([self.x_start.x, self.x_start.y])
  124. return path
  125. def InGoalRegion(self, node):
  126. if self.Line(node, self.x_goal) < self.step_len:
  127. return True
  128. return False
  129. @staticmethod
  130. def RotationToWorldFrame(x_start, x_goal, L):
  131. a1 = np.array([[(x_start.x - x_start.x) / L],
  132. [(x_goal.y - x_start.y) / L], [0.0]])
  133. e1 = np.array([[1.0], [0.0], [0.0]])
  134. M = a1 @ e1.T
  135. U, _, V_T = np.linalg.svd(M, True, True)
  136. C = U @ np.diag([1.0, 1.0, np.linalg.det(U) * np.linalg.det(V_T.T)]) @ V_T
  137. return C
  138. @staticmethod
  139. def SampleUnitNBall():
  140. theta, r = random.uniform(0.0, 2 * math.pi), random.random()
  141. x = r * math.cos(theta)
  142. y = r * math.sin(theta)
  143. return np.array([[x], [y], [0.0]])
  144. @staticmethod
  145. def Nearest(nodelist, n):
  146. return nodelist[int(np.argmin([(nd.x - n.x) ** 2 + (nd.y - n.y) ** 2
  147. for nd in nodelist]))]
  148. @staticmethod
  149. def Line(x_start, x_goal):
  150. return math.hypot(x_goal.x - x_start.x, x_goal.y - x_start.y)
  151. @staticmethod
  152. def Cost(node):
  153. cost = 0.0
  154. if node.parent is None:
  155. return cost
  156. while node.parent:
  157. cost += math.hypot(node.x - node.parent.x, node.y - node.parent.y)
  158. node = node.parent
  159. return cost
  160. @staticmethod
  161. def get_distance_and_angle(node_start, node_end):
  162. dx = node_end.x - node_start.x
  163. dy = node_end.y - node_start.y
  164. return math.hypot(dx, dy), math.atan2(dy, dx)
  165. def animation(self, x_center=None, c_best=None, dist=None, theta=None):
  166. plt.cla()
  167. self.plot_grid("Informed rrt*, N = " + str(self.iter_max))
  168. plt.gcf().canvas.mpl_connect(
  169. 'key_release_event',
  170. lambda event: [exit(0) if event.key == 'escape' else None])
  171. if c_best != np.inf:
  172. self.draw_ellipse(x_center, c_best, dist, theta)
  173. for node in self.V:
  174. if node.parent:
  175. plt.plot([node.x, node.parent.x], [node.y, node.parent.y], "-g")
  176. plt.pause(0.01)
  177. def plot_grid(self, name):
  178. for (ox, oy, w, h) in self.obs_boundary:
  179. self.ax.add_patch(
  180. patches.Rectangle(
  181. (ox, oy), w, h,
  182. edgecolor='black',
  183. facecolor='black',
  184. fill=True
  185. )
  186. )
  187. for (ox, oy, w, h) in self.obs_rectangle:
  188. self.ax.add_patch(
  189. patches.Rectangle(
  190. (ox, oy), w, h,
  191. edgecolor='black',
  192. facecolor='gray',
  193. fill=True
  194. )
  195. )
  196. for (ox, oy, r) in self.obs_circle:
  197. self.ax.add_patch(
  198. patches.Circle(
  199. (ox, oy), r,
  200. edgecolor='black',
  201. facecolor='gray',
  202. fill=True
  203. )
  204. )
  205. plt.plot(self.x_start.x, self.x_start.y, "bs", linewidth=3)
  206. plt.plot(self.x_goal.x, self.x_goal.y, "rs", linewidth=3)
  207. plt.title(name)
  208. plt.axis("equal")
  209. @staticmethod
  210. def draw_ellipse(x_center, c_best, dist, theta):
  211. a = math.sqrt(c_best ** 2 - dist ** 2) / 2.0
  212. b = c_best / 2.0
  213. angle = math.pi / 2.0 - theta
  214. cx = x_center[0]
  215. cy = x_center[1]
  216. t = np.arange(0, 2 * math.pi + 0.1, 0.1)
  217. x = [a * math.cos(it) for it in t]
  218. y = [b * math.sin(it) for it in t]
  219. rot = Rot.from_euler('z', -angle).as_dcm()[0:2, 0:2]
  220. fx = rot @ np.array([x, y])
  221. px = np.array(fx[0, :] + cx).flatten()
  222. py = np.array(fx[1, :] + cy).flatten()
  223. plt.plot(cx, cy, ".b")
  224. plt.plot(px, py, linestyle='--', color='darkorange', linewidth=2)
  225. def main():
  226. x_start = (18, 8) # Starting node
  227. x_goal = (37, 18) # Goal node
  228. rrt = RrtStarSmart(x_start, x_goal, 1.0, 0.10, 0, 1000)
  229. rrt.planning()
  230. if __name__ == '__main__':
  231. main()