rrt_star_smart.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. """
  2. RRT_STAR_SMART 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import random
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from scipy.spatial.transform import Rotation as Rot
  12. import matplotlib.patches as patches
  13. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  14. "/../../Sampling-based Planning/")
  15. from rrt_2D import env
  16. from rrt_2D import plotting
  17. from rrt_2D import utils
  18. class Node:
  19. def __init__(self, n):
  20. self.x = n[0]
  21. self.y = n[1]
  22. self.parent = None
  23. class RrtStarSmart:
  24. def __init__(self, x_start, x_goal, step_len,
  25. goal_sample_rate, search_radius, iter_max):
  26. self.x_start = Node(x_start)
  27. self.x_goal = Node(x_goal)
  28. self.step_len = step_len
  29. self.goal_sample_rate = goal_sample_rate
  30. self.search_radius = search_radius
  31. self.iter_max = iter_max
  32. self.env = env.Env()
  33. self.plotting = plotting.Plotting(x_start, x_goal)
  34. self.utils = utils.Utils()
  35. self.fig, self.ax = plt.subplots()
  36. self.delta = self.utils.delta
  37. self.x_range = self.env.x_range
  38. self.y_range = self.env.y_range
  39. self.obs_circle = self.env.obs_circle
  40. self.obs_rectangle = self.env.obs_rectangle
  41. self.obs_boundary = self.env.obs_boundary
  42. self.V = [self.x_start]
  43. self.beacons = []
  44. self.beacons_radius = 2
  45. self.direct_cost_old = np.inf
  46. self.obs_vertex = self.utils.get_obs_vertex()
  47. self.path = None
  48. def planning(self):
  49. n = 0
  50. b = 2
  51. InitPathFlag = False
  52. self.ReformObsVertex()
  53. for k in range(self.iter_max):
  54. if k % 200 == 0:
  55. print(k)
  56. if (k - n) % b == 0 and len(self.beacons) > 0:
  57. x_rand = self.Sample(self.beacons)
  58. else:
  59. x_rand = self.Sample()
  60. x_nearest = self.Nearest(self.V, x_rand)
  61. x_new = self.Steer(x_nearest, x_rand)
  62. if x_new and not self.utils.is_collision(x_nearest, x_new):
  63. X_near = self.Near(self.V, x_new)
  64. self.V.append(x_new)
  65. if X_near:
  66. # choose parent
  67. cost_list = [self.Cost(x_near) + self.Line(x_near, x_new) for x_near in X_near]
  68. x_new.parent = X_near[int(np.argmin(cost_list))]
  69. # rewire
  70. c_min = self.Cost(x_new)
  71. for x_near in X_near:
  72. c_near = self.Cost(x_near)
  73. c_new = c_min + self.Line(x_new, x_near)
  74. if c_new < c_near:
  75. x_near.parent = x_new
  76. if not InitPathFlag and self.InitialPathFound(x_new):
  77. InitPathFlag = True
  78. n = k
  79. if InitPathFlag:
  80. self.PathOptimization(x_new)
  81. if k % 5 == 0:
  82. self.animation()
  83. self.path = self.ExtractPath()
  84. self.animation()
  85. plt.plot([x for x, _ in self.path], [y for _, y in self.path], '-r')
  86. plt.pause(0.01)
  87. plt.show()
  88. def PathOptimization(self, node):
  89. direct_cost_new = 0.0
  90. node_end = self.x_goal
  91. while node.parent:
  92. node_parent = node.parent
  93. if not self.utils.is_collision(node_parent, node_end):
  94. node_end.parent = node_parent
  95. else:
  96. direct_cost_new += self.Line(node, node_end)
  97. node_end = node
  98. node = node_parent
  99. if direct_cost_new < self.direct_cost_old:
  100. self.direct_cost_old = direct_cost_new
  101. self.UpdateBeacons()
  102. def UpdateBeacons(self):
  103. node = self.x_goal
  104. beacons = []
  105. while node.parent:
  106. near_vertex = [v for v in self.obs_vertex
  107. if (node.x - v[0]) ** 2 + (node.y - v[1]) ** 2 < 9]
  108. if len(near_vertex) > 0:
  109. for v in near_vertex:
  110. beacons.append(v)
  111. node = node.parent
  112. self.beacons = beacons
  113. def ReformObsVertex(self):
  114. obs_vertex = []
  115. for obs in self.obs_vertex:
  116. for vertex in obs:
  117. obs_vertex.append(vertex)
  118. self.obs_vertex = obs_vertex
  119. def Steer(self, x_start, x_goal):
  120. dist, theta = self.get_distance_and_angle(x_start, x_goal)
  121. dist = min(self.step_len, dist)
  122. node_new = Node((x_start.x + dist * math.cos(theta),
  123. x_start.y + dist * math.sin(theta)))
  124. node_new.parent = x_start
  125. return node_new
  126. def Near(self, nodelist, node):
  127. n = len(self.V) + 1
  128. r = 50 * math.sqrt((math.log(n) / n))
  129. dist_table = [(nd.x - node.x) ** 2 + (nd.y - node.y) ** 2 for nd in nodelist]
  130. X_near = [nodelist[ind] for ind in range(len(dist_table)) if dist_table[ind] <= r ** 2 and
  131. not self.utils.is_collision(node, nodelist[ind])]
  132. return X_near
  133. def Sample(self, goal=None):
  134. if goal is None:
  135. delta = self.utils.delta
  136. goal_sample_rate = self.goal_sample_rate
  137. if np.random.random() > goal_sample_rate:
  138. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  139. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  140. return self.x_goal
  141. else:
  142. R = self.beacons_radius
  143. r = random.uniform(0, R)
  144. theta = random.uniform(0, 2 * math.pi)
  145. ind = random.randint(0, len(goal) - 1)
  146. return Node((goal[ind][0] + r * math.cos(theta),
  147. goal[ind][1] + r * math.sin(theta)))
  148. def SampleFreeSpace(self):
  149. delta = self.delta
  150. if np.random.random() > self.goal_sample_rate:
  151. return Node((np.random.uniform(self.x_range[0] + delta, self.x_range[1] - delta),
  152. np.random.uniform(self.y_range[0] + delta, self.y_range[1] - delta)))
  153. return self.x_goal
  154. def ExtractPath(self):
  155. path = []
  156. node = self.x_goal
  157. while node.parent:
  158. path.append([node.x, node.y])
  159. node = node.parent
  160. path.append([self.x_start.x, self.x_start.y])
  161. return path
  162. def InitialPathFound(self, node):
  163. if self.Line(node, self.x_goal) < self.step_len:
  164. return True
  165. return False
  166. @staticmethod
  167. def Nearest(nodelist, n):
  168. return nodelist[int(np.argmin([(nd.x - n.x) ** 2 + (nd.y - n.y) ** 2
  169. for nd in nodelist]))]
  170. @staticmethod
  171. def Line(x_start, x_goal):
  172. return math.hypot(x_goal.x - x_start.x, x_goal.y - x_start.y)
  173. @staticmethod
  174. def Cost(node):
  175. cost = 0.0
  176. if node.parent is None:
  177. return cost
  178. while node.parent:
  179. cost += math.hypot(node.x - node.parent.x, node.y - node.parent.y)
  180. node = node.parent
  181. return cost
  182. @staticmethod
  183. def get_distance_and_angle(node_start, node_end):
  184. dx = node_end.x - node_start.x
  185. dy = node_end.y - node_start.y
  186. return math.hypot(dx, dy), math.atan2(dy, dx)
  187. def animation(self):
  188. plt.cla()
  189. self.plot_grid("rrt*-Smart, N = " + str(self.iter_max))
  190. plt.gcf().canvas.mpl_connect(
  191. 'key_release_event',
  192. lambda event: [exit(0) if event.key == 'escape' else None])
  193. for node in self.V:
  194. if node.parent:
  195. plt.plot([node.x, node.parent.x], [node.y, node.parent.y], "-g")
  196. if self.beacons:
  197. theta = np.arange(0, 2 * math.pi, 0.1)
  198. r = self.beacons_radius
  199. for v in self.beacons:
  200. x = v[0] + r * np.cos(theta)
  201. y = v[1] + r * np.sin(theta)
  202. plt.plot(x, y, linestyle='--', linewidth=2, color='darkorange')
  203. plt.pause(0.01)
  204. def plot_grid(self, name):
  205. for (ox, oy, w, h) in self.obs_boundary:
  206. self.ax.add_patch(
  207. patches.Rectangle(
  208. (ox, oy), w, h,
  209. edgecolor='black',
  210. facecolor='black',
  211. fill=True
  212. )
  213. )
  214. for (ox, oy, w, h) in self.obs_rectangle:
  215. self.ax.add_patch(
  216. patches.Rectangle(
  217. (ox, oy), w, h,
  218. edgecolor='black',
  219. facecolor='gray',
  220. fill=True
  221. )
  222. )
  223. for (ox, oy, r) in self.obs_circle:
  224. self.ax.add_patch(
  225. patches.Circle(
  226. (ox, oy), r,
  227. edgecolor='black',
  228. facecolor='gray',
  229. fill=True
  230. )
  231. )
  232. plt.plot(self.x_start.x, self.x_start.y, "bs", linewidth=3)
  233. plt.plot(self.x_goal.x, self.x_goal.y, "rs", linewidth=3)
  234. plt.title(name)
  235. plt.axis("equal")
  236. def main():
  237. x_start = (18, 8) # Starting node
  238. x_goal = (37, 18) # Goal node
  239. rrt = RrtStarSmart(x_start, x_goal, 1.5, 0.10, 0, 1000)
  240. rrt.planning()
  241. if __name__ == '__main__':
  242. main()