rrtstar3D.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. This is rrt star code for 3D
  3. @author: yue qi
  4. """
  5. import numpy as np
  6. from numpy.matlib import repmat
  7. from collections import defaultdict
  8. import time
  9. import matplotlib.pyplot as plt
  10. import os
  11. import sys
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
  13. from rrt_3D.env3D import env
  14. from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset, \
  15. hash3D, dehash
  16. class rrtstar():
  17. def __init__(self):
  18. self.env = env()
  19. self.Parent = {}
  20. self.E = edgeset()
  21. self.V = []
  22. self.i = 0
  23. self.maxiter = 10000 # at least 4000 in this env
  24. self.stepsize = 0.5
  25. self.gamma = 500
  26. self.eta = 2 * self.stepsize
  27. self.Path = []
  28. self.done = False
  29. def wireup(self, x, y):
  30. self.E.add_edge([x, y]) # add edge
  31. self.Parent[hash3D(x)] = y
  32. def removewire(self, xnear):
  33. xparent = self.Parent[hash3D(xnear)]
  34. a = [xnear, xparent]
  35. self.E.remove_edge(a) # remove and replace old the connection
  36. def reached(self):
  37. self.done = True
  38. xn = near(self, self.env.goal)
  39. c = [cost(self, x) for x in xn]
  40. xncmin = xn[np.argmin(c)]
  41. self.wireup(self.env.goal, xncmin)
  42. self.V.append(self.env.goal)
  43. self.Path, self.D = path(self)
  44. def run(self):
  45. self.V.append(self.env.start)
  46. self.ind = 0
  47. xnew = self.env.start
  48. print('start rrt*... ')
  49. self.fig = plt.figure(figsize=(10, 8))
  50. while self.ind < self.maxiter:
  51. xrand = sampleFree(self)
  52. xnearest = nearest(self, xrand)
  53. xnew = steer(self, xnearest, xrand)
  54. if not isCollide(self, xnearest, xnew):
  55. Xnear = near(self, xnew)
  56. self.V.append(xnew) # add point
  57. visualization(self)
  58. # minimal path and minimal cost
  59. xmin, cmin = xnearest, cost(self, xnearest) + getDist(xnearest, xnew)
  60. # connecting along minimal cost path
  61. for xnear in Xnear:
  62. c1 = cost(self, xnear) + getDist(xnew, xnear)
  63. if not isCollide(self, xnew, xnear) and c1 < cmin:
  64. xmin, cmin = xnear, c1
  65. self.wireup(xnew, xmin)
  66. # rewire
  67. for xnear in Xnear:
  68. c2 = cost(self, xnew) + getDist(xnew, xnear)
  69. if not isCollide(self, xnew, xnear) and c2 < cost(self, xnear):
  70. self.removewire(xnear)
  71. self.wireup(xnear, xnew)
  72. self.i += 1
  73. self.ind += 1
  74. # max sample reached
  75. self.reached()
  76. print('time used = ' + str(time.time() - starttime))
  77. print('Total distance = ' + str(self.D))
  78. visualization(self)
  79. plt.show()
  80. if __name__ == '__main__':
  81. p = rrtstar()
  82. starttime = time.time()
  83. p.run()