dynamic_rrt3D.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. This is dynamic rrt code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import copy
  9. import time
  10. import matplotlib.pyplot as plt
  11. import os
  12. import sys
  13. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
  14. from rrt_3D.env3D import env
  15. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, cost, path, edgeset, isinbound, isinside
  16. from rrt_3D.rrt3D import rrt
  17. from rrt_3D.plot_util3D import make_get_proj, draw_block_list, draw_Spheres, draw_obb, draw_line, make_transparent
  18. class dynamic_rrt_3D():
  19. def __init__(self):
  20. self.env = env()
  21. self.x0, self.xt = tuple(self.env.start), tuple(self.env.goal)
  22. self.qrobot = self.x0
  23. self.current = tuple(self.env.start)
  24. self.stepsize = 0.25
  25. self.maxiter = 10000
  26. self.GoalProb = 0.05 # probability biased to the goal
  27. self.WayPointProb = 0.05 # probability falls back on to the way points
  28. self.V = [] # vertices
  29. self.Parent = {} # parent child relation
  30. self.Edge = set() # edge relation (node, parent node) tuple
  31. self.Path = []
  32. self.flag = {}# flag dictionary
  33. self.ind = 0
  34. self.i = 0
  35. #--------Dynamic RRT algorithm
  36. def RegrowRRT(self):
  37. self.TrimRRT()
  38. self.GrowRRT()
  39. def TrimRRT(self):
  40. S = []
  41. i = 1
  42. print('trimming...')
  43. while i < len(self.V):
  44. qi = self.V[i]
  45. qp = self.Parent[qi]
  46. if self.flag[qp] == 'Invalid':
  47. self.flag[qi] = 'Invalid'
  48. if self.flag[qi] != 'Invalid':
  49. S.append(qi)
  50. i += 1
  51. self.CreateTreeFromNodes(S)
  52. print('trimming complete...')
  53. def InvalidateNodes(self, obstacle):
  54. Edges = self.FindAffectedEdges(obstacle)
  55. for edge in Edges:
  56. qe = self.ChildEndpointNode(edge)
  57. self.flag[qe] = 'Invalid'
  58. #--------Extend RRT algorithm-----
  59. def initRRT(self):
  60. self.V.append(self.x0)
  61. self.flag[self.x0] = 'Valid'
  62. def GrowRRT(self):
  63. print('growing')
  64. qnew = self.x0
  65. tree = None
  66. distance_threshold = self.stepsize
  67. self.ind = 0
  68. while self.ind <= self.maxiter:
  69. qtarget = self.ChooseTarget()
  70. qnearest = self.Nearest(tree, qtarget)
  71. qnew, collide = self.Extend(qnearest, qtarget)
  72. if not collide:
  73. self.AddNode(qnearest, qnew)
  74. if getDist(qnew, self.xt) < distance_threshold:
  75. self.AddNode(qnearest, self.xt)
  76. self.flag[self.xt] = 'Valid'
  77. break
  78. self.i += 1
  79. self.ind += 1
  80. # self.visualization()
  81. print('growing complete...')
  82. def ChooseTarget(self):
  83. # return the goal, or randomly choose a state in the waypoints based on probs
  84. p = np.random.uniform()
  85. if len(self.V) == 1:
  86. i = 0
  87. else:
  88. i = np.random.randint(0, high = len(self.V) - 1)
  89. if 0 < p < self.GoalProb:
  90. return self.xt
  91. elif self.GoalProb < p < self.GoalProb + self.WayPointProb:
  92. return self.V[i]
  93. elif self.GoalProb + self.WayPointProb < p < 1:
  94. return tuple(self.RandomState())
  95. def RandomState(self):
  96. # generate a random, obstacle free state
  97. xrand = sampleFree(self, bias=0)
  98. return xrand
  99. def AddNode(self, nearest, extended):
  100. self.V.append(extended)
  101. self.Parent[extended] = nearest
  102. self.Edge.add((extended, nearest))
  103. self.flag[extended] = 'Valid'
  104. def Nearest(self, tree, target):
  105. # TODO use kdTree to speed up search
  106. return nearest(self, target, isset=True)
  107. def Extend(self, nearest, target):
  108. extended, dist = steer(self, nearest, target, DIST = True)
  109. collide, _ = isCollide(self, nearest, target, dist)
  110. return extended, collide
  111. #--------Main function
  112. def Main(self):
  113. # qstart = qgoal
  114. self.x0 = tuple(self.env.goal)
  115. # qgoal = qrobot
  116. self.xt = tuple(self.env.start)
  117. self.initRRT()
  118. self.GrowRRT()
  119. self.Path, D = path(self)
  120. self.done = True
  121. self.visualization()
  122. plt.show()
  123. t = 0
  124. while True:
  125. # move the block while the robot is moving
  126. new, old = self.env.move_block(a=[0, 0, -0.2], mode='translation')
  127. self.InvalidateNodes(new)
  128. # if solution path contains invalid node
  129. self.done = True
  130. self.visualization()
  131. plt.show()
  132. invalid = self.PathisInvalid(self.Path)
  133. if invalid:
  134. self.done = False
  135. self.RegrowRRT()
  136. self.Path = []
  137. self.Path, D = path(self)
  138. if t == 8:
  139. break
  140. #--------Additional utility functions
  141. def FindAffectedEdges(self, obstacle):
  142. # scan the graph for the changed edges in the tree.
  143. # return the end point and the affected
  144. Affectededges = []
  145. for e in self.Edge:
  146. child, parent = e
  147. collide, _ = isCollide(self, child, parent)
  148. if collide:
  149. Affectededges.append(e)
  150. return Affectededges
  151. def ChildEndpointNode(self, edge):
  152. return edge[0]
  153. def CreateTreeFromNodes(self, Nodes):
  154. self.V = []
  155. Parent = {}
  156. edges = set()
  157. for v in Nodes:
  158. self.V.append(v)
  159. Parent[v] = self.Parent[v]
  160. edges.add((v, Parent[v]))
  161. self.Parent = Parent
  162. self.Edge = edges
  163. def PathisInvalid(self, path):
  164. for edge in path:
  165. if self.flag[tuple(edge[0])] == 'Invalid' or self.flag[tuple(edge[1])] == 'Invalid':
  166. return True
  167. def path(self, Path=[], dist=0):
  168. x = self.xt
  169. while x != self.x0:
  170. x2 = self.Parent[x]
  171. Path.append(np.array([x, x2]))
  172. dist += getDist(x, x2)
  173. x = x2
  174. return Path, dist
  175. #--------Visualization specialized for dynamic RRT
  176. def visualization(self):
  177. if self.ind % 100 == 0 or self.done:
  178. V = np.array(self.V)
  179. Path = np.array(self.Path)
  180. start = self.env.start
  181. goal = self.env.goal
  182. edges = []
  183. for i in self.Parent:
  184. edges.append([i,self.Parent[i]])
  185. ax = plt.subplot(111, projection='3d')
  186. # ax.view_init(elev=0.+ 0.03*initparams.ind/(2*np.pi), azim=90 + 0.03*initparams.ind/(2*np.pi))
  187. # ax.view_init(elev=0., azim=90.)
  188. ax.view_init(elev=8., azim=120.)
  189. ax.clear()
  190. # drawing objects
  191. draw_Spheres(ax, self.env.balls)
  192. draw_block_list(ax, self.env.blocks)
  193. if self.env.OBB is not None:
  194. draw_obb(ax, self.env.OBB)
  195. draw_block_list(ax, np.array([self.env.boundary]), alpha=0)
  196. draw_line(ax, edges, visibility=0.75, color='g')
  197. draw_line(ax, Path, color='r')
  198. # if len(V) > 0:
  199. # ax.scatter3D(V[:, 0], V[:, 1], V[:, 2], s=2, color='g', )
  200. ax.plot(start[0:1], start[1:2], start[2:], 'go', markersize=7, markeredgecolor='k')
  201. ax.plot(goal[0:1], goal[1:2], goal[2:], 'ro', markersize=7, markeredgecolor='k')
  202. # adjust the aspect ratio
  203. xmin, xmax = self.env.boundary[0], self.env.boundary[3]
  204. ymin, ymax = self.env.boundary[1], self.env.boundary[4]
  205. zmin, zmax = self.env.boundary[2], self.env.boundary[5]
  206. dx, dy, dz = xmax - xmin, ymax - ymin, zmax - zmin
  207. ax.get_proj = make_get_proj(ax, 1 * dx, 1 * dy, 2 * dy)
  208. make_transparent(ax)
  209. #plt.xlabel('x')
  210. #plt.ylabel('y')
  211. ax.set_axis_off()
  212. plt.pause(0.0001)
  213. if __name__ == '__main__':
  214. rrt = dynamic_rrt_3D()
  215. rrt.Main()