rrt3D.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import time
  9. import matplotlib.pyplot as plt
  10. import os
  11. import sys
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
  13. from rrt_3D.env3D import env
  14. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset
  15. class rrtstar():
  16. def __init__(self):
  17. self.env = env()
  18. # self.Parent = defaultdict(lambda: defaultdict(dict))
  19. self.Parent = {}
  20. self.V = []
  21. self.E = edgeset()
  22. self.i = 0
  23. self.maxiter = 10000
  24. self.stepsize = 0.5
  25. self.Path = []
  26. self.done = False
  27. def wireup(self, x, y):
  28. self.E.add_edge([x, y]) # add edge
  29. self.Parent[x] = y
  30. def run(self):
  31. self.V.append(tuple(self.env.start))
  32. self.ind = 0
  33. self.fig = plt.figure(figsize=(10, 8))
  34. xnew = self.env.start
  35. while self.ind < self.maxiter and getDist(xnew, self.env.goal) > self.stepsize:
  36. xrand = sampleFree(self)
  37. xnearest = nearest(self, xrand)
  38. xnew = steer(self, xnearest, xrand)
  39. collide, _ = isCollide(self, xnearest, xnew)
  40. if not collide:
  41. self.V.append(xnew) # add point
  42. self.wireup(xnew, xnearest)
  43. if getDist(xnew, self.env.goal) <= self.stepsize:
  44. goal = tuple(self.env.goal)
  45. self.wireup(goal, xnew)
  46. self.Path, D = path(self)
  47. print('Total distance = ' + str(D))
  48. # visualization(self)
  49. self.i += 1
  50. self.ind += 1
  51. # if the goal is really reached
  52. self.done = True
  53. visualization(self)
  54. plt.show()
  55. if __name__ == '__main__':
  56. p = rrtstar()
  57. starttime = time.time()
  58. p.run()
  59. print('time used = ' + str(time.time() - starttime))