rrt3D.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from env3D import env
  8. from collections import defaultdict
  9. import pyrr as pyrr
  10. from utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path
  11. import time
  12. class rrtstar():
  13. def __init__(self):
  14. self.env = env()
  15. self.Parent = defaultdict(lambda: defaultdict(dict))
  16. self.V = []
  17. self.E = []
  18. self.i = 0
  19. self.maxiter = 10000
  20. self.stepsize = 0.5
  21. self.Path = []
  22. def wireup(self,x,y):
  23. self.E.append([x,y]) # add edge
  24. self.Parent[str(x[0])][str(x[1])][str(x[2])] = y
  25. def removewire(self,xnear):
  26. xparent = self.Parent[str(xnear[0])][str(xnear[1])][str(xnear[2])]
  27. a = np.array([xnear,xparent])
  28. self.E = [xx for xx in self.E if not (xx==a).all()] # remove and replace old the connection
  29. def run(self):
  30. self.V.append(self.env.start)
  31. ind = 0
  32. xnew = self.env.start
  33. while ind < self.maxiter and getDist(xnew,self.env.goal) > 1:
  34. xrand = sampleFree(self)
  35. xnearest = nearest(self,xrand)
  36. xnew = steer(self,xnearest,xrand)
  37. if not isCollide(self,xnearest,xnew):
  38. self.V.append(xnew) # add point
  39. self.wireup(xnew,xnearest)
  40. #visualization(self)
  41. self.i += 1
  42. ind += 1
  43. if getDist(xnew,self.env.goal) <= 1:
  44. self.wireup(self.env.goal,xnew)
  45. self.Path,D = path(self)
  46. print('Total distance = '+str(D))
  47. visualization(self)
  48. if __name__ == '__main__':
  49. p = rrtstar()
  50. starttime = time.time()
  51. p.run()
  52. print('time used = ' + str(time.time()-starttime))