utils3D.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import numpy as np
  2. from numpy.matlib import repmat
  3. import pyrr as pyrr
  4. from collections import deque
  5. import os
  6. import sys
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling_based_Planning/")
  8. from rrt_3D.plot_util3D import visualization
  9. def getRay(x, y):
  10. direc = [y[0] - x[0], y[1] - x[1], y[2] - x[2]]
  11. return np.array([x, direc])
  12. def getAABB(blocks):
  13. AABB = []
  14. for i in blocks:
  15. AABB.append(np.array([np.add(i[0:3], -0), np.add(i[3:6], 0)])) # make AABBs alittle bit of larger
  16. return AABB
  17. def getDist(pos1, pos2):
  18. return np.sqrt(sum([(pos1[0] - pos2[0]) ** 2, (pos1[1] - pos2[1]) ** 2, (pos1[2] - pos2[2]) ** 2]))
  19. ''' The following utils can be used for rrt or rrt*,
  20. required param initparams should have
  21. env, environement generated from env3D
  22. V, node set
  23. E, edge set
  24. i, nodes added
  25. maxiter, maximum iteration allowed
  26. stepsize, leaf growth restriction
  27. '''
  28. def sampleFree(initparams, bias = 0.1):
  29. '''biased sampling'''
  30. x = np.random.uniform(initparams.env.boundary[0:3], initparams.env.boundary[3:6])
  31. i = np.random.random()
  32. if isinside(initparams, x):
  33. return sampleFree(initparams)
  34. else:
  35. if i < bias:
  36. return np.array(initparams.xt) + 1
  37. else:
  38. return x
  39. return x
  40. # ---------------------- Collision checking algorithms
  41. def isinside(initparams, x):
  42. '''see if inside obstacle'''
  43. for i in initparams.env.blocks:
  44. if isinbound(i, x):
  45. return True
  46. for i in initparams.env.OBB:
  47. if isinbound(i, x, mode = 'obb'):
  48. return True
  49. for i in initparams.env.balls:
  50. if isinball(i, x):
  51. return True
  52. return False
  53. def isinbound(i, x, mode = False, factor = 0, isarray = False):
  54. if mode == 'obb':
  55. return isinobb(i, x, isarray)
  56. if isarray:
  57. compx = (i[0] - factor <= x[:,0]) & (x[:,0] < i[3] + factor)
  58. compy = (i[1] - factor <= x[:,1]) & (x[:,1] < i[4] + factor)
  59. compz = (i[2] - factor <= x[:,2]) & (x[:,2] < i[5] + factor)
  60. return compx & compy & compz
  61. else:
  62. return i[0] - factor <= x[0] < i[3] + factor and i[1] - factor <= x[1] < i[4] + factor and i[2] - factor <= x[2] < i[5]
  63. def isinobb(i, x, isarray = False):
  64. # transform the point from {W} to {body}
  65. if isarray:
  66. pts = (i.T@np.column_stack((x, np.ones(len(x)))).T).T[:,0:3]
  67. block = [- i.E[0],- i.E[1],- i.E[2],+ i.E[0],+ i.E[1],+ i.E[2]]
  68. return isinbound(block, pts, isarray = isarray)
  69. else:
  70. pt = i.T@np.append(x,1)
  71. block = [- i.E[0],- i.E[1],- i.E[2],+ i.E[0],+ i.E[1],+ i.E[2]]
  72. return isinbound(block, pt)
  73. def isinball(i, x, factor = 0):
  74. if getDist(i[0:3], x) <= i[3] + factor:
  75. return True
  76. return False
  77. def lineSphere(p0, p1, ball):
  78. # https://cseweb.ucsd.edu/classes/sp19/cse291-d/Files/CSE291_13_CollisionDetection.pdf
  79. c, r = ball[0:3], ball[-1]
  80. line = [p1[0] - p0[0], p1[1] - p0[1], p1[2] - p0[2]]
  81. d1 = [c[0] - p0[0], c[1] - p0[1], c[2] - p0[2]]
  82. t = (1 / (line[0] * line[0] + line[1] * line[1] + line[2] * line[2])) * (
  83. line[0] * d1[0] + line[1] * d1[1] + line[2] * d1[2])
  84. if t <= 0:
  85. if (d1[0] * d1[0] + d1[1] * d1[1] + d1[2] * d1[2]) <= r ** 2: return True
  86. elif t >= 1:
  87. d2 = [c[0] - p1[0], c[1] - p1[1], c[2] - p1[2]]
  88. if (d2[0] * d2[0] + d2[1] * d2[1] + d2[2] * d2[2]) <= r ** 2: return True
  89. elif 0 < t < 1:
  90. x = [p0[0] + t * line[0], p0[1] + t * line[1], p0[2] + t * line[2]]
  91. k = [c[0] - x[0], c[1] - x[1], c[2] - x[2]]
  92. if (k[0] * k[0] + k[1] * k[1] + k[2] * k[2]) <= r ** 2: return True
  93. return False
  94. def lineAABB(p0, p1, dist, aabb):
  95. # https://www.gamasutra.com/view/feature/131790/simple_intersection_tests_for_games.php?print=1
  96. # aabb should have the attributes of P, E as center point and extents
  97. mid = [(p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2, (p0[2] + p1[2]) / 2] # mid point
  98. I = [(p1[0] - p0[0]) / dist, (p1[1] - p0[1]) / dist, (p1[2] - p0[2]) / dist] # unit direction
  99. hl = dist / 2 # radius
  100. T = [aabb.P[0] - mid[0], aabb.P[1] - mid[1], aabb.P[2] - mid[2]]
  101. # do any of the principal axis form a separting axis?
  102. if abs(T[0]) > (aabb.E[0] + hl * abs(I[0])): return False
  103. if abs(T[1]) > (aabb.E[1] + hl * abs(I[1])): return False
  104. if abs(T[2]) > (aabb.E[2] + hl * abs(I[2])): return False
  105. # I.cross(s axis) ?
  106. r = aabb.E[1] * abs(I[2]) + aabb.E[2] * abs(I[1])
  107. if abs(T[1] * I[2] - T[2] * I[1]) > r: return False
  108. # I.cross(y axis) ?
  109. r = aabb.E[0] * abs(I[2]) + aabb.E[2] * abs(I[0])
  110. if abs(T[2] * I[0] - T[0] * I[2]) > r: return False
  111. # I.cross(z axis) ?
  112. r = aabb.E[0] * abs(I[1]) + aabb.E[1] * abs(I[0])
  113. if abs(T[0] * I[1] - T[1] * I[0]) > r: return False
  114. return True
  115. def lineOBB(p0, p1, dist, obb):
  116. # transform points to obb frame
  117. res = obb.T@np.column_stack([np.array([p0,p1]),[1,1]]).T
  118. # record old position and set the position to origin
  119. oldP, obb.P= obb.P, [0,0,0]
  120. # calculate segment-AABB testing
  121. ans = lineAABB(res[0:3,0],res[0:3,1],dist,obb)
  122. # reset the position
  123. obb.P = oldP
  124. return ans
  125. def isCollide(initparams, x, child, dist=None):
  126. '''see if line intersects obstacle'''
  127. '''specified for expansion in A* 3D lookup table'''
  128. if dist==None:
  129. dist = getDist(x, child)
  130. # check in bound
  131. if not isinbound(initparams.env.boundary, child):
  132. return True, dist
  133. # check collision in AABB
  134. for i in range(len(initparams.env.AABB)):
  135. if lineAABB(x, child, dist, initparams.env.AABB[i]):
  136. return True, dist
  137. # check collision in ball
  138. for i in initparams.env.balls:
  139. if lineSphere(x, child, i):
  140. return True, dist
  141. # check collision with obb
  142. for i in initparams.env.OBB:
  143. if lineOBB(x, child, dist, i):
  144. return True, dist
  145. return False, dist
  146. # ---------------------- leaf node extending algorithms
  147. def nearest(initparams, x, isset=False):
  148. V = np.array(initparams.V)
  149. if initparams.i == 0:
  150. return initparams.V[0]
  151. xr = repmat(x, len(V), 1)
  152. dists = np.linalg.norm(xr - V, axis=1)
  153. return tuple(initparams.V[np.argmin(dists)])
  154. def near(initparams, x):
  155. # s = np.array(s)
  156. V = np.array(initparams.V)
  157. if initparams.i == 0:
  158. return [initparams.V[0]]
  159. cardV = len(initparams.V)
  160. eta = initparams.eta
  161. gamma = initparams.gamma
  162. r = min(gamma * (np.log(cardV) / cardV ** (1/3)), eta)
  163. if initparams.done:
  164. r = 1
  165. xr = repmat(x, len(V), 1)
  166. inside = np.linalg.norm(xr - V, axis=1) < r
  167. nearpoints = V[inside]
  168. return np.array(nearpoints)
  169. def steer(initparams, x, y, DIST=False):
  170. # steer from s to y
  171. if np.equal(x, y).all():
  172. return x, 0.0
  173. dist, step = getDist(y, x), initparams.stepsize
  174. step = min(dist, step)
  175. increment = ((y[0] - x[0]) / dist * step, (y[1] - x[1]) / dist * step, (y[2] - x[2]) / dist * step)
  176. xnew = (x[0] + increment[0], x[1] + increment[1], x[2] + increment[2])
  177. # direc = (y - s) / np.linalg.norm(y - s)
  178. # xnew = s + initparams.stepsize * direc
  179. if DIST:
  180. return xnew, dist
  181. return xnew, dist
  182. def cost(initparams, x):
  183. '''here use the additive recursive cost function'''
  184. if x == initparams.x0:
  185. return 0
  186. return cost(initparams, initparams.Parent[x]) + getDist(x, initparams.Parent[x])
  187. def cost_from_set(initparams, x):
  188. '''here use a incremental cost set function'''
  189. if x == initparams.x0:
  190. return 0
  191. return initparams.COST[initparams.Parent[x]] + getDist(x, initparams.Parent[x])
  192. def path(initparams, Path=[], dist=0):
  193. x = initparams.xt
  194. while x != initparams.x0:
  195. x2 = initparams.Parent[x]
  196. Path.append(np.array([x, x2]))
  197. dist += getDist(x, x2)
  198. x = x2
  199. return Path, dist
  200. class edgeset(object):
  201. def __init__(self):
  202. self.E = {}
  203. def add_edge(self, edge):
  204. x, y = edge[0], edge[1]
  205. if x in self.E:
  206. self.E[x].add(y)
  207. else:
  208. self.E[x] = set()
  209. self.E[x].add(y)
  210. def remove_edge(self, edge):
  211. x, y = edge[0], edge[1]
  212. self.E[x].remove(y)
  213. def get_edge(self, nodes = None):
  214. edges = []
  215. if nodes is None:
  216. for v in self.E:
  217. for n in self.E[v]:
  218. # if (n,v) not in edges:
  219. edges.append((v, n))
  220. else:
  221. for v in nodes:
  222. for n in self.E[tuple(v)]:
  223. edges.append((v, n))
  224. return edges
  225. def isEndNode(self, node):
  226. return node not in self.E
  227. #------------------------ use a linked list to express the tree
  228. class Node:
  229. def __init__(self, data):
  230. self.pos = data
  231. self.Parent = None
  232. self.child = set()
  233. def tree_add_edge(node_in_tree, x):
  234. # add an edge at the specified parent
  235. node_to_add = Node(x)
  236. # node_in_tree = tree_bfs(head, xparent)
  237. node_in_tree.child.add(node_to_add)
  238. node_to_add.Parent = node_in_tree
  239. return node_to_add
  240. def tree_bfs(head, x):
  241. # searches s in order of bfs
  242. node = head
  243. Q = []
  244. Q.append(node)
  245. while Q:
  246. curr = Q.pop()
  247. if curr.pos == x:
  248. return curr
  249. for child_node in curr.child:
  250. Q.append(child_node)
  251. def tree_nearest(head, x):
  252. # find the node nearest to s
  253. D = np.inf
  254. min_node = None
  255. Q = []
  256. Q.append(head)
  257. while Q:
  258. curr = Q.pop()
  259. dist = getDist(curr.pos, x)
  260. # record the current best
  261. if dist < D:
  262. D, min_node = dist, curr
  263. # bfs
  264. for child_node in curr.child:
  265. Q.append(child_node)
  266. return min_node
  267. def tree_steer(initparams, node, x):
  268. # steer from node to s
  269. dist, step = getDist(node.pos, x), initparams.stepsize
  270. increment = ((node.pos[0] - x[0]) / dist * step, (node.pos[1] - x[1]) / dist * step, (node.pos[2] - x[2]) / dist * step)
  271. xnew = (x[0] + increment[0], x[1] + increment[1], x[2] + increment[2])
  272. return xnew
  273. def tree_print(head):
  274. Q = []
  275. Q.append(head)
  276. verts = []
  277. edge = []
  278. while Q:
  279. curr = Q.pop()
  280. # print(curr.pos)
  281. verts.append(curr.pos)
  282. if curr.Parent == None:
  283. pass
  284. else:
  285. edge.append([curr.pos, curr.Parent.pos])
  286. for child in curr.child:
  287. Q.append(child)
  288. return verts, edge
  289. def tree_path(initparams, end_node):
  290. path = []
  291. curr = end_node
  292. while curr.pos != initparams.x0:
  293. path.append([curr.pos, curr.Parent.pos])
  294. curr = curr.Parent
  295. return path
  296. #---------------KD tree, used for nearest neighbor search
  297. class kdTree:
  298. def __init__(self):
  299. pass
  300. def R1_dist(self, q, p):
  301. return abs(q-p)
  302. def S1_dist(self, q, p):
  303. return min(abs(q-p), 1- abs(q-p))
  304. def P3_dist(self, q, p):
  305. # cubes with antipodal points
  306. q1, q2, q3 = q
  307. p1, p2, p3 = p
  308. d1 = np.sqrt((q1-p1)**2 + (q2-p2)**2 + (q3-p3)**2)
  309. d2 = np.sqrt((1-abs(q1-p1))**2 + (1-abs(q2-p2))**2 + (1-abs(q3-p3))**2)
  310. d3 = np.sqrt((-q1-p1)**2 + (-q2-p2)**2 + (q3+1-p3)**2)
  311. d4 = np.sqrt((-q1-p1)**2 + (-q2-p2)**2 + (q3-1-p3)**2)
  312. d5 = np.sqrt((-q1-p1)**2 + (q2+1-p2)**2 + (-q3-p3)**2)
  313. d6 = np.sqrt((-q1-p1)**2 + (q2-1-p2)**2 + (-q3-p3)**2)
  314. d7 = np.sqrt((q1+1-p1)**2 + (-q2-p2)**2 + (-q3-p3)**2)
  315. d8 = np.sqrt((q1-1-p1)**2 + (-q2-p2)**2 + (-q3-p3)**2)
  316. return min(d1,d2,d3,d4,d5,d6,d7,d8)
  317. if __name__ == '__main__':
  318. from rrt_3D.env3D import env
  319. import time
  320. import matplotlib.pyplot as plt
  321. class rrt_demo:
  322. def __init__(self):
  323. self.env = env()
  324. self.x0, self.xt = tuple(self.env.start), tuple(self.env.goal)
  325. self.stepsize = 0.5
  326. self.maxiter = 10000
  327. self.ind, self.i = 0, 0
  328. self.done = False
  329. self.Path = []
  330. self.V = []
  331. self.head = Node(self.x0)
  332. def run(self):
  333. while self.ind < self.maxiter:
  334. xrand = sampleFree(self) # O(1)
  335. nearest_node = tree_nearest(self.head, xrand) # O(N)
  336. xnew = tree_steer(self, nearest_node, xrand) # O(1)
  337. collide, _ = isCollide(self, nearest_node.pos, xnew) # O(num obs)
  338. if not collide:
  339. new_node = tree_add_edge(nearest_node, xnew) # O(1)
  340. # if the path is found
  341. if getDist(xnew, self.xt) <= self.stepsize:
  342. end_node = tree_add_edge(new_node, self.xt)
  343. self.Path = tree_path(self, end_node)
  344. break
  345. self.i += 1
  346. self.ind += 1
  347. self.done = True
  348. self.V, self.E = tree_print(self.head)
  349. print(self.E)
  350. visualization(self)
  351. plt.show()
  352. A = rrt_demo()
  353. st = time.time()
  354. A.run()
  355. print(time.time() - st)