utils3D.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import numpy as np
  2. from numpy.matlib import repmat
  3. import pyrr as pyrr
  4. from collections import deque
  5. import os
  6. import sys
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
  8. from rrt_3D.plot_util3D import visualization
  9. def getRay(x, y):
  10. direc = [y[0] - x[0], y[1] - x[1], y[2] - x[2]]
  11. return np.array([x, direc])
  12. def getAABB(blocks):
  13. AABB = []
  14. for i in blocks:
  15. AABB.append(np.array([np.add(i[0:3], -0), np.add(i[3:6], 0)])) # make AABBs alittle bit of larger
  16. return AABB
  17. def getDist(pos1, pos2):
  18. return np.sqrt(sum([(pos1[0] - pos2[0]) ** 2, (pos1[1] - pos2[1]) ** 2, (pos1[2] - pos2[2]) ** 2]))
  19. ''' The following utils can be used for rrt or rrt*,
  20. required param initparams should have
  21. env, environement generated from env3D
  22. V, node set
  23. E, edge set
  24. i, nodes added
  25. maxiter, maximum iteration allowed
  26. stepsize, leaf growth restriction
  27. '''
  28. def sampleFree(initparams, bias = 0.1):
  29. '''biased sampling'''
  30. x = np.random.uniform(initparams.env.boundary[0:3], initparams.env.boundary[3:6])
  31. i = np.random.random()
  32. if isinside(initparams, x):
  33. return sampleFree(initparams)
  34. else:
  35. if i < bias:
  36. return np.array(initparams.xt) + 1
  37. else:
  38. return x
  39. return x
  40. # ---------------------- Collision checking algorithms
  41. def isinside(initparams, x):
  42. '''see if inside obstacle'''
  43. for i in initparams.env.blocks:
  44. if isinbound(i, x):
  45. return True
  46. for i in initparams.env.OBB:
  47. if isinbound(i, x, mode = 'obb'):
  48. return True
  49. for i in initparams.env.balls:
  50. if isinball(i, x):
  51. return True
  52. return False
  53. def isinbound(i, x, mode = False, factor = 0, isarray = False):
  54. if mode == 'obb':
  55. return isinobb(i, x, isarray)
  56. if isarray:
  57. compx = (i[0] - factor <= x[:,0]) & (x[:,0] < i[3] + factor)
  58. compy = (i[1] - factor <= x[:,1]) & (x[:,1] < i[4] + factor)
  59. compz = (i[2] - factor <= x[:,2]) & (x[:,2] < i[5] + factor)
  60. return compx & compy & compz
  61. else:
  62. return i[0] - factor <= x[0] < i[3] + factor and i[1] - factor <= x[1] < i[4] + factor and i[2] - factor <= x[2] < i[5]
  63. def isinobb(i, x, isarray = False):
  64. # transform the point from {W} to {body}
  65. if isarray:
  66. pts = (i.T@np.column_stack((x, np.ones(len(x)))).T).T[:,0:3]
  67. block = [- i.E[0],- i.E[1],- i.E[2],+ i.E[0],+ i.E[1],+ i.E[2]]
  68. return isinbound(block, pts, isarray = isarray)
  69. else:
  70. pt = i.T@np.append(x,1)
  71. block = [- i.E[0],- i.E[1],- i.E[2],+ i.E[0],+ i.E[1],+ i.E[2]]
  72. return isinbound(block, pt)
  73. def isinball(i, x, factor = 0):
  74. if getDist(i[0:3], x) <= i[3] + factor:
  75. return True
  76. return False
  77. def lineSphere(p0, p1, ball):
  78. # https://cseweb.ucsd.edu/classes/sp19/cse291-d/Files/CSE291_13_CollisionDetection.pdf
  79. c, r = ball[0:3], ball[-1]
  80. line = [p1[0] - p0[0], p1[1] - p0[1], p1[2] - p0[2]]
  81. d1 = [c[0] - p0[0], c[1] - p0[1], c[2] - p0[2]]
  82. t = (1 / (line[0] * line[0] + line[1] * line[1] + line[2] * line[2])) * (
  83. line[0] * d1[0] + line[1] * d1[1] + line[2] * d1[2])
  84. if t <= 0:
  85. if (d1[0] * d1[0] + d1[1] * d1[1] + d1[2] * d1[2]) <= r ** 2: return True
  86. elif t >= 1:
  87. d2 = [c[0] - p1[0], c[1] - p1[1], c[2] - p1[2]]
  88. if (d2[0] * d2[0] + d2[1] * d2[1] + d2[2] * d2[2]) <= r ** 2: return True
  89. elif 0 < t < 1:
  90. x = [p0[0] + t * line[0], p0[1] + t * line[1], p0[2] + t * line[2]]
  91. k = [c[0] - x[0], c[1] - x[1], c[2] - x[2]]
  92. if (k[0] * k[0] + k[1] * k[1] + k[2] * k[2]) <= r ** 2: return True
  93. return False
  94. def lineAABB(p0, p1, dist, aabb):
  95. # https://www.gamasutra.com/view/feature/131790/simple_intersection_tests_for_games.php?print=1
  96. # aabb should have the attributes of P, E as center point and extents
  97. mid = [(p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2, (p0[2] + p1[2]) / 2] # mid point
  98. I = [(p1[0] - p0[0]) / dist, (p1[1] - p0[1]) / dist, (p1[2] - p0[2]) / dist] # unit direction
  99. hl = dist / 2 # radius
  100. T = [aabb.P[0] - mid[0], aabb.P[1] - mid[1], aabb.P[2] - mid[2]]
  101. # do any of the principal axis form a separting axis?
  102. if abs(T[0]) > (aabb.E[0] + hl * abs(I[0])): return False
  103. if abs(T[1]) > (aabb.E[1] + hl * abs(I[1])): return False
  104. if abs(T[2]) > (aabb.E[2] + hl * abs(I[2])): return False
  105. # I.cross(x axis) ?
  106. r = aabb.E[1] * abs(I[2]) + aabb.E[2] * abs(I[1])
  107. if abs(T[1] * I[2] - T[2] * I[1]) > r: return False
  108. # I.cross(y axis) ?
  109. r = aabb.E[0] * abs(I[2]) + aabb.E[2] * abs(I[0])
  110. if abs(T[2] * I[0] - T[0] * I[2]) > r: return False
  111. # I.cross(z axis) ?
  112. r = aabb.E[0] * abs(I[1]) + aabb.E[1] * abs(I[0])
  113. if abs(T[0] * I[1] - T[1] * I[0]) > r: return False
  114. return True
  115. def lineOBB(p0, p1, dist, obb):
  116. # transform points to obb frame
  117. res = obb.T@np.column_stack([np.array([p0,p1]),[1,1]]).T
  118. # record old position and set the position to origin
  119. oldP, obb.P= obb.P, [0,0,0]
  120. # calculate segment-AABB testing
  121. ans = lineAABB(res[0:3,0],res[0:3,1],dist,obb)
  122. # reset the position
  123. obb.P = oldP
  124. return ans
  125. def isCollide(initparams, x, child, dist=None):
  126. '''see if line intersects obstacle'''
  127. '''specified for expansion in A* 3D lookup table'''
  128. if dist==None:
  129. dist = getDist(x, child)
  130. # check in bound
  131. if not isinbound(initparams.env.boundary, child):
  132. return True, dist
  133. # check collision in AABB
  134. for i in range(len(initparams.env.AABB)):
  135. if lineAABB(x, child, dist, initparams.env.AABB[i]):
  136. return True, dist
  137. # check collision in ball
  138. for i in initparams.env.balls:
  139. if lineSphere(x, child, i):
  140. return True, dist
  141. # check collision with obb
  142. for i in initparams.env.OBB:
  143. if lineOBB(x, child, dist, i):
  144. return True, dist
  145. return False, dist
  146. # ---------------------- leaf node extending algorithms
  147. def nearest(initparams, x, isset=False):
  148. V = np.array(initparams.V)
  149. if initparams.i == 0:
  150. return initparams.V[0]
  151. xr = repmat(x, len(V), 1)
  152. dists = np.linalg.norm(xr - V, axis=1)
  153. return tuple(initparams.V[np.argmin(dists)])
  154. def near(initparams, x):
  155. # x = np.array(x)
  156. V = np.array(initparams.V)
  157. if initparams.i == 0:
  158. return [initparams.V[0]]
  159. cardV = len(initparams.V)
  160. eta = initparams.eta
  161. gamma = initparams.gamma
  162. r = min(gamma * (np.log(cardV) / cardV ** (1/3)), eta)
  163. if initparams.done:
  164. r = 1
  165. xr = repmat(x, len(V), 1)
  166. inside = np.linalg.norm(xr - V, axis=1) < r
  167. nearpoints = V[inside]
  168. return np.array(nearpoints)
  169. def steer(initparams, x, y, DIST=False):
  170. # steer from x to y
  171. if np.equal(x, y).all():
  172. return x, 0.0
  173. dist, step = getDist(y, x), initparams.stepsize
  174. increment = ((y[0] - x[0]) / dist * step, (y[1] - x[1]) / dist * step, (y[2] - x[2]) / dist * step)
  175. xnew = (x[0] + increment[0], x[1] + increment[1], x[2] + increment[2])
  176. # direc = (y - x) / np.linalg.norm(y - x)
  177. # xnew = x + initparams.stepsize * direc
  178. if DIST:
  179. return xnew, dist
  180. return xnew, dist
  181. def cost(initparams, x):
  182. '''here use the additive recursive cost function'''
  183. if x == initparams.x0:
  184. return 0
  185. return cost(initparams, initparams.Parent[x]) + getDist(x, initparams.Parent[x])
  186. def cost_from_set(initparams, x):
  187. '''here use a incremental cost set function'''
  188. if x == initparams.x0:
  189. return 0
  190. return initparams.COST[initparams.Parent[x]] + getDist(x, initparams.Parent[x])
  191. def path(initparams, Path=[], dist=0):
  192. x = initparams.xt
  193. while x != initparams.x0:
  194. x2 = initparams.Parent[x]
  195. Path.append(np.array([x, x2]))
  196. dist += getDist(x, x2)
  197. x = x2
  198. return Path, dist
  199. class edgeset(object):
  200. def __init__(self):
  201. self.E = {}
  202. def add_edge(self, edge):
  203. x, y = edge[0], edge[1]
  204. if x in self.E:
  205. self.E[x].add(y)
  206. else:
  207. self.E[x] = set()
  208. self.E[x].add(y)
  209. def remove_edge(self, edge):
  210. x, y = edge[0], edge[1]
  211. self.E[x].remove(y)
  212. def get_edge(self, nodes = None):
  213. edges = []
  214. if nodes is None:
  215. for v in self.E:
  216. for n in self.E[v]:
  217. # if (n,v) not in edges:
  218. edges.append((v, n))
  219. else:
  220. for v in nodes:
  221. for n in self.E[tuple(v)]:
  222. edges.append((v, n))
  223. return edges
  224. def isEndNode(self, node):
  225. return node not in self.E
  226. #------------------------ use a linked list to express the tree
  227. class Node:
  228. def __init__(self, data):
  229. self.pos = data
  230. self.Parent = None
  231. self.child = set()
  232. def tree_add_edge(node_in_tree, x):
  233. # add an edge at the specified parent
  234. node_to_add = Node(x)
  235. # node_in_tree = tree_bfs(head, xparent)
  236. node_in_tree.child.add(node_to_add)
  237. node_to_add.Parent = node_in_tree
  238. return node_to_add
  239. def tree_bfs(head, x):
  240. # searches x in order of bfs
  241. node = head
  242. Q = []
  243. Q.append(node)
  244. while Q:
  245. curr = Q.pop()
  246. if curr.pos == x:
  247. return curr
  248. for child_node in curr.child:
  249. Q.append(child_node)
  250. def tree_nearest(head, x):
  251. # find the node nearest to x
  252. D = np.inf
  253. min_node = None
  254. Q = []
  255. Q.append(head)
  256. while Q:
  257. curr = Q.pop()
  258. dist = getDist(curr.pos, x)
  259. # record the current best
  260. if dist < D:
  261. D, min_node = dist, curr
  262. # bfs
  263. for child_node in curr.child:
  264. Q.append(child_node)
  265. return min_node
  266. def tree_steer(initparams, node, x):
  267. # steer from node to x
  268. dist, step = getDist(node.pos, x), initparams.stepsize
  269. increment = ((node.pos[0] - x[0]) / dist * step, (node.pos[1] - x[1]) / dist * step, (node.pos[2] - x[2]) / dist * step)
  270. xnew = (x[0] + increment[0], x[1] + increment[1], x[2] + increment[2])
  271. return xnew
  272. def tree_print(head):
  273. Q = []
  274. Q.append(head)
  275. verts = []
  276. edge = []
  277. while Q:
  278. curr = Q.pop()
  279. # print(curr.pos)
  280. verts.append(curr.pos)
  281. if curr.Parent == None:
  282. pass
  283. else:
  284. edge.append([curr.pos, curr.Parent.pos])
  285. for child in curr.child:
  286. Q.append(child)
  287. return verts, edge
  288. def tree_path(initparams, end_node):
  289. path = []
  290. curr = end_node
  291. while curr.pos != initparams.x0:
  292. path.append([curr.pos, curr.Parent.pos])
  293. curr = curr.Parent
  294. return path
  295. #---------------KD tree, used for nearest neighbor search
  296. class kdTree:
  297. def __init__(self):
  298. pass
  299. def R1_dist(self, q, p):
  300. return abs(q-p)
  301. def S1_dist(self, q, p):
  302. return min(abs(q-p), 1- abs(q-p))
  303. def P3_dist(self, q, p):
  304. # cubes with antipodal points
  305. q1, q2, q3 = q
  306. p1, p2, p3 = p
  307. d1 = np.sqrt((q1-p1)**2 + (q2-p2)**2 + (q3-p3)**2)
  308. d2 = np.sqrt((1-abs(q1-p1))**2 + (1-abs(q2-p2))**2 + (1-abs(q3-p3))**2)
  309. d3 = np.sqrt((-q1-p1)**2 + (-q2-p2)**2 + (q3+1-p3)**2)
  310. d4 = np.sqrt((-q1-p1)**2 + (-q2-p2)**2 + (q3-1-p3)**2)
  311. d5 = np.sqrt((-q1-p1)**2 + (q2+1-p2)**2 + (-q3-p3)**2)
  312. d6 = np.sqrt((-q1-p1)**2 + (q2-1-p2)**2 + (-q3-p3)**2)
  313. d7 = np.sqrt((q1+1-p1)**2 + (-q2-p2)**2 + (-q3-p3)**2)
  314. d8 = np.sqrt((q1-1-p1)**2 + (-q2-p2)**2 + (-q3-p3)**2)
  315. return min(d1,d2,d3,d4,d5,d6,d7,d8)
  316. if __name__ == '__main__':
  317. from rrt_3D.env3D import env
  318. import time
  319. import matplotlib.pyplot as plt
  320. class rrt_demo:
  321. def __init__(self):
  322. self.env = env()
  323. self.x0, self.xt = tuple(self.env.start), tuple(self.env.goal)
  324. self.stepsize = 0.5
  325. self.maxiter = 10000
  326. self.ind, self.i = 0, 0
  327. self.done = False
  328. self.Path = []
  329. self.V = []
  330. self.head = Node(self.x0)
  331. def run(self):
  332. while self.ind < self.maxiter:
  333. xrand = sampleFree(self) # O(1)
  334. nearest_node = tree_nearest(self.head, xrand) # O(N)
  335. xnew = tree_steer(self, nearest_node, xrand) # O(1)
  336. collide, _ = isCollide(self, nearest_node.pos, xnew) # O(num obs)
  337. if not collide:
  338. new_node = tree_add_edge(nearest_node, xnew) # O(1)
  339. # if the path is found
  340. if getDist(xnew, self.xt) <= self.stepsize:
  341. end_node = tree_add_edge(new_node, self.xt)
  342. self.Path = tree_path(self, end_node)
  343. break
  344. self.i += 1
  345. self.ind += 1
  346. self.done = True
  347. self.V, self.E = tree_print(self.head)
  348. print(self.E)
  349. visualization(self)
  350. plt.show()
  351. A = rrt_demo()
  352. st = time.time()
  353. A.run()
  354. print(time.time() - st)