rrtstar3D.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import time
  9. import os
  10. import sys
  11. sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../../Sampling-based Planning/")
  12. from rrt_3D.env3D import env
  13. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset
  14. class rrtstar():
  15. def __init__(self):
  16. self.env = env()
  17. self.Parent = defaultdict(lambda: defaultdict(dict))
  18. self.E = edgeset()
  19. self.V = []
  20. self.i = 0
  21. self.maxiter = 10000
  22. self.stepsize = 0.5
  23. self.Path = []
  24. def wireup(self,x,y):
  25. self.E.add_edge([x,y]) # add edge
  26. self.Parent[str(x[0])][str(x[1])][str(x[2])] = y
  27. def removewire(self,xnear):
  28. xparent = self.Parent[str(xnear[0])][str(xnear[1])][str(xnear[2])]
  29. a = [xnear,xparent]
  30. self.E.remove_edge(a) # remove and replace old the connection
  31. def run(self):
  32. self.V.append(self.env.start)
  33. ind = 0
  34. xnew = self.env.start
  35. while ind < self.maxiter and getDist(xnew,self.env.goal) > 1:
  36. #while ind < self.maxiter:
  37. xrand = sampleFree(self)
  38. xnearest = nearest(self,xrand)
  39. xnew = steer(self,xnearest,xrand)
  40. if not isCollide(self,xnearest,xnew):
  41. Xnear = near(self,xnew)
  42. self.V.append(xnew) # add point
  43. # visualization(self)
  44. # minimal path and minimal cost
  45. xmin,cmin = xnearest,cost(self,xnearest) + getDist(xnearest,xnew)
  46. # connecting along minimal cost path
  47. if self.i == 0:
  48. c1 = cost(self,Xnear) + getDist(xnew,Xnear)
  49. if not isCollide(self,xnew,Xnear) and c1 < cmin:
  50. xmin,cmin = Xnear,c1
  51. self.wireup(xnew,xmin)
  52. else:
  53. for xnear in Xnear:
  54. c1 = cost(self,xnear) + getDist(xnew,xnear)
  55. if not isCollide(self,xnew,xnear) and c1 < cmin:
  56. xmin,cmin = xnear,c1
  57. self.wireup(xnew,xmin)
  58. # rewire
  59. for xnear in Xnear:
  60. c2 = cost(self,xnew) + getDist(xnew,xnear)
  61. if not isCollide(self,xnew,xnear) and c2 < cost(self,xnear):
  62. self.removewire(xnear)
  63. self.wireup(xnear,xnew)
  64. self.i += 1
  65. ind += 1
  66. # when the goal is reached
  67. if getDist(xnew,self.env.goal) <= 1:
  68. self.wireup(self.env.goal,xnew)
  69. self.Path,self.D = path(self)
  70. visualization(self)
  71. print('Total distance = '+str(self.D))
  72. if __name__ == '__main__':
  73. p = rrtstar()
  74. starttime = time.time()
  75. p.run()
  76. print('time used = ' + str(time.time()-starttime))