Dstar3D.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import os
  4. import sys
  5. from collections import defaultdict
  6. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Search-based Planning/")
  7. from Search_3D.env3D import env
  8. from Search_3D import Astar3D
  9. from Search_3D.utils3D import StateSpace, getDist, getRay, isinbound, isinball
  10. import pyrr
  11. def isCollide(initparams, x, child):
  12. '''see if line intersects obstacle'''
  13. ray , dist = getRay(x, child) , getDist(x, child)
  14. if not isinbound(initparams.env.boundary,child):
  15. return True, dist
  16. for i in initparams.env.AABB:
  17. shot = pyrr.geometric_tests.ray_intersect_aabb(ray, i)
  18. if shot is not None:
  19. dist_wall = getDist(x, shot)
  20. if dist_wall <= dist: # collide
  21. return True, dist
  22. for i in initparams.env.balls:
  23. if isinball(i, child):
  24. return True, dist
  25. shot = pyrr.geometric_tests.ray_intersect_sphere(ray, i)
  26. if shot != []:
  27. dists_ball = [getDist(x, j) for j in shot]
  28. if all(dists_ball <= dist): # collide
  29. return True, dist
  30. return False, dist
  31. def children(initparams, x):
  32. # get the neighbor of a specific state
  33. allchild = []
  34. resolution = initparams.env.resolution
  35. for direc in initparams.Alldirec:
  36. child = tuple(map(np.add,x,np.multiply(direc,resolution)))
  37. if isinbound(initparams.env.boundary,child):
  38. allchild.append(child)
  39. return allchild
  40. def cost(initparams, x, y):
  41. # get the cost between two points,
  42. # do collision check here
  43. collide, dist = isCollide(initparams,x,y)
  44. if collide: return np.inf
  45. else: return dist
  46. def initcost(initparams):
  47. # initialize cost dictionary, could be modifed lateron
  48. c = defaultdict(lambda: defaultdict(dict)) # two key dicionary
  49. for xi in initparams.X:
  50. cdren = children(initparams, xi)
  51. for child in cdren:
  52. c[xi][child] = cost(initparams, xi, child)
  53. return c
  54. class D_star(object):
  55. def __init__(self,resolution = 1):
  56. self.Alldirec = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1],
  57. [-1, 0, 0], [0, -1, 0], [0, 0, -1], [-1, -1, 0], [-1, 0, -1], [0, -1, -1],
  58. [-1, -1, -1],
  59. [1, -1, 0], [-1, 1, 0], [1, 0, -1], [-1, 0, 1], [0, 1, -1], [0, -1, 1],
  60. [1, -1, -1], [-1, 1, -1], [-1, -1, 1], [1, 1, -1], [1, -1, 1], [-1, 1, 1]])
  61. self.env = env(resolution = resolution)
  62. self.X = StateSpace(self.env)
  63. self.x0, self.xt = getNearest(self.X, self.env.start), getNearest(self.X, self.env.goal)
  64. self.b = {} # back pointers every state has one except xt.
  65. self.OPEN = {} # OPEN list, here use a hashmap implementation. hash is point, key is value
  66. self.h = self.initH() # estimate from a point to the end point
  67. self.tag = self.initTag() # set all states to new
  68. # initialize cost set
  69. self.c = initcost(self)
  70. # put G (ending state) into the OPEN list
  71. self.OPEN[self.xt] = 0
  72. def initH(self):
  73. # h set, all initialzed h vals are 0 for all states.
  74. h = {}
  75. for xi in self.X:
  76. h[xi] = 0
  77. return h
  78. def initTag(self):
  79. # tag , New point (never been in the OPEN list)
  80. # Open point ( currently in OPEN )
  81. # Closed (currently in CLOSED)
  82. t = {}
  83. for xi in self.X:
  84. t[xi] = 'New'
  85. return t
  86. def get_kmin(self):
  87. # get the minimum of the k val in OPEN
  88. # -1 if it does not exist
  89. if self.OPEN:
  90. minv = np.inf
  91. for k,v in enumerate(self.OPEN):
  92. if v < minv: minv = v
  93. return minv
  94. return -1
  95. def min_state(self):
  96. # returns the state in OPEN with min k(.)
  97. # if empty, returns None and -1
  98. # it also removes this min value form the OPEN set.
  99. if self.OPEN:
  100. minv = np.inf
  101. for k,v in enumerate(self.OPEN):
  102. if v < minv: mink, minv = k, v
  103. return mink, self.OPEN.pop(mink)
  104. return None, -1
  105. def insert(self, x, h_new):
  106. # inserting a key and value into OPEN list (x, kx)
  107. # depending on following situations
  108. if self.tag[x] == 'New':
  109. kx = h_new
  110. if self.tag[x] == 'Open':
  111. kx = min(self.OPEN[x],h_new)
  112. if self.tag[x] == 'Closed':
  113. kx = min(self.h[x], h_new)
  114. self.OPEN[x] = kx
  115. self.h[x],self.tag[x] = h_new, 'Open'
  116. def process_state(self):
  117. x, kold = self.min_state()
  118. self.tag[x] = 'Closed'
  119. if x == None: return -1
  120. if kold < self.h[x]: # raised states
  121. for y in children(self,x):
  122. a = self.h[y] + self.c[y][x]
  123. if self.h[y] <= kold and self.h[x] > a:
  124. self.b[x], self.h[x] = y , a
  125. elif kold == self.h[x]:# lower
  126. for y in children(self,x):
  127. bb = self.h[x] + self.c[x][y]
  128. if self.tag[y] == 'New' or \
  129. (self.b[y] == x and self.h[y] != bb) or \
  130. (self.b[y] != x and self.h[y] > bb):
  131. self.b[y] = x
  132. self.insert(y, bb)
  133. else:
  134. for y in children(self,x):
  135. bb = self.h[x] + self.c[x][y]
  136. if self.tag[y] == 'New' or \
  137. (self.b[y] == x and self.h[y] != bb):
  138. self.b[y] = x
  139. self.insert(y, bb)
  140. else:
  141. if self.b[y] != x and self.h[y] > bb:
  142. self.insert(x, self.h[x])
  143. else:
  144. if self.b[y] != x and self.h[y] > bb and \
  145. self.tag[y] == 'Closed' and self.h[y] == kold:
  146. self.insert(y, self.h[y])
  147. return self.get_kmin()
  148. def modify_cost(self,x,y,cval):
  149. self.c[x][y] = cval # set the new cost to the cval
  150. if self.tag[x] == 'Closed': self.insert(x,self.h[x])
  151. return self.get_kmin()
  152. def run(self):
  153. # TODO: implementation of changing obstable in process
  154. pass
  155. if __name__ == '__main__':
  156. D = D_star(1)