LP_Astar3D.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import os
  4. import sys
  5. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Search-based Planning/")
  6. from Search_3D.env3D import env
  7. from Search_3D import Astar3D
  8. from Search_3D.utils3D import getDist, getRay, g_Space, Heuristic, getNearest, isinbound, isinball, \
  9. cost, obstacleFree
  10. from Search_3D.plot_util3D import visualization
  11. import queue
  12. import pyrr
  13. import time
  14. class Lifelong_Astar(object):
  15. def __init__(self,resolution = 1):
  16. self.Alldirec = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1],
  17. [-1, 0, 0], [0, -1, 0], [0, 0, -1], [-1, -1, 0], [-1, 0, -1], [0, -1, -1],
  18. [-1, -1, -1],
  19. [1, -1, 0], [-1, 1, 0], [1, 0, -1], [-1, 0, 1], [0, 1, -1], [0, -1, 1],
  20. [1, -1, -1], [-1, 1, -1], [-1, -1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]])
  21. self.env = env(resolution=resolution)
  22. self.g = g_Space(self)
  23. self.start, self.goal = getNearest(self.g, self.env.start), getNearest(self.g, self.env.goal)
  24. self.x0, self.xt = self.start, self.goal
  25. self.v = g_Space(self) # rhs(.) = g(.) = inf
  26. self.v[self.start] = 0 # rhs(x0) = 0
  27. self.h = Heuristic(self.g, self.goal)
  28. self.OPEN = queue.QueuePrior() # store [point,priority]
  29. self.OPEN.put(self.x0, [self.h[self.x0],0])
  30. self.CLOSED = set()
  31. # used for A*
  32. self.done = False
  33. self.Path = []
  34. self.V = []
  35. self.ind = 0
  36. # initialize children list
  37. self.CHILDREN = {}
  38. self.getCHILDRENset()
  39. # initialize cost list
  40. self.COST = {}
  41. _ = self.costset()
  42. def costset(self):
  43. NodeToChange = set()
  44. for xi in self.CHILDREN.keys():
  45. children = self.CHILDREN[xi]
  46. toUpdate = [self.cost(xj,xi) for xj in children]
  47. if xi in self.COST:
  48. # if the old cost not equal to new cost
  49. diff = np.not_equal(self.COST[xi],toUpdate)
  50. cd = np.array(children)[diff]
  51. for i in cd:
  52. NodeToChange.add(tuple(i))
  53. self.COST[xi] = toUpdate
  54. else:
  55. self.COST[xi] = toUpdate
  56. return NodeToChange
  57. def getCOSTset(self,xi,xj):
  58. ind, children = 0, self.CHILDREN[xi]
  59. for i in children:
  60. if i == xj:
  61. return self.COST[xi][ind]
  62. ind += 1
  63. def children(self, x):
  64. allchild = []
  65. resolution = self.env.resolution
  66. for direc in self.Alldirec:
  67. child = tuple(map(np.add,x,np.multiply(direc,resolution)))
  68. if isinbound(self.env.boundary,child):
  69. allchild.append(child)
  70. return allchild
  71. def getCHILDRENset(self):
  72. for xi in self.g.keys():
  73. self.CHILDREN[xi] = self.children(xi)
  74. def isCollide(self, x, child):
  75. ray , dist = getRay(x, child) , getDist(x, child)
  76. if not isinbound(self.env.boundary,child):
  77. return True, dist
  78. for i in self.env.AABB_pyrr:
  79. shot = pyrr.geometric_tests.ray_intersect_aabb(ray, i)
  80. if shot is not None:
  81. dist_wall = getDist(x, shot)
  82. if dist_wall <= dist: # collide
  83. return True, dist
  84. for i in self.env.balls:
  85. if isinball(i, child):
  86. return True, dist
  87. shot = pyrr.geometric_tests.ray_intersect_sphere(ray, i)
  88. if shot != []:
  89. dists_ball = [getDist(x, j) for j in shot]
  90. if all(dists_ball <= dist): # collide
  91. return True, dist
  92. return False, dist
  93. def cost(self, x, y):
  94. collide, dist = self.isCollide(x, y)
  95. if collide: return np.inf
  96. else: return dist
  97. def key(self,xi,epsilion = 1):
  98. return [min(self.g[xi],self.v[xi]) + epsilion*self.h[xi],min(self.g[xi],self.v[xi])]
  99. def path(self):
  100. path = []
  101. x = self.xt
  102. start = self.x0
  103. ind = 0
  104. while x != start:
  105. j = x
  106. nei = self.CHILDREN[x]
  107. gset = [self.g[xi] for xi in nei]
  108. # collision check and make g cost inf
  109. for i in range(len(nei)):
  110. if self.isCollide(nei[i],j)[0]:
  111. gset[i] = np.inf
  112. parent = nei[np.argmin(gset)]
  113. path.append([x, parent])
  114. x = parent
  115. if ind > 100:
  116. break
  117. ind += 1
  118. return path
  119. #------------------Lifelong Plannning A*
  120. def UpdateMembership(self, xi, xparent=None):
  121. if xi != self.x0:
  122. self.v[xi] = min([self.g[j] + self.getCOSTset(xi,j) for j in self.CHILDREN[xi]])
  123. self.OPEN.check_remove(xi)
  124. if self.g[xi] != self.v[xi]:
  125. self.OPEN.put(xi,self.key(xi))
  126. def ComputePath(self):
  127. print('computing path ...')
  128. while self.key(self.xt) > self.OPEN.top_key() or self.v[self.xt] != self.g[self.xt]:
  129. xi = self.OPEN.get()
  130. # if g > rhs, overconsistent
  131. if self.g[xi] > self.v[xi]:
  132. self.g[xi] = self.v[xi]
  133. # add xi to expanded node set
  134. if xi not in self.CLOSED:
  135. self.V.append(xi)
  136. self.CLOSED.add(xi)
  137. else: # underconsistent and consistent
  138. self.g[xi] = np.inf
  139. self.UpdateMembership(xi)
  140. for xj in self.CHILDREN[xi]:
  141. self.UpdateMembership(xj)
  142. # visualization(self)
  143. self.ind += 1
  144. self.Path = self.path()
  145. self.done = True
  146. visualization(self)
  147. plt.pause(2)
  148. def change_env(self):
  149. self.env.New_block()
  150. self.done = False
  151. self.Path = []
  152. self.CLOSED = set()
  153. N = self.costset()
  154. for xi in N:
  155. self.UpdateMembership(xi)
  156. if __name__ == '__main__':
  157. sta = time.time()
  158. Astar = Lifelong_Astar(0.5)
  159. Astar.ComputePath()
  160. Astar.change_env()
  161. Astar.ComputePath()
  162. plt.show()
  163. print(time.time() - sta)