plot_util3D.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # plotting
  2. import matplotlib.pyplot as plt
  3. from mpl_toolkits.mplot3d import Axes3D
  4. from mpl_toolkits.mplot3d.art3d import Poly3DCollection
  5. import mpl_toolkits.mplot3d as plt3d
  6. from mpl_toolkits.mplot3d import proj3d
  7. import numpy as np
  8. def CreateSphere(center,r):
  9. u = np.linspace(0,2* np.pi,30)
  10. v = np.linspace(0,np.pi,30)
  11. x = np.outer(np.cos(u),np.sin(v))
  12. y = np.outer(np.sin(u),np.sin(v))
  13. z = np.outer(np.ones(np.size(u)),np.cos(v))
  14. x, y, z = r*x + center[0], r*y + center[1], r*z + center[2]
  15. return (x,y,z)
  16. def draw_Spheres(ax,balls):
  17. for i in balls:
  18. (xs,ys,zs) = CreateSphere(i[0:3],i[-1])
  19. ax.plot_wireframe(xs, ys, zs, alpha=0.15,color="b")
  20. def draw_block_list(ax, blocks ,color=None,alpha=0.15):
  21. '''
  22. drawing the blocks on the graph
  23. '''
  24. v = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]],
  25. dtype='float')
  26. f = np.array([[0, 1, 5, 4], [1, 2, 6, 5], [2, 3, 7, 6], [3, 0, 4, 7], [0, 1, 2, 3], [4, 5, 6, 7]])
  27. n = blocks.shape[0]
  28. d = blocks[:, 3:6] - blocks[:, :3]
  29. vl = np.zeros((8 * n, 3))
  30. fl = np.zeros((6 * n, 4), dtype='int64')
  31. for k in range(n):
  32. vl[k * 8:(k + 1) * 8, :] = v * d[k] + blocks[k, :3]
  33. fl[k * 6:(k + 1) * 6, :] = f + k * 8
  34. if type(ax) is Poly3DCollection:
  35. ax.set_verts(vl[fl])
  36. else:
  37. pc = Poly3DCollection(vl[fl], alpha=alpha, linewidths=1, edgecolors='k')
  38. pc.set_facecolor(color)
  39. h = ax.add_collection3d(pc)
  40. return h
  41. def obb_verts(obb):
  42. # 0.017004013061523438 for 1000 iters
  43. ori_body = np.array([[1,1,1],[-1,1,1],[-1,-1,1],[1,-1,1],\
  44. [1,1,-1],[-1,1,-1],[-1,-1,-1],[1,-1,-1]])
  45. # P + (ori * E)
  46. ori_body = np.multiply(ori_body,obb.E)
  47. # obb.O is orthornormal basis in {W}, aka rotation matrix in SO(3)
  48. verts = (obb.O@ori_body.T).T + obb.P
  49. return verts
  50. def draw_obb(ax, OBB, color=None,alpha=0.15):
  51. f = np.array([[0, 1, 5, 4], [1, 2, 6, 5], [2, 3, 7, 6], [3, 0, 4, 7], [0, 1, 2, 3], [4, 5, 6, 7]])
  52. n = OBB.shape[0]
  53. vl = np.zeros((8 * n, 3))
  54. fl = np.zeros((6 * n, 4), dtype='int64')
  55. for k in range(n):
  56. vl[k * 8:(k + 1) * 8, :] = obb_verts(OBB[k])
  57. fl[k * 6:(k + 1) * 6, :] = f + k * 8
  58. if type(ax) is Poly3DCollection:
  59. ax.set_verts(vl[fl])
  60. else:
  61. pc = Poly3DCollection(vl[fl], alpha=alpha, linewidths=1, edgecolors='k')
  62. pc.set_facecolor(color)
  63. h = ax.add_collection3d(pc)
  64. return h
  65. def draw_line(ax,SET,visibility=1,color=None):
  66. if SET != []:
  67. for i in SET:
  68. xs = i[0][0], i[1][0]
  69. ys = i[0][1], i[1][1]
  70. zs = i[0][2], i[1][2]
  71. line = plt3d.art3d.Line3D(xs, ys, zs, alpha=visibility, color=color)
  72. ax.add_line(line)
  73. def visualization(initparams):
  74. if initparams.ind % 20 == 0 or initparams.done:
  75. V = np.array(list(initparams.V))
  76. E = initparams.E
  77. Path = np.array(initparams.Path)
  78. start = initparams.env.start
  79. goal = initparams.env.goal
  80. edges = E.get_edge()
  81. # generate axis objects
  82. ax = plt.subplot(111, projection='3d')
  83. #ax.view_init(elev=0.+ 0.03*initparams.ind/(2*np.pi), azim=90 + 0.03*initparams.ind/(2*np.pi))
  84. #ax.view_init(elev=0., azim=90.)
  85. ax.view_init(elev=8., azim=120.)
  86. #ax.view_init(elev=-8., azim=180)
  87. ax.clear()
  88. # drawing objects
  89. draw_Spheres(ax, initparams.env.balls)
  90. draw_block_list(ax, initparams.env.blocks)
  91. if initparams.env.OBB is not None:
  92. draw_obb(ax,initparams.env.OBB)
  93. draw_block_list(ax, np.array([initparams.env.boundary]),alpha=0)
  94. draw_line(ax,edges,visibility=0.25)
  95. draw_line(ax,Path,color='r')
  96. if len(V) > 0:
  97. ax.scatter3D(V[:, 0], V[:, 1], V[:, 2], s=2, color='g',)
  98. ax.plot(start[0:1], start[1:2], start[2:], 'go', markersize=7, markeredgecolor='k')
  99. ax.plot(goal[0:1], goal[1:2], goal[2:], 'ro', markersize=7, markeredgecolor='k')
  100. # adjust the aspect ratio
  101. xmin, xmax = initparams.env.boundary[0], initparams.env.boundary[3]
  102. ymin, ymax = initparams.env.boundary[1], initparams.env.boundary[4]
  103. zmin, zmax = initparams.env.boundary[2], initparams.env.boundary[5]
  104. dx, dy, dz = xmax-xmin, ymax-ymin, zmax-zmin
  105. ax.get_proj = make_get_proj(ax,1*dx, 1*dy, 2*dy)
  106. plt.xlabel('x')
  107. plt.ylabel('y')
  108. plt.pause(0.0001)
  109. def make_get_proj(self, rx, ry, rz):
  110. '''
  111. Return a variation on :func:`~mpl_toolkit.mplot2d.axes3d.Axes3D.getproj` that
  112. makes the box aspect ratio equal to *rx:ry:rz*, using an axes object *self*.
  113. '''
  114. rm = max(rx, ry, rz)
  115. kx = rm / rx; ky = rm / ry; kz = rm / rz
  116. # Copied directly from mpl_toolkit/mplot3d/axes3d.py. New or modified lines are
  117. # marked by ##
  118. def get_proj():
  119. relev, razim = np.pi * self.elev/180, np.pi * self.azim/180
  120. xmin, xmax = self.get_xlim3d()
  121. ymin, ymax = self.get_ylim3d()
  122. zmin, zmax = self.get_zlim3d()
  123. # transform to uniform world coordinates 0-1.0,0-1.0,0-1.0
  124. worldM = proj3d.world_transformation(xmin, xmax,
  125. ymin, ymax,
  126. zmin, zmax)
  127. ratio = 0.5
  128. # adjust the aspect ratio ##
  129. aspectM = proj3d.world_transformation(-kx + 1, kx, ##
  130. -ky + 1, ky, ##
  131. -kz + 1, kz) ##
  132. # look into the middle of the new coordinates
  133. R = np.array([0.5, 0.5, 0.5])
  134. xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist *ratio
  135. yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist *ratio
  136. zp = R[2] + np.sin(relev) * self.dist *ratio
  137. E = np.array((xp, yp, zp))
  138. self.eye = E
  139. self.vvec = R - E
  140. self.vvec = self.vvec / np.linalg.norm(self.vvec)
  141. if abs(relev) > np.pi/2:
  142. # upside down
  143. V = np.array((0, 0, -1))
  144. else:
  145. V = np.array((0, 0, 1))
  146. zfront, zback = -self.dist *ratio, self.dist *ratio
  147. viewM = proj3d.view_transformation(E, R, V)
  148. perspM = proj3d.persp_transformation(zfront, zback)
  149. M0 = np.dot(viewM, np.dot(aspectM, worldM)) ##
  150. M = np.dot(perspM, M0)
  151. return M
  152. return get_proj
  153. if __name__ == '__main__':
  154. pass