rrtstar3D.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from env3D import env
  8. from collections import defaultdict
  9. import pyrr as pyrr
  10. from utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path
  11. import time
  12. class rrtstar():
  13. def __init__(self):
  14. self.env = env()
  15. self.Parent = defaultdict(lambda: defaultdict(dict))
  16. self.V = []
  17. self.E = []
  18. self.i = 0
  19. self.maxiter = 10000
  20. self.stepsize = 0.5
  21. self.Path = []
  22. def wireup(self,x,y):
  23. self.E.append([x,y]) # add edge
  24. self.Parent[str(x[0])][str(x[1])][str(x[2])] = y
  25. def removewire(self,xnear):
  26. xparent = self.Parent[str(xnear[0])][str(xnear[1])][str(xnear[2])]
  27. a = np.array([xnear,xparent])
  28. self.E = [xx for xx in self.E if not (xx==a).all()] # remove and replace old the connection
  29. def run(self):
  30. self.V.append(self.env.start)
  31. ind = 0
  32. xnew = self.env.start
  33. while ind < self.maxiter and getDist(xnew,self.env.goal) > 1:
  34. xrand = sampleFree(self)
  35. xnearest = nearest(self,xrand)
  36. xnew = steer(self,xnearest,xrand)
  37. if not isCollide(self,xnearest,xnew):
  38. Xnear = near(self,xnew)
  39. self.V.append(xnew) # add point
  40. # visualization(self)
  41. # minimal path and minimal cost
  42. xmin,cmin = xnearest,cost(self,xnearest) + getDist(xnearest,xnew)
  43. # connecting along minimal cost path
  44. if self.i == 0:
  45. c1 = cost(self,Xnear) + getDist(xnew,Xnear)
  46. if not isCollide(self,xnew,Xnear) and c1 < cmin:
  47. xmin,cmin = Xnear,c1
  48. self.wireup(xnew,xmin)
  49. else:
  50. for xnear in Xnear:
  51. c1 = cost(self,xnear) + getDist(xnew,xnear)
  52. if not isCollide(self,xnew,xnear) and c1 < cmin:
  53. xmin,cmin = xnear,c1
  54. self.wireup(xnew,xmin)
  55. # rewire
  56. for xnear in Xnear:
  57. c2 = cost(self,xnew) + getDist(xnew,xnear)
  58. if not isCollide(self,xnew,xnear) and c2 < cost(self,xnear):
  59. self.removewire(xnear)
  60. self.wireup(xnear,xnew)
  61. self.i += 1
  62. ind += 1
  63. if getDist(xnew,self.env.goal) <= 1:
  64. self.wireup(self.env.goal,xnew)
  65. self.Path,D = path(self)
  66. print('Total distance = '+str(D))
  67. visualization(self)
  68. if __name__ == '__main__':
  69. p = rrtstar()
  70. starttime = time.time()
  71. p.run()
  72. print('time used = ' + str(time.time()-starttime))