Anytime_Dstar3D.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # check paper of
  2. # [Likhachev2005]
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import os
  6. import sys
  7. from collections import defaultdict
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Search-based Planning/")
  9. from Search_3D.env3D import env
  10. from Search_3D.utils3D import getDist, heuristic_fun, getNearest, isinbound, \
  11. cost, children, StateSpace
  12. from Search_3D.plot_util3D import visualization
  13. from Search_3D import queue
  14. import time
  15. class Anytime_Dstar(object):
  16. def __init__(self, resolution=1):
  17. self.Alldirec = {(1, 0, 0): 1, (0, 1, 0): 1, (0, 0, 1): 1, \
  18. (-1, 0, 0): 1, (0, -1, 0): 1, (0, 0, -1): 1, \
  19. (1, 1, 0): np.sqrt(2), (1, 0, 1): np.sqrt(2), (0, 1, 1): np.sqrt(2), \
  20. (-1, -1, 0): np.sqrt(2), (-1, 0, -1): np.sqrt(2), (0, -1, -1): np.sqrt(2), \
  21. (1, -1, 0): np.sqrt(2), (-1, 1, 0): np.sqrt(2), (1, 0, -1): np.sqrt(2), \
  22. (-1, 0, 1): np.sqrt(2), (0, 1, -1): np.sqrt(2), (0, -1, 1): np.sqrt(2), \
  23. (1, 1, 1): np.sqrt(3), (-1, -1, -1): np.sqrt(3), \
  24. (1, -1, -1): np.sqrt(3), (-1, 1, -1): np.sqrt(3), (-1, -1, 1): np.sqrt(3), \
  25. (1, 1, -1): np.sqrt(3), (1, -1, 1): np.sqrt(3), (-1, 1, 1): np.sqrt(3)}
  26. self.env = env(resolution=resolution)
  27. self.settings = 'CollisionChecking' # for collision checking
  28. self.x0, self.xt = tuple(self.env.start), tuple(self.env.goal)
  29. self.OPEN = queue.MinheapPQ()
  30. self.g = {} # all g initialized at inf
  31. self.h = {}
  32. self.rhs = {self.xt: 0} # rhs(x0) = 0
  33. self.OPEN.put(self.xt, self.key(self.xt))
  34. self.INCONS = set()
  35. self.CLOSED = set()
  36. # init children set:
  37. self.CHILDREN = {}
  38. # init Cost set
  39. self.COST = defaultdict(lambda: defaultdict(dict))
  40. # for visualization
  41. self.V = set() # vertice in closed
  42. self.ind = 0
  43. self.Path = []
  44. self.done = False
  45. # epsilon in the key caculation
  46. self.epsilon = 1
  47. self.increment = 0.1
  48. self.decrement = 0.2
  49. def getcost(self, xi, xj):
  50. # use a LUT for getting the costd
  51. if xi not in self.COST:
  52. for (xj, xjcost) in children(self, xi, settings=1):
  53. self.COST[xi][xj] = cost(self, xi, xj, xjcost)
  54. # this might happen when there is a node changed.
  55. if xj not in self.COST[xi]:
  56. self.COST[xi][xj] = cost(self, xi, xj)
  57. return self.COST[xi][xj]
  58. def getchildren(self, xi):
  59. if xi not in self.CHILDREN:
  60. allchild = children(self, xi)
  61. self.CHILDREN[xi] = set(allchild)
  62. return self.CHILDREN[xi]
  63. def geth(self, xi):
  64. # when the heurisitic is first calculated
  65. if xi not in self.h:
  66. self.h[xi] = heuristic_fun(self, xi, self.x0)
  67. return self.h[xi]
  68. def getg(self, xi):
  69. if xi not in self.g:
  70. self.g[xi] = np.inf
  71. return self.g[xi]
  72. def getrhs(self, xi):
  73. if xi not in self.rhs:
  74. self.rhs[xi] = np.inf
  75. return self.rhs[xi]
  76. def updatecost(self, range_changed=None, new=None, old=None, mode=False):
  77. # scan graph for changed Cost, if Cost is changed update it
  78. CHANGED = set()
  79. for xi in self.CLOSED:
  80. if isinbound(old, xi, mode) or isinbound(new, xi, mode):
  81. newchildren = set(children(self, xi)) # B
  82. self.CHILDREN[xi] = newchildren
  83. for xj in newchildren:
  84. self.COST[xi][xj] = cost(self, xi, xj)
  85. CHANGED.add(xi)
  86. return CHANGED
  87. # def updateGraphCost(self, range_changed=None, new=None, old=None, mode=False):
  88. # # TODO scan graph for changed Cost, if Cost is changed update it
  89. # # make the graph Cost via vectorization
  90. # CHANGED = set()
  91. # Allnodes = np.array(list(self.CLOSED))
  92. # isChanged = isinbound(old, Allnodes, mode = mode, isarray = True) & \
  93. # isinbound(new, Allnodes, mode = mode, isarray = True)
  94. # Changednodes = Allnodes[isChanged]
  95. # for xi in Changednodes:
  96. # xi = tuple(xi)
  97. # CHANGED.add(xi)
  98. # self.CHILDREN[xi] = set(children(self, xi))
  99. # for xj in self.CHILDREN:
  100. # self.COST[xi][xj] = Cost(self, xi, xj)
  101. # --------------main functions for Anytime D star
  102. def key(self, s, epsilon=1):
  103. if self.getg(s) > self.getrhs(s):
  104. return [self.rhs[s] + epsilon * heuristic_fun(self, s, self.x0), self.rhs[s]]
  105. else:
  106. return [self.getg(s) + heuristic_fun(self, s, self.x0), self.getg(s)]
  107. def UpdateState(self, s):
  108. if s not in self.CLOSED:
  109. # TODO if s is not visited before
  110. self.g[s] = np.inf
  111. if s != self.xt:
  112. self.rhs[s] = min([self.getcost(s, s_p) + self.getg(s_p) for s_p in self.getchildren(s)])
  113. self.OPEN.check_remove(s)
  114. if self.getg(s) != self.getrhs(s):
  115. if s not in self.CLOSED:
  116. self.OPEN.put(s, self.key(s))
  117. else:
  118. self.INCONS.add(s)
  119. def ComputeorImprovePath(self):
  120. while self.OPEN.top_key() < self.key(self.x0, self.epsilon) or self.rhs[self.x0] != self.g[self.x0]:
  121. s = self.OPEN.get()
  122. if getDist(s, tuple(self.env.start)) < self.env.resolution:
  123. break
  124. if self.g[s] > self.rhs[s]:
  125. self.g[s] = self.rhs[s]
  126. self.CLOSED.add(s)
  127. self.V.add(s)
  128. for s_p in self.getchildren(s):
  129. self.UpdateState(s_p)
  130. else:
  131. self.g[s] = np.inf
  132. self.UpdateState(s)
  133. for s_p in self.getchildren(s):
  134. self.UpdateState(s_p)
  135. self.ind += 1
  136. def Main(self):
  137. ischanged = False
  138. islargelychanged = False
  139. t = 0
  140. self.ComputeorImprovePath()
  141. # TODO publish current epsilon sub-optimal solution
  142. self.done = True
  143. self.ind = 0
  144. self.Path = self.path()
  145. visualization(self)
  146. while True:
  147. visualization(self)
  148. if t == 20:
  149. break
  150. # change environment
  151. # new2,old2 = self.env.move_block(theta = [0,0,0.1*t], mode='rotation')
  152. new2, old2 = self.env.move_block(a=[0, 0, -0.2], mode='translation')
  153. ischanged = True
  154. # islargelychanged = True
  155. self.Path = []
  156. # update Cost with changed environment
  157. if ischanged:
  158. # CHANGED = self.updatecost(True, new2, old2, mode='obb')
  159. CHANGED = self.updatecost(True, new2, old2)
  160. for u in CHANGED:
  161. self.UpdateState(u)
  162. self.ComputeorImprovePath()
  163. ischanged = False
  164. if islargelychanged:
  165. self.epsilon += self.increment # or replan from scratch
  166. elif self.epsilon > 1:
  167. self.epsilon -= self.decrement
  168. # move states from the INCONS to OPEN
  169. # update priorities in OPEN
  170. Allnodes = self.INCONS.union(self.OPEN.allnodes())
  171. for node in Allnodes:
  172. self.OPEN.put(node, self.key(node, self.epsilon))
  173. self.INCONS = set()
  174. self.CLOSED = set()
  175. self.ComputeorImprovePath()
  176. # publish current epsilon sub optimal solution
  177. self.Path = self.path()
  178. # if epsilon == 1:
  179. # wait for change to occur
  180. t += 1
  181. def path(self, s_start=None):
  182. '''After ComputeShortestPath()
  183. returns, one can then follow a shortest path from x_start to
  184. x_goal by always moving from the current vertex s, starting
  185. at x_start. , to any successor s' that minimizes c(s,s') + g(s')
  186. until x_goal is reached (ties can be broken arbitrarily).'''
  187. path = []
  188. s_goal = self.xt
  189. s = self.x0
  190. ind = 0
  191. while getDist(s, s_goal) > self.env.resolution:
  192. if s == self.x0:
  193. children = [i for i in self.CLOSED if getDist(s, i) <= self.env.resolution * np.sqrt(3)]
  194. else:
  195. children = list(self.CHILDREN[s])
  196. snext = children[np.argmin([self.getcost(s, s_p) + self.getg(s_p) for s_p in children])]
  197. path.append([s, snext])
  198. s = snext
  199. if ind > 100:
  200. break
  201. ind += 1
  202. return path
  203. if __name__ == '__main__':
  204. AD = Anytime_Dstar(resolution=1)
  205. AD.Main()