rrtstar3D.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import time
  9. import os
  10. import sys
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
  12. from rrt_3D.env3D import env
  13. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset, hash3D, dehash
  14. class rrtstar():
  15. def __init__(self):
  16. self.env = env()
  17. self.Parent = {}
  18. self.E = edgeset()
  19. self.V = []
  20. self.i = 0
  21. self.maxiter = 4000 # at least 4000 in this env
  22. self.stepsize = 0.5
  23. self.gamma = 500
  24. self.eta = 1.1*self.stepsize
  25. self.Path = []
  26. self.done = False
  27. def wireup(self,x,y):
  28. self.E.add_edge([x,y]) # add edge
  29. self.Parent[hash3D(x)] = y
  30. def removewire(self,xnear):
  31. xparent = self.Parent[hash3D(xnear)]
  32. a = [xnear,xparent]
  33. self.E.remove_edge(a) # remove and replace old the connection
  34. #self.Parent.pop(hash3D(xnear), None)
  35. def reached(self):
  36. self.done = True
  37. xn = near(self,self.env.goal)
  38. c = [cost(self,x) for x in xn]
  39. xncmin = xn[np.argmin(c)]
  40. self.wireup(self.env.goal,xncmin)
  41. self.V.append(self.env.goal)
  42. self.Path,self.D = path(self)
  43. def run(self):
  44. self.V.append(self.env.start)
  45. ind = 0
  46. xnew = self.env.start
  47. print('start rrt*... ')
  48. while ind < self.maxiter:
  49. xrand = sampleFree(self)
  50. xnearest = nearest(self,xrand)
  51. xnew = steer(self,xnearest,xrand)
  52. if not isCollide(self,xnearest,xnew):
  53. Xnear = near(self,xnew)
  54. self.V.append(xnew) # add point
  55. # visualization(self)
  56. # minimal path and minimal cost
  57. xmin, cmin = xnearest, cost(self, xnearest) + getDist(xnearest, xnew)
  58. # connecting along minimal cost path
  59. if self.i == 0:
  60. c1 = cost(self, Xnear) + getDist(xnew, Xnear)
  61. if not isCollide(self, xnew, Xnear) and c1 < cmin:
  62. xmin, cmin = Xnear, c1
  63. self.wireup(xnew, xmin)
  64. else:
  65. for xnear in Xnear:
  66. c1 = cost(self, xnear) + getDist(xnew, xnear)
  67. if not isCollide(self, xnew, xnear) and c1 < cmin:
  68. xmin, cmin = xnear, c1
  69. self.wireup(xnew, xmin)
  70. # rewire
  71. for xnear in Xnear:
  72. c2 = cost(self, xnew) + getDist(xnew, xnear)
  73. if not isCollide(self, xnew, xnear) and c2 < cost(self, xnear):
  74. self.removewire(xnear)
  75. self.wireup(xnear, xnew)
  76. self.i += 1
  77. ind += 1
  78. # max sample reached
  79. self.reached()
  80. print('time used = ' + str(time.time()-starttime))
  81. print('Total distance = '+str(self.D))
  82. visualization(self)
  83. if __name__ == '__main__':
  84. p = rrtstar()
  85. starttime = time.time()
  86. p.run()