rrt3D.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import time
  9. import matplotlib.pyplot as plt
  10. import os
  11. import sys
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling_based_Planning/")
  13. from rrt_3D.env3D import env
  14. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset
  15. class rrt():
  16. def __init__(self):
  17. self.env = env()
  18. self.Parent = {}
  19. self.V = []
  20. self.E = edgeset()
  21. self.i = 0
  22. self.maxiter = 10000
  23. self.stepsize = 0.5
  24. self.Path = []
  25. self.done = False
  26. self.x0 = tuple(self.env.start)
  27. self.xt = tuple(self.env.goal)
  28. self.Flag = None
  29. self.ind = 0
  30. self.fig = plt.figure(figsize=(10, 8))
  31. def wireup(self, x, y):
  32. self.E.add_edge([x, y]) # add edge
  33. self.Parent[x] = y
  34. def run(self, Reversed = True, xrobot = None):
  35. if Reversed:
  36. if xrobot is None:
  37. self.x0 = tuple(self.env.goal)
  38. self.xt = tuple(self.env.start)
  39. else:
  40. self.x0 = tuple(self.env.goal)
  41. self.xt = xrobot
  42. xnew = self.env.goal
  43. else:
  44. xnew = self.env.start
  45. self.V.append(self.x0)
  46. while self.ind < self.maxiter:
  47. xrand = sampleFree(self)
  48. xnearest = nearest(self, xrand)
  49. xnew = steer(self, xnearest, xrand)
  50. collide, _ = isCollide(self, xnearest, xnew)
  51. if not collide:
  52. self.V.append(xnew) # add point
  53. if self.Flag is not None:
  54. self.Flag[xnew] = 'Valid'
  55. self.wireup(xnew, xnearest)
  56. if getDist(xnew, self.xt) <= self.stepsize:
  57. self.wireup(self.xt, xnew)
  58. self.Path, D = path(self)
  59. print('Total distance = ' + str(D))
  60. if self.Flag is not None:
  61. self.Flag[self.xt] = 'Valid'
  62. break
  63. # visualization(self)
  64. self.i += 1
  65. self.ind += 1
  66. # if the goal is really reached
  67. # self.done = True
  68. # visualization(self)
  69. # plt.show()
  70. if __name__ == '__main__':
  71. p = rrt()
  72. starttime = time.time()
  73. p.run()
  74. print('time used = ' + str(time.time() - starttime))