Dstar3D.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import os
  4. import sys
  5. from collections import defaultdict
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Search-based Planning/")
  7. from Search_3D.env3D import env
  8. from Search_3D import Astar3D
  9. from Search_3D.utils3D import getDist, getRay, isinbound, isinball
  10. import pyrr
  11. def StateSpace(env, factor = 0):
  12. boundary = env.boundary
  13. resolution = env.resolution
  14. xmin,xmax = boundary[0]+factor*resolution,boundary[3]-factor*resolution
  15. ymin,ymax = boundary[1]+factor*resolution,boundary[4]-factor*resolution
  16. zmin,zmax = boundary[2]+factor*resolution,boundary[5]-factor*resolution
  17. xarr = np.arange(xmin,xmax,resolution).astype(float)
  18. yarr = np.arange(ymin,ymax,resolution).astype(float)
  19. zarr = np.arange(zmin,zmax,resolution).astype(float)
  20. g = set()
  21. for x in xarr:
  22. for y in yarr:
  23. for z in zarr:
  24. g.add((x,y,z))
  25. return g
  26. def getNearest(Space,pt):
  27. '''get the nearest point on the grid'''
  28. mindis,minpt = 1000,None
  29. for pts in Space:
  30. dis = getDist(pts,pt)
  31. if dis < mindis:
  32. mindis,minpt = dis,pts
  33. return minpt
  34. def isCollide(initparams, x, child):
  35. '''see if line intersects obstacle'''
  36. ray , dist = getRay(x, child) , getDist(x, child)
  37. if not isinbound(initparams.env.boundary,child):
  38. return True, dist
  39. for i in initparams.env.AABB:
  40. shot = pyrr.geometric_tests.ray_intersect_aabb(ray, i)
  41. if shot is not None:
  42. dist_wall = getDist(x, shot)
  43. if dist_wall <= dist: # collide
  44. return True, dist
  45. for i in initparams.env.balls:
  46. if isinball(i, child):
  47. return True, dist
  48. shot = pyrr.geometric_tests.ray_intersect_sphere(ray, i)
  49. if shot != []:
  50. dists_ball = [getDist(x, j) for j in shot]
  51. if all(dists_ball <= dist): # collide
  52. return True, dist
  53. return False, dist
  54. def children(initparams, x):
  55. # get the neighbor of a specific state
  56. allchild = []
  57. resolution = initparams.env.resolution
  58. for direc in initparams.Alldirec:
  59. child = tuple(map(np.add,x,np.multiply(direc,resolution)))
  60. if isinbound(initparams.env.boundary,child):
  61. allchild.append(child)
  62. return allchild
  63. def cost(initparams, x, y):
  64. # get the cost between two points,
  65. # do collision check here
  66. collide, dist = isCollide(initparams,x,y)
  67. if collide: return np.inf
  68. else: return dist
  69. def initcost(initparams):
  70. # initialize cost dictionary, could be modifed lateron
  71. c = defaultdict(lambda: defaultdict(dict)) # two key dicionary
  72. for xi in initparams.X:
  73. cdren = children(initparams, xi)
  74. for child in cdren:
  75. c[xi][child] = cost(initparams, xi, child)
  76. return c
  77. class D_star(object):
  78. def __init__(self,resolution = 1):
  79. self.Alldirec = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1],
  80. [-1, 0, 0], [0, -1, 0], [0, 0, -1], [-1, -1, 0], [-1, 0, -1], [0, -1, -1],
  81. [-1, -1, -1],
  82. [1, -1, 0], [-1, 1, 0], [1, 0, -1], [-1, 0, 1], [0, 1, -1], [0, -1, 1],
  83. [1, -1, -1], [-1, 1, -1], [-1, -1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]])
  84. self.env = env(resolution = resolution)
  85. self.X = StateSpace(self.env)
  86. self.x0, self.xt = getNearest(self.X, self.env.start), getNearest(self.X, self.env.goal)
  87. self.b = {} # back pointers every state has one except xt.
  88. self.OPEN = {} # OPEN list, here use a hashmap implementation. hash is point, key is value
  89. self.h = self.initH() # estimate from a point to the end point
  90. self.tag = self.initTag() # set all states to new
  91. # initialize cost set
  92. self.c = initcost(self)
  93. # put G (ending state) into the OPEN list
  94. self.OPEN[self.xt] = 0
  95. def initH(self):
  96. # h set, all initialzed h vals are 0 for all states.
  97. h = {}
  98. for xi in self.X:
  99. h[xi] = 0
  100. return h
  101. def initTag(self):
  102. # tag , New point (never been in the OPEN list)
  103. # Open point ( currently in OPEN )
  104. # Closed (currently in CLOSED)
  105. t = {}
  106. for xi in self.X:
  107. t[xi] = 'New'
  108. return t
  109. def get_kmin(self):
  110. # get the minimum of the k val in OPEN
  111. # -1 if it does not exist
  112. if self.OPEN:
  113. minv = np.inf
  114. for k,v in enumerate(self.OPEN):
  115. if v < minv: minv = v
  116. return minv
  117. return -1
  118. def min_state(self):
  119. # returns the state in OPEN with min k(.)
  120. # if empty, returns None and -1
  121. # it also removes this min value form the OPEN set.
  122. if self.OPEN:
  123. minv = np.inf
  124. for k,v in enumerate(self.OPEN):
  125. if v < minv: mink, minv = k, v
  126. return mink, self.OPEN.pop(mink)
  127. return None, -1
  128. def insert(self, x, h_new):
  129. # inserting a key and value into OPEN list (x, kx)
  130. # depending on following situations
  131. if self.tag[x] == 'New':
  132. kx = h_new
  133. if self.tag[x] == 'Open':
  134. kx = min(self.OPEN[x],h_new)
  135. if self.tag[x] == 'Closed':
  136. kx = min(self.h[x], h_new)
  137. self.OPEN[x] = kx
  138. self.h[x],self.tag[x] = h_new, 'Open'
  139. def process_state(self):
  140. x, kold = self.min_state()
  141. self.tag[x] = 'Closed'
  142. if x == None: return -1
  143. if kold < self.h[x]: # raised states
  144. for y in children(self,x):
  145. a = self.h[y] + self.c[y][x]
  146. if self.h[y] <= kold and self.h[x] > a:
  147. self.b[x], self.h[x] = y , a
  148. elif kold == self.h[x]:# lower
  149. for y in children(self,x):
  150. bb = self.h[x] + self.c[x][y]
  151. if self.tag[y] == 'New' or \
  152. (self.b[y] == x and self.h[y] != bb) or \
  153. (self.b[y] != x and self.h[y] > bb):
  154. self.b[y] = x
  155. self.insert(y, bb)
  156. else:
  157. for y in children(self,x):
  158. bb = self.h[x] + self.c[x][y]
  159. if self.tag[y] == 'New' or \
  160. (self.b[y] == x and self.h[y] != bb):
  161. self.b[y] = x
  162. self.insert(y, bb)
  163. else:
  164. if self.b[y] != x and self.h[y] > bb:
  165. self.insert(x, self.h[x])
  166. else:
  167. if self.b[y] != x and self.h[y] > bb and \
  168. self.tag[y] == 'Closed' and self.h[y] == kold:
  169. self.insert(y, self.h[y])
  170. return self.get_kmin()
  171. def modify_cost(self,x,y,cval):
  172. self.c[x][y] = cval # set the new cost to the cval
  173. if self.tag[x] == 'Closed': self.insert(x,self.h[x])
  174. return self.get_kmin()
  175. def run(self):
  176. # TODO: implementation of changing obstable in process
  177. pass
  178. if __name__ == '__main__':
  179. D = D_star(1)