rrtstar3D.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import time
  9. import matplotlib.pyplot as plt
  10. import os
  11. import sys
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
  13. from rrt_3D.env3D import env
  14. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset
  15. class rrtstar():
  16. def __init__(self):
  17. self.env = env()
  18. self.Parent = {}
  19. self.E = edgeset()
  20. self.V = []
  21. self.i = 0
  22. self.maxiter = 5000 # at least 2000 in this env
  23. self.stepsize = 0.5
  24. self.gamma = 500
  25. self.eta = 2*self.stepsize
  26. self.Path = []
  27. self.done = False
  28. self.x0 = tuple(self.env.start)
  29. self.xt = tuple(self.env.goal)
  30. self.V.append(self.x0)
  31. self.ind = 0
  32. def wireup(self,x,y):
  33. self.E.add_edge([x,y]) # add edge
  34. self.Parent[x] = y
  35. def removewire(self,xnear):
  36. xparent = self.Parent[xnear]
  37. a = [xnear,xparent]
  38. self.E.remove_edge(a) # remove and replace old the connection
  39. def reached(self):
  40. self.done = True
  41. goal = self.xt
  42. xn = near(self,self.env.goal)
  43. c = [cost(self,tuple(x)) for x in xn]
  44. xncmin = xn[np.argmin(c)]
  45. self.wireup(goal , tuple(xncmin))
  46. self.V.append(goal)
  47. self.Path,self.D = path(self)
  48. def run(self):
  49. xnew = self.x0
  50. print('start rrt*... ')
  51. self.fig = plt.figure(figsize = (10,8))
  52. while self.ind < self.maxiter:
  53. xrand = sampleFree(self)
  54. xnearest = nearest(self,xrand)
  55. xnew = steer(self,xnearest,xrand)
  56. collide, _ = isCollide(self,xnearest,xnew)
  57. if not collide:
  58. Xnear = near(self,xnew)
  59. self.V.append(xnew) # add point
  60. # visualization(self)
  61. # minimal path and minimal Cost
  62. xmin, cmin = xnearest, cost(self, xnearest) + getDist(xnearest, xnew)
  63. # connecting along minimal Cost path
  64. for xnear in Xnear:
  65. xnear = tuple(xnear)
  66. c1 = cost(self, xnear) + getDist(xnew, xnear)
  67. collide, _ = isCollide(self, xnew, xnear)
  68. if not collide and c1 < cmin:
  69. xmin, cmin = xnear, c1
  70. self.wireup(xnew, xmin)
  71. # rewire
  72. for xnear in Xnear:
  73. xnear = tuple(xnear)
  74. c2 = cost(self, xnew) + getDist(xnew, xnear)
  75. collide, _ = isCollide(self, xnew, xnear)
  76. if not collide and c2 < cost(self, xnear):
  77. self.removewire(xnear)
  78. self.wireup(xnear, xnew)
  79. self.i += 1
  80. self.ind += 1
  81. # max sample reached
  82. self.reached()
  83. print('time used = ' + str(time.time()-starttime))
  84. print('Total distance = '+str(self.D))
  85. visualization(self)
  86. plt.show()
  87. if __name__ == '__main__':
  88. p = rrtstar()
  89. starttime = time.time()
  90. p.run()