Dstar3D.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import os
  4. import sys
  5. from collections import defaultdict
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Search-based Planning/")
  7. from Search_3D.env3D import env
  8. from Search_3D import Astar3D
  9. from Search_3D.utils3D import StateSpace, getDist, getRay, isinbound, isinball
  10. import pyrr
  11. def getNearest(Space,pt):
  12. '''get the nearest point on the grid'''
  13. mindis,minpt = 1000,None
  14. for pts in Space:
  15. dis = getDist(pts,pt)
  16. if dis < mindis:
  17. mindis,minpt = dis,pts
  18. return minpt
  19. def isCollide(initparams, x, child):
  20. '''see if line intersects obstacle'''
  21. ray , dist = getRay(x, child) , getDist(x, child)
  22. if not isinbound(initparams.env.boundary,child):
  23. return True, dist
  24. for i in initparams.env.AABB:
  25. shot = pyrr.geometric_tests.ray_intersect_aabb(ray, i)
  26. if shot is not None:
  27. dist_wall = getDist(x, shot)
  28. if dist_wall <= dist: # collide
  29. return True, dist
  30. for i in initparams.env.balls:
  31. if isinball(i, child):
  32. return True, dist
  33. shot = pyrr.geometric_tests.ray_intersect_sphere(ray, i)
  34. if shot != []:
  35. dists_ball = [getDist(x, j) for j in shot]
  36. if all(dists_ball <= dist): # collide
  37. return True, dist
  38. return False, dist
  39. def children(initparams, x):
  40. # get the neighbor of a specific state
  41. allchild = []
  42. resolution = initparams.env.resolution
  43. for direc in initparams.Alldirec:
  44. child = tuple(map(np.add,x,np.multiply(direc,resolution)))
  45. if isinbound(initparams.env.boundary,child):
  46. allchild.append(child)
  47. return allchild
  48. def cost(initparams, x, y):
  49. # get the cost between two points,
  50. # do collision check here
  51. collide, dist = isCollide(initparams,x,y)
  52. if collide: return np.inf
  53. else: return dist
  54. def initcost(initparams):
  55. # initialize cost dictionary, could be modifed lateron
  56. c = defaultdict(lambda: defaultdict(dict)) # two key dicionary
  57. for xi in initparams.X:
  58. cdren = children(initparams, xi)
  59. for child in cdren:
  60. c[xi][child] = cost(initparams, xi, child)
  61. return c
  62. class D_star(object):
  63. def __init__(self,resolution = 1):
  64. self.Alldirec = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1],
  65. [-1, 0, 0], [0, -1, 0], [0, 0, -1], [-1, -1, 0], [-1, 0, -1], [0, -1, -1],
  66. [-1, -1, -1],
  67. [1, -1, 0], [-1, 1, 0], [1, 0, -1], [-1, 0, 1], [0, 1, -1], [0, -1, 1],
  68. [1, -1, -1], [-1, 1, -1], [-1, -1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]])
  69. self.env = env(resolution = resolution)
  70. self.X = StateSpace(self.env)
  71. self.x0, self.xt = getNearest(self.X, self.env.start), getNearest(self.X, self.env.goal)
  72. self.b = {} # back pointers every state has one except xt.
  73. self.OPEN = {} # OPEN list, here use a hashmap implementation. hash is point, key is value
  74. self.h = self.initH() # estimate from a point to the end point
  75. self.tag = self.initTag() # set all states to new
  76. # initialize cost set
  77. self.c = initcost(self)
  78. # put G (ending state) into the OPEN list
  79. self.OPEN[self.xt] = 0
  80. def initH(self):
  81. # h set, all initialzed h vals are 0 for all states.
  82. h = {}
  83. for xi in self.X:
  84. h[xi] = 0
  85. return h
  86. def initTag(self):
  87. # tag , New point (never been in the OPEN list)
  88. # Open point ( currently in OPEN )
  89. # Closed (currently in CLOSED)
  90. t = {}
  91. for xi in self.X:
  92. t[xi] = 'New'
  93. return t
  94. def get_kmin(self):
  95. # get the minimum of the k val in OPEN
  96. # -1 if it does not exist
  97. if self.OPEN:
  98. minv = np.inf
  99. for k,v in enumerate(self.OPEN):
  100. if v < minv: minv = v
  101. return minv
  102. return -1
  103. def min_state(self):
  104. # returns the state in OPEN with min k(.)
  105. # if empty, returns None and -1
  106. # it also removes this min value form the OPEN set.
  107. if self.OPEN:
  108. minv = np.inf
  109. for k,v in enumerate(self.OPEN):
  110. if v < minv: mink, minv = k, v
  111. return mink, self.OPEN.pop(mink)
  112. return None, -1
  113. def insert(self, x, h_new):
  114. # inserting a key and value into OPEN list (x, kx)
  115. # depending on following situations
  116. if self.tag[x] == 'New':
  117. kx = h_new
  118. if self.tag[x] == 'Open':
  119. kx = min(self.OPEN[x],h_new)
  120. if self.tag[x] == 'Closed':
  121. kx = min(self.h[x], h_new)
  122. self.OPEN[x] = kx
  123. self.h[x],self.tag[x] = h_new, 'Open'
  124. def process_state(self):
  125. x, kold = self.min_state()
  126. self.tag[x] = 'Closed'
  127. if x == None: return -1
  128. if kold < self.h[x]: # raised states
  129. for y in children(self,x):
  130. a = self.h[y] + self.c[y][x]
  131. if self.h[y] <= kold and self.h[x] > a:
  132. self.b[x], self.h[x] = y , a
  133. elif kold == self.h[x]:# lower
  134. for y in children(self,x):
  135. bb = self.h[x] + self.c[x][y]
  136. if self.tag[y] == 'New' or \
  137. (self.b[y] == x and self.h[y] != bb) or \
  138. (self.b[y] != x and self.h[y] > bb):
  139. self.b[y] = x
  140. self.insert(y, bb)
  141. else:
  142. for y in children(self,x):
  143. bb = self.h[x] + self.c[x][y]
  144. if self.tag[y] == 'New' or \
  145. (self.b[y] == x and self.h[y] != bb):
  146. self.b[y] = x
  147. self.insert(y, bb)
  148. else:
  149. if self.b[y] != x and self.h[y] > bb:
  150. self.insert(x, self.h[x])
  151. else:
  152. if self.b[y] != x and self.h[y] > bb and \
  153. self.tag[y] == 'Closed' and self.h[y] == kold:
  154. self.insert(y, self.h[y])
  155. return self.get_kmin()
  156. def modify_cost(self,x,y,cval):
  157. self.c[x][y] = cval # set the new cost to the cval
  158. if self.tag[x] == 'Closed': self.insert(x,self.h[x])
  159. return self.get_kmin()
  160. def run(self):
  161. # TODO: implementation of changing obstable in process
  162. pass
  163. if __name__ == '__main__':
  164. D = D_star(1)