rrtstar3D.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from rrt_3D.env3D import env
  8. from collections import defaultdict
  9. import pyrr as pyrr
  10. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path
  11. import time
  12. class rrtstar():
  13. def __init__(self):
  14. self.env = env()
  15. self.Parent = defaultdict(lambda: defaultdict(dict))
  16. self.V = []
  17. self.E = []
  18. self.i = 0
  19. self.maxiter = 10000
  20. self.stepsize = 0.5
  21. self.Path = []
  22. def wireup(self,x,y):
  23. self.E.append([x,y]) # add edge
  24. self.Parent[str(x[0])][str(x[1])][str(x[2])] = y
  25. def removewire(self,xnear):
  26. xparent = self.Parent[str(xnear[0])][str(xnear[1])][str(xnear[2])]
  27. a = np.array([xnear,xparent])
  28. self.E = [xx for xx in self.E if not (xx==a).all()] # remove and replace old the connection
  29. def run(self):
  30. self.V.append(self.env.start)
  31. ind = 0
  32. xnew = self.env.start
  33. while ind < self.maxiter and getDist(xnew,self.env.goal) > 1:
  34. xrand = sampleFree(self)
  35. xnearest = nearest(self,xrand)
  36. xnew = steer(self,xnearest,xrand)
  37. if not isCollide(self,xnearest,xnew):
  38. Xnear = near(self,xnew)
  39. self.V.append(xnew) # add point
  40. # visualization(self)
  41. # minimal path and minimal cost
  42. xmin,cmin = xnearest,cost(self,xnearest) + getDist(xnearest,xnew)
  43. # connecting along minimal cost path
  44. if self.i == 0:
  45. c1 = cost(self,Xnear) + getDist(xnew,Xnear)
  46. if not isCollide(self,xnew,Xnear) and c1 < cmin:
  47. xmin,cmin = Xnear,c1
  48. self.wireup(xnew,xmin)
  49. else:
  50. for xnear in Xnear:
  51. c1 = cost(self,xnear) + getDist(xnew,xnear)
  52. if not isCollide(self,xnew,xnear) and c1 < cmin:
  53. xmin,cmin = xnear,c1
  54. self.wireup(xnew,xmin)
  55. # rewire
  56. for xnear in Xnear:
  57. c2 = cost(self,xnew) + getDist(xnew,xnear)
  58. if not isCollide(self,xnew,xnear) and c2 < cost(self,xnear):
  59. self.removewire(xnear)
  60. self.wireup(xnear,xnew)
  61. self.i += 1
  62. ind += 1
  63. if getDist(xnew,self.env.goal) <= 1:
  64. self.wireup(self.env.goal,xnew)
  65. self.Path,D = path(self)
  66. print('Total distance = '+str(D))
  67. visualization(self)
  68. if __name__ == '__main__':
  69. p = rrtstar()
  70. starttime = time.time()
  71. p.run()
  72. print('time used = ' + str(time.time()-starttime))