Dstar3D.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import os
  4. import sys
  5. from collections import defaultdict
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Search-based Planning/")
  7. from Search_3D.env3D import env
  8. from Search_3D import Astar3D
  9. from Search_3D.utils3D import StateSpace, getDist, getNearest, getRay, isinbound, isinball, isCollide, children, cost, \
  10. initcost
  11. from Search_3D.plot_util3D import visualization
  12. class D_star(object):
  13. def __init__(self, resolution=1):
  14. self.Alldirec = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1],
  15. [-1, 0, 0], [0, -1, 0], [0, 0, -1], [-1, -1, 0], [-1, 0, -1], [0, -1, -1],
  16. [-1, -1, -1],
  17. [1, -1, 0], [-1, 1, 0], [1, 0, -1], [-1, 0, 1], [0, 1, -1], [0, -1, 1],
  18. [1, -1, -1], [-1, 1, -1], [-1, -1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]])
  19. self.env = env(resolution=resolution)
  20. self.X = StateSpace(self.env)
  21. self.x0, self.xt = getNearest(self.X, self.env.start), getNearest(self.X, self.env.goal)
  22. self.b = defaultdict(lambda: defaultdict(dict)) # back pointers every state has one except xt.
  23. self.OPEN = {} # OPEN list, here use a hashmap implementation. hash is point, key is value
  24. self.h = self.initH() # estimate from a point to the end point
  25. self.tag = self.initTag() # set all states to new
  26. self.V = set() # vertice in closed
  27. # initialize cost set
  28. # self.c = initcost(self)
  29. # for visualization
  30. self.ind = 0
  31. self.Path = []
  32. self.done = False
  33. self.Obstaclemap = {}
  34. def update_obs(self):
  35. for xi in self.X:
  36. print('xi')
  37. self.Obstaclemap[xi] = False
  38. for aabb in self.env.blocks:
  39. self.Obstaclemap[xi] = isinbound(aabb, xi)
  40. if self.Obstaclemap[xi] == False:
  41. for ball in self.env.balls:
  42. self.Obstaclemap[xi] = isinball(ball, xi)
  43. def initH(self):
  44. # h set, all initialzed h vals are 0 for all states.
  45. h = {}
  46. for xi in self.X:
  47. h[xi] = 0
  48. return h
  49. def initTag(self):
  50. # tag , New point (never been in the OPEN list)
  51. # Open point ( currently in OPEN )
  52. # Closed (currently in CLOSED)
  53. t = {}
  54. for xi in self.X:
  55. t[xi] = 'New'
  56. return t
  57. def get_kmin(self):
  58. # get the minimum of the k val in OPEN
  59. # -1 if it does not exist
  60. if self.OPEN:
  61. return min([x for x in self.OPEN.values()])
  62. return -1
  63. def min_state(self):
  64. # returns the state in OPEN with min k(.)
  65. # if empty, returns None and -1
  66. # it also removes this min value form the OPEN set.
  67. if self.OPEN:
  68. mink = -1
  69. minv = np.inf
  70. for v, k in enumerate(self.OPEN):
  71. if v < minv:
  72. mink, minv = k, v
  73. return mink, self.OPEN.pop(mink)
  74. return None, -1
  75. def insert(self, x, h_new):
  76. # inserting a key and value into OPEN list (x, kx)
  77. # depending on following situations
  78. if self.tag[x] == 'New':
  79. kx = h_new
  80. if self.tag[x] == 'Open':
  81. kx = min(self.OPEN[x], h_new)
  82. if self.tag[x] == 'Closed':
  83. kx = min(self.h[x], h_new)
  84. self.OPEN[x] = kx
  85. self.h[x], self.tag[x] = h_new, 'Open'
  86. def process_state(self):
  87. x, kold = self.min_state()
  88. self.tag[x] = 'Closed'
  89. self.V.add(x)
  90. if x == None: return -1
  91. if kold < self.h[x]: # raised states
  92. for y in children(self, x):
  93. a = self.h[y] + cost(self, y, x)
  94. if self.h[y] <= kold and self.h[x] > a:
  95. self.b[x], self.h[x] = y, a
  96. if kold == self.h[x]: # lower
  97. for y in children(self, x):
  98. bb = self.h[x] + cost(self, x, y)
  99. if self.tag[y] == 'New' or \
  100. (self.b[y] == x and self.h[y] != bb) or \
  101. (self.b[y] != x and self.h[y] > bb):
  102. self.b[y] = x
  103. self.insert(y, bb)
  104. else:
  105. for y in children(self, x):
  106. bb = self.h[x] + cost(self, x, y)
  107. if self.tag[y] == 'New' or \
  108. (self.b[y] == x and self.h[y] != bb):
  109. self.b[y] = x
  110. self.insert(y, bb)
  111. else:
  112. if self.b[y] != x and self.h[y] > bb:
  113. self.insert(x, self.h[x])
  114. else:
  115. if self.b[y] != x and self.h[y] > bb and \
  116. self.tag[y] == 'Closed' and self.h[y] == kold:
  117. self.insert(y, self.h[y])
  118. return self.get_kmin()
  119. def modify_cost(self, x):
  120. # TODO: implement own function
  121. # self.c[x][y] = cval
  122. xparent = self.b[x]
  123. if self.tag[x] == 'Closed':
  124. self.insert(x, self.h[xparent] + cost(self, x, xparent))
  125. def modify(self, x):
  126. self.modify_cost(x)
  127. while True:
  128. kmin = self.process_state()
  129. if kmin >= self.h[x]:
  130. break
  131. def path(self, goal=None):
  132. path = []
  133. if not goal:
  134. x = self.x0
  135. else:
  136. x = goal
  137. start = self.xt
  138. while x != start:
  139. path.append([np.array(x), np.array(self.b[x])])
  140. x = self.b[x]
  141. return path
  142. def run(self):
  143. # put G (ending state) into the OPEN list
  144. self.OPEN[self.xt] = 0
  145. # first run
  146. while True:
  147. # TODO: self.x0 =
  148. self.process_state()
  149. visualization(self)
  150. if self.tag[self.x0] == "Closed":
  151. break
  152. self.ind += 1
  153. self.Path = self.path()
  154. self.done = True
  155. visualization(self)
  156. plt.pause(0.2)
  157. # plt.show()
  158. # when the environemnt changes over time
  159. for i in range(2):
  160. self.env.move_block(a=[0, 0, -1], s=0.5, block_to_move=1, mode='translation')
  161. visualization(self)
  162. plt.pause(0.2)
  163. s = tuple(self.env.start)
  164. count = 0
  165. count_obs = 0
  166. while s != self.xt:
  167. count += 1
  168. print(count)
  169. if s == tuple(self.env.start):
  170. sparent = self.b[self.x0]
  171. else:
  172. sparent = self.b[s]
  173. # self.update_obs()
  174. if cost(self, s, sparent) == np.inf:
  175. # print(s, " ", sparent)
  176. count_obs += 1
  177. print(count_obs)
  178. self.modify(s)
  179. continue
  180. self.ind += 1
  181. s = sparent
  182. print("test")
  183. self.Path = self.path()
  184. visualization(self)
  185. plt.pause(0.2)
  186. plt.show()
  187. if __name__ == '__main__':
  188. D = D_star(1)
  189. D.run()