utils3D.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import numpy as np
  2. from numpy.matlib import repmat
  3. import pyrr as pyrr
  4. from collections import deque
  5. import os
  6. import sys
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling_based_Planning/")
  8. from rrt_3D.plot_util3D import visualization
  9. def getRay(x, y):
  10. direc = [y[0] - x[0], y[1] - x[1], y[2] - x[2]]
  11. return np.array([x, direc])
  12. def getAABB(blocks):
  13. AABB = []
  14. for i in blocks:
  15. AABB.append(np.array([np.add(i[0:3], -0), np.add(i[3:6], 0)])) # make AABBs alittle bit of larger
  16. return AABB
  17. def getDist(pos1, pos2):
  18. return np.sqrt(sum([(pos1[0] - pos2[0]) ** 2, (pos1[1] - pos2[1]) ** 2, (pos1[2] - pos2[2]) ** 2]))
  19. ''' The following utils can be used for rrt or rrt*,
  20. required param initparams should have
  21. env, environement generated from env3D
  22. V, node set
  23. E, edge set
  24. i, nodes added
  25. maxiter, maximum iteration allowed
  26. stepsize, leaf growth restriction
  27. '''
  28. def sampleFree(initparams, bias = 0.1):
  29. '''biased sampling'''
  30. x = np.random.uniform(initparams.env.boundary[0:3], initparams.env.boundary[3:6])
  31. i = np.random.random()
  32. if isinside(initparams, x):
  33. return sampleFree(initparams)
  34. else:
  35. if i < bias:
  36. return np.array(initparams.xt) + 1
  37. else:
  38. return x
  39. return x
  40. # ---------------------- Collision checking algorithms
  41. def isinside(initparams, x):
  42. '''see if inside obstacle'''
  43. for i in initparams.env.blocks:
  44. if isinbound(i, x):
  45. return True
  46. for i in initparams.env.OBB:
  47. if isinbound(i, x, mode = 'obb'):
  48. return True
  49. for i in initparams.env.balls:
  50. if isinball(i, x):
  51. return True
  52. return False
  53. def isinbound(i, x, mode = False, factor = 0, isarray = False):
  54. if mode == 'obb':
  55. return isinobb(i, x, isarray)
  56. if isarray:
  57. compx = (i[0] - factor <= x[:,0]) & (x[:,0] < i[3] + factor)
  58. compy = (i[1] - factor <= x[:,1]) & (x[:,1] < i[4] + factor)
  59. compz = (i[2] - factor <= x[:,2]) & (x[:,2] < i[5] + factor)
  60. return compx & compy & compz
  61. else:
  62. return i[0] - factor <= x[0] < i[3] + factor and i[1] - factor <= x[1] < i[4] + factor and i[2] - factor <= x[2] < i[5]
  63. def isinobb(i, x, isarray = False):
  64. # transform the point from {W} to {body}
  65. if isarray:
  66. pts = (i.T@np.column_stack((x, np.ones(len(x)))).T).T[:,0:3]
  67. block = [- i.E[0],- i.E[1],- i.E[2],+ i.E[0],+ i.E[1],+ i.E[2]]
  68. return isinbound(block, pts, isarray = isarray)
  69. else:
  70. pt = i.T@np.append(x,1)
  71. block = [- i.E[0],- i.E[1],- i.E[2],+ i.E[0],+ i.E[1],+ i.E[2]]
  72. return isinbound(block, pt)
  73. def isinball(i, x, factor = 0):
  74. if getDist(i[0:3], x) <= i[3] + factor:
  75. return True
  76. return False
  77. def lineSphere(p0, p1, ball):
  78. # https://cseweb.ucsd.edu/classes/sp19/cse291-d/Files/CSE291_13_CollisionDetection.pdf
  79. c, r = ball[0:3], ball[-1]
  80. line = [p1[0] - p0[0], p1[1] - p0[1], p1[2] - p0[2]]
  81. d1 = [c[0] - p0[0], c[1] - p0[1], c[2] - p0[2]]
  82. t = (1 / (line[0] * line[0] + line[1] * line[1] + line[2] * line[2])) * (
  83. line[0] * d1[0] + line[1] * d1[1] + line[2] * d1[2])
  84. if t <= 0:
  85. if (d1[0] * d1[0] + d1[1] * d1[1] + d1[2] * d1[2]) <= r ** 2: return True
  86. elif t >= 1:
  87. d2 = [c[0] - p1[0], c[1] - p1[1], c[2] - p1[2]]
  88. if (d2[0] * d2[0] + d2[1] * d2[1] + d2[2] * d2[2]) <= r ** 2: return True
  89. elif 0 < t < 1:
  90. x = [p0[0] + t * line[0], p0[1] + t * line[1], p0[2] + t * line[2]]
  91. k = [c[0] - x[0], c[1] - x[1], c[2] - x[2]]
  92. if (k[0] * k[0] + k[1] * k[1] + k[2] * k[2]) <= r ** 2: return True
  93. return False
  94. def lineAABB(p0, p1, dist, aabb):
  95. # https://www.gamasutra.com/view/feature/131790/simple_intersection_tests_for_games.php?print=1
  96. # aabb should have the attributes of P, E as center point and extents
  97. mid = [(p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2, (p0[2] + p1[2]) / 2] # mid point
  98. I = [(p1[0] - p0[0]) / dist, (p1[1] - p0[1]) / dist, (p1[2] - p0[2]) / dist] # unit direction
  99. hl = dist / 2 # radius
  100. T = [aabb.P[0] - mid[0], aabb.P[1] - mid[1], aabb.P[2] - mid[2]]
  101. # do any of the principal axis form a separting axis?
  102. if abs(T[0]) > (aabb.E[0] + hl * abs(I[0])): return False
  103. if abs(T[1]) > (aabb.E[1] + hl * abs(I[1])): return False
  104. if abs(T[2]) > (aabb.E[2] + hl * abs(I[2])): return False
  105. # I.cross(s axis) ?
  106. r = aabb.E[1] * abs(I[2]) + aabb.E[2] * abs(I[1])
  107. if abs(T[1] * I[2] - T[2] * I[1]) > r: return False
  108. # I.cross(y axis) ?
  109. r = aabb.E[0] * abs(I[2]) + aabb.E[2] * abs(I[0])
  110. if abs(T[2] * I[0] - T[0] * I[2]) > r: return False
  111. # I.cross(z axis) ?
  112. r = aabb.E[0] * abs(I[1]) + aabb.E[1] * abs(I[0])
  113. if abs(T[0] * I[1] - T[1] * I[0]) > r: return False
  114. return True
  115. def lineOBB(p0, p1, dist, obb):
  116. # transform points to obb frame
  117. res = obb.T@np.column_stack([np.array([p0,p1]),[1,1]]).T
  118. # record old position and set the position to origin
  119. oldP, obb.P= obb.P, [0,0,0]
  120. # calculate segment-AABB testing
  121. ans = lineAABB(res[0:3,0],res[0:3,1],dist,obb)
  122. # reset the position
  123. obb.P = oldP
  124. return ans
  125. def isCollide(initparams, x, child, dist=None):
  126. '''see if line intersects obstacle'''
  127. '''specified for expansion in A* 3D lookup table'''
  128. if dist==None:
  129. dist = getDist(x, child)
  130. # check in bound
  131. if not isinbound(initparams.env.boundary, child):
  132. return True, dist
  133. # check collision in AABB
  134. for i in range(len(initparams.env.AABB)):
  135. if lineAABB(x, child, dist, initparams.env.AABB[i]):
  136. return True, dist
  137. # check collision in ball
  138. for i in initparams.env.balls:
  139. if lineSphere(x, child, i):
  140. return True, dist
  141. # check collision with obb
  142. for i in initparams.env.OBB:
  143. if lineOBB(x, child, dist, i):
  144. return True, dist
  145. return False, dist
  146. # ---------------------- leaf node extending algorithms
  147. def nearest(initparams, x, isset=False):
  148. V = np.array(initparams.V)
  149. if initparams.i == 0:
  150. return initparams.V[0]
  151. xr = repmat(x, len(V), 1)
  152. dists = np.linalg.norm(xr - V, axis=1)
  153. return tuple(initparams.V[np.argmin(dists)])
  154. def near(initparams, x):
  155. # s = np.array(s)
  156. V = np.array(initparams.V)
  157. if initparams.i == 0:
  158. return [initparams.V[0]]
  159. cardV = len(initparams.V)
  160. eta = initparams.eta
  161. gamma = initparams.gamma
  162. # min{γRRT∗ (log(card (V ))/ card (V ))1/d, η}
  163. r = min(gamma * ((np.log(cardV) / cardV) ** (1/3)), eta)
  164. if initparams.done:
  165. r = 1
  166. xr = repmat(x, len(V), 1)
  167. inside = np.linalg.norm(xr - V, axis=1) < r
  168. nearpoints = V[inside]
  169. return np.array(nearpoints)
  170. def steer(initparams, x, y, DIST=False):
  171. # steer from s to y
  172. if np.equal(x, y).all():
  173. return x, 0.0
  174. dist, step = getDist(y, x), initparams.stepsize
  175. step = min(dist, step)
  176. increment = ((y[0] - x[0]) / dist * step, (y[1] - x[1]) / dist * step, (y[2] - x[2]) / dist * step)
  177. xnew = (x[0] + increment[0], x[1] + increment[1], x[2] + increment[2])
  178. # direc = (y - s) / np.linalg.norm(y - s)
  179. # xnew = s + initparams.stepsize * direc
  180. if DIST:
  181. return xnew, dist
  182. return xnew, dist
  183. def cost(initparams, x):
  184. '''here use the additive recursive cost function'''
  185. if x == initparams.x0:
  186. return 0
  187. return cost(initparams, initparams.Parent[x]) + getDist(x, initparams.Parent[x])
  188. def cost_from_set(initparams, x):
  189. '''here use a incremental cost set function'''
  190. if x == initparams.x0:
  191. return 0
  192. return initparams.COST[initparams.Parent[x]] + getDist(x, initparams.Parent[x])
  193. def path(initparams, Path=[], dist=0):
  194. x = initparams.xt
  195. while x != initparams.x0:
  196. x2 = initparams.Parent[x]
  197. Path.append(np.array([x, x2]))
  198. dist += getDist(x, x2)
  199. x = x2
  200. return Path, dist
  201. class edgeset(object):
  202. def __init__(self):
  203. self.E = {}
  204. def add_edge(self, edge):
  205. x, y = edge[0], edge[1]
  206. if x in self.E:
  207. self.E[x].add(y)
  208. else:
  209. self.E[x] = set()
  210. self.E[x].add(y)
  211. def remove_edge(self, edge):
  212. x, y = edge[0], edge[1]
  213. self.E[x].remove(y)
  214. def get_edge(self, nodes = None):
  215. edges = []
  216. if nodes is None:
  217. for v in self.E:
  218. for n in self.E[v]:
  219. # if (n,v) not in edges:
  220. edges.append((v, n))
  221. else:
  222. for v in nodes:
  223. for n in self.E[tuple(v)]:
  224. edges.append((v, n))
  225. return edges
  226. def isEndNode(self, node):
  227. return node not in self.E
  228. #------------------------ use a linked list to express the tree
  229. class Node:
  230. def __init__(self, data):
  231. self.pos = data
  232. self.Parent = None
  233. self.child = set()
  234. def tree_add_edge(node_in_tree, x):
  235. # add an edge at the specified parent
  236. node_to_add = Node(x)
  237. # node_in_tree = tree_bfs(head, xparent)
  238. node_in_tree.child.add(node_to_add)
  239. node_to_add.Parent = node_in_tree
  240. return node_to_add
  241. def tree_bfs(head, x):
  242. # searches s in order of bfs
  243. node = head
  244. Q = []
  245. Q.append(node)
  246. while Q:
  247. curr = Q.pop()
  248. if curr.pos == x:
  249. return curr
  250. for child_node in curr.child:
  251. Q.append(child_node)
  252. def tree_nearest(head, x):
  253. # find the node nearest to s
  254. D = np.inf
  255. min_node = None
  256. Q = []
  257. Q.append(head)
  258. while Q:
  259. curr = Q.pop()
  260. dist = getDist(curr.pos, x)
  261. # record the current best
  262. if dist < D:
  263. D, min_node = dist, curr
  264. # bfs
  265. for child_node in curr.child:
  266. Q.append(child_node)
  267. return min_node
  268. def tree_steer(initparams, node, x):
  269. # steer from node to s
  270. dist, step = getDist(node.pos, x), initparams.stepsize
  271. increment = ((node.pos[0] - x[0]) / dist * step, (node.pos[1] - x[1]) / dist * step, (node.pos[2] - x[2]) / dist * step)
  272. xnew = (x[0] + increment[0], x[1] + increment[1], x[2] + increment[2])
  273. return xnew
  274. def tree_print(head):
  275. Q = []
  276. Q.append(head)
  277. verts = []
  278. edge = []
  279. while Q:
  280. curr = Q.pop()
  281. # print(curr.pos)
  282. verts.append(curr.pos)
  283. if curr.Parent == None:
  284. pass
  285. else:
  286. edge.append([curr.pos, curr.Parent.pos])
  287. for child in curr.child:
  288. Q.append(child)
  289. return verts, edge
  290. def tree_path(initparams, end_node):
  291. path = []
  292. curr = end_node
  293. while curr.pos != initparams.x0:
  294. path.append([curr.pos, curr.Parent.pos])
  295. curr = curr.Parent
  296. return path
  297. #---------------KD tree, used for nearest neighbor search
  298. class kdTree:
  299. def __init__(self):
  300. pass
  301. def R1_dist(self, q, p):
  302. return abs(q-p)
  303. def S1_dist(self, q, p):
  304. return min(abs(q-p), 1- abs(q-p))
  305. def P3_dist(self, q, p):
  306. # cubes with antipodal points
  307. q1, q2, q3 = q
  308. p1, p2, p3 = p
  309. d1 = np.sqrt((q1-p1)**2 + (q2-p2)**2 + (q3-p3)**2)
  310. d2 = np.sqrt((1-abs(q1-p1))**2 + (1-abs(q2-p2))**2 + (1-abs(q3-p3))**2)
  311. d3 = np.sqrt((-q1-p1)**2 + (-q2-p2)**2 + (q3+1-p3)**2)
  312. d4 = np.sqrt((-q1-p1)**2 + (-q2-p2)**2 + (q3-1-p3)**2)
  313. d5 = np.sqrt((-q1-p1)**2 + (q2+1-p2)**2 + (-q3-p3)**2)
  314. d6 = np.sqrt((-q1-p1)**2 + (q2-1-p2)**2 + (-q3-p3)**2)
  315. d7 = np.sqrt((q1+1-p1)**2 + (-q2-p2)**2 + (-q3-p3)**2)
  316. d8 = np.sqrt((q1-1-p1)**2 + (-q2-p2)**2 + (-q3-p3)**2)
  317. return min(d1,d2,d3,d4,d5,d6,d7,d8)
  318. if __name__ == '__main__':
  319. from rrt_3D.env3D import env
  320. import time
  321. import matplotlib.pyplot as plt
  322. class rrt_demo:
  323. def __init__(self):
  324. self.env = env()
  325. self.x0, self.xt = tuple(self.env.start), tuple(self.env.goal)
  326. self.stepsize = 0.5
  327. self.maxiter = 10000
  328. self.ind, self.i = 0, 0
  329. self.done = False
  330. self.Path = []
  331. self.V = []
  332. self.head = Node(self.x0)
  333. def run(self):
  334. while self.ind < self.maxiter:
  335. xrand = sampleFree(self) # O(1)
  336. nearest_node = tree_nearest(self.head, xrand) # O(N)
  337. xnew = tree_steer(self, nearest_node, xrand) # O(1)
  338. collide, _ = isCollide(self, nearest_node.pos, xnew) # O(num obs)
  339. if not collide:
  340. new_node = tree_add_edge(nearest_node, xnew) # O(1)
  341. # if the path is found
  342. if getDist(xnew, self.xt) <= self.stepsize:
  343. end_node = tree_add_edge(new_node, self.xt)
  344. self.Path = tree_path(self, end_node)
  345. break
  346. self.i += 1
  347. self.ind += 1
  348. self.done = True
  349. self.V, self.E = tree_print(self.head)
  350. print(self.E)
  351. visualization(self)
  352. plt.show()
  353. A = rrt_demo()
  354. st = time.time()
  355. A.run()
  356. print(time.time() - st)