瀏覽代碼

updated plotting

391311qy 5 年之前
父節點
當前提交
c398cefee7

二進制
Sampling-based Planning/rrt_3D/__pycache__/env3D.cpython-37.pyc


二進制
Sampling-based Planning/rrt_3D/__pycache__/plot_util3D.cpython-37.pyc


二進制
Sampling-based Planning/rrt_3D/__pycache__/utils3D.cpython-37.pyc


+ 47 - 62
Sampling-based Planning/rrt_3D/plot_util3D.py

@@ -9,13 +9,10 @@ import numpy as np
 def CreateSphere(center,r):
     u = np.linspace(0,2* np.pi,30)
     v = np.linspace(0,np.pi,30)
-    x=np.outer(np.cos(u),np.sin(v))
-    y=np.outer(np.sin(u),np.sin(v))
-    z=np.outer(np.ones(np.size(u)),np.cos(v))
-    # shift and scale sphere
-    x = r*x + center[0]
-    y = r*y + center[1]
-    z = r*z + center[2]
+    x = np.outer(np.cos(u),np.sin(v))
+    y = np.outer(np.sin(u),np.sin(v))
+    z = np.outer(np.ones(np.size(u)),np.cos(v))
+    x, y, z = r*x + center[0], r*y + center[1], r*z + center[2]
     return (x,y,z)
 
 def draw_Spheres(ax,balls):
@@ -23,79 +20,67 @@ def draw_Spheres(ax,balls):
         (xs,ys,zs) = CreateSphere(i[0:3],i[-1])
         ax.plot_wireframe(xs, ys, zs, alpha=0.15,color="b")
 
-def draw_block_list(ax, blocks):
+def draw_block_list(ax, blocks ,color=None,alpha=0.15):
     '''
-    Subroutine used by draw_map() to display the environment blocks
+    drawing the blocks on the graph
     '''
     v = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]],
                  dtype='float')
     f = np.array([[0, 1, 5, 4], [1, 2, 6, 5], [2, 3, 7, 6], [3, 0, 4, 7], [0, 1, 2, 3], [4, 5, 6, 7]])
-    # clr = blocks[:,6:]/255
     n = blocks.shape[0]
     d = blocks[:, 3:6] - blocks[:, :3]
     vl = np.zeros((8 * n, 3))
     fl = np.zeros((6 * n, 4), dtype='int64')
-    # fcl = np.zeros((6*n,3))
     for k in range(n):
         vl[k * 8:(k + 1) * 8, :] = v * d[k] + blocks[k, :3]
         fl[k * 6:(k + 1) * 6, :] = f + k * 8
-        # fcl[k*6:(k+1)*6,:] = clr[k,:]
-
     if type(ax) is Poly3DCollection:
         ax.set_verts(vl[fl])
     else:
-        pc = Poly3DCollection(vl[fl], alpha=0.15, linewidths=1, edgecolors='k')
-        # pc.set_facecolor(fcl)
+        pc = Poly3DCollection(vl[fl], alpha=alpha, linewidths=1, edgecolors='k')
+        pc.set_facecolor(color)
         h = ax.add_collection3d(pc)
         return h
 
-def visualization(initparams):
-    V = np.array(initparams.V)
-    E = initparams.E
-    Path = np.array(initparams.Path)
-    start = initparams.env.start
-    goal = initparams.env.goal
-    ax = plt.subplot(111, projection='3d',adjustable='box')
-    ax.view_init(elev=0., azim=90)
-    ax.clear()
-    draw_Spheres(ax, initparams.env.balls)
-    draw_block_list(ax, initparams.env.blocks)
-    edges = E.get_edge()
-    if edges != []:
-        for i in edges:
+def draw_line(ax,SET,visibility=1,color=None):
+    if SET != []:
+        for i in SET:
             xs = i[0][0], i[1][0]
             ys = i[0][1], i[1][1]
             zs = i[0][2], i[1][2]
-            line = plt3d.art3d.Line3D(xs, ys, zs, alpha=0.25)
+            line = plt3d.art3d.Line3D(xs, ys, zs, alpha=visibility, color=color)
             ax.add_line(line)
 
-    if Path != []:
-        for i in Path:
-            xs = i[0][0], i[1][0]
-            ys = i[0][1], i[1][1]
-            zs = i[0][2], i[1][2]
-            line = plt3d.art3d.Line3D(xs, ys, zs, color='r')
-            ax.add_line(line)
-
-    ax.plot(start[0:1], start[1:2], start[2:], 'go', markersize=7, markeredgecolor='k')
-    ax.plot(goal[0:1], goal[1:2], goal[2:], 'ro', markersize=7, markeredgecolor='k')
-    ax.scatter3D(V[:, 0], V[:, 1], V[:, 2], s=2, color='g',)
-
-    xmin, xmax = initparams.env.boundary[0], initparams.env.boundary[3]
-    ymin, ymax = initparams.env.boundary[1], initparams.env.boundary[4]
-    zmin, zmax = initparams.env.boundary[2], initparams.env.boundary[5]
-    dx, dy, dz = xmax-xmin, ymax-ymin, zmax-zmin
-    ax.set_xlim3d(xmin, xmax)
-    ax.set_ylim3d(ymin, ymax)
-    ax.set_zlim3d(zmin, zmax)
-    ax.get_proj = make_get_proj(ax,1*dx, 1*dy, 2*dy)
-    #ax.dist = 5
-    plt.xlabel('x')
-    plt.ylabel('y')
-    if not Path != []:
+def visualization(initparams):
+    if initparams.ind % 10 == 0 or initparams.done:
+        V = np.array(initparams.V)
+        E = initparams.E
+        Path = np.array(initparams.Path)
+        start = initparams.env.start
+        goal = initparams.env.goal
+        edges = E.get_edge()
+        # generate axis objects
+        ax = plt.subplot(111, projection='3d')
+        ax.view_init(elev=0., azim=90)
+        ax.clear()
+        # drawing objects
+        draw_Spheres(ax, initparams.env.balls)
+        draw_block_list(ax, initparams.env.blocks)
+        draw_block_list(ax, np.array([initparams.env.boundary]),alpha=0)
+        draw_line(ax,edges,visibility=0.25)
+        draw_line(ax,Path,color='r')
+        ax.scatter3D(V[:, 0], V[:, 1], V[:, 2], s=2, color='g',)
+        ax.plot(start[0:1], start[1:2], start[2:], 'go', markersize=7, markeredgecolor='k')
+        ax.plot(goal[0:1], goal[1:2], goal[2:], 'ro', markersize=7, markeredgecolor='k') 
+        # adjust the aspect ratio
+        xmin, xmax = initparams.env.boundary[0], initparams.env.boundary[3]
+        ymin, ymax = initparams.env.boundary[1], initparams.env.boundary[4]
+        zmin, zmax = initparams.env.boundary[2], initparams.env.boundary[5]
+        dx, dy, dz = xmax-xmin, ymax-ymin, zmax-zmin
+        ax.get_proj = make_get_proj(ax,1*dx, 1*dy, 2*dy)
+        plt.xlabel('x')
+        plt.ylabel('y')
         plt.pause(0.001)
-    else:
-        plt.show()
 
 def make_get_proj(self, rx, ry, rz):
     '''
@@ -104,7 +89,7 @@ def make_get_proj(self, rx, ry, rz):
     '''
 
     rm = max(rx, ry, rz)
-    kx = rm / rx; ky = rm / ry; kz = rm / rz;
+    kx = rm / rx; ky = rm / ry; kz = rm / rz
 
     # Copied directly from mpl_toolkit/mplot3d/axes3d.py. New or modified lines are
     # marked by ##
@@ -119,7 +104,7 @@ def make_get_proj(self, rx, ry, rz):
         worldM = proj3d.world_transformation(xmin, xmax,
                                              ymin, ymax,
                                              zmin, zmax)
-
+        ratio = 0.5
         # adjust the aspect ratio                          ##
         aspectM = proj3d.world_transformation(-kx + 1, kx, ##
                                               -ky + 1, ky, ##
@@ -128,9 +113,9 @@ def make_get_proj(self, rx, ry, rz):
         # look into the middle of the new coordinates
         R = np.array([0.5, 0.5, 0.5])
 
-        xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist
-        yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist
-        zp = R[2] + np.sin(relev) * self.dist
+        xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist *ratio
+        yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist *ratio
+        zp = R[2] + np.sin(relev) * self.dist *ratio
         E = np.array((xp, yp, zp))
 
         self.eye = E
@@ -142,7 +127,7 @@ def make_get_proj(self, rx, ry, rz):
             V = np.array((0, 0, -1))
         else:
             V = np.array((0, 0, 1))
-        zfront, zback = -self.dist, self.dist
+        zfront, zback = -self.dist *ratio, self.dist *ratio
 
         viewM = proj3d.view_transformation(E, R, V)
         perspM = proj3d.persp_transformation(zfront, zback)

+ 11 - 6
Sampling-based Planning/rrt_3D/rrt3D.py

@@ -6,6 +6,7 @@ import numpy as np
 from numpy.matlib import repmat
 from collections import defaultdict
 import time
+import matplotlib.pyplot as plt
 
 import os
 import sys
@@ -13,7 +14,7 @@ import sys
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../Sampling-based Planning/")
 
 from rrt_3D.env3D import env
-from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset
+from rrt_3D.utils3D import getDist, sampleFree, nearest, steer, isCollide, near, visualization, cost, path, edgeset, hash3D, dehash
 
 
 class rrtstar():
@@ -26,16 +27,18 @@ class rrtstar():
         self.maxiter = 10000
         self.stepsize = 0.5
         self.Path = []
+        self.done = False
 
     def wireup(self, x, y):
-        self.E.add_edge([x, y])  # add edge
-        self.Parent[str(x[0])][str(x[1])][str(x[2])] = y
+        self.E.add_edge([x,y]) # add edge
+        self.Parent[hash3D(x)] = y
 
     def run(self):
         self.V.append(self.env.start)
-        ind = 0
+        self.ind = 0
+        self.fig = plt.figure(figsize = (10,8))
         xnew = self.env.start
-        while ind < self.maxiter and getDist(xnew, self.env.goal) > 1:
+        while self.ind < self.maxiter and getDist(xnew, self.env.goal) > 1:
             xrand = sampleFree(self)
             xnearest = nearest(self, xrand)
             xnew = steer(self, xnearest, xrand)
@@ -44,12 +47,14 @@ class rrtstar():
                 self.wireup(xnew, xnearest)
                 # visualization(self)
                 self.i += 1
-            ind += 1
+            self.ind += 1
             if getDist(xnew, self.env.goal) <= 1:
                 self.wireup(self.env.goal, xnew)
                 self.Path, D = path(self)
                 print('Total distance = ' + str(D))
+        self.done = True
         visualization(self)
+        plt.show()
 
 
 if __name__ == '__main__':

+ 7 - 5
Sampling-based Planning/rrt_3D/rrtstar3D.py

@@ -6,6 +6,7 @@ import numpy as np
 from numpy.matlib import repmat
 from collections import defaultdict
 import time
+import matplotlib.pyplot as plt
 
 import os
 import sys
@@ -37,7 +38,6 @@ class rrtstar():
         xparent = self.Parent[hash3D(xnear)]
         a = [xnear,xparent]
         self.E.remove_edge(a) # remove and replace old the connection
-        #self.Parent.pop(hash3D(xnear), None)
 
     def reached(self):
         self.done = True
@@ -50,17 +50,18 @@ class rrtstar():
 
     def run(self):
         self.V.append(self.env.start)
-        ind = 0
+        self.ind = 0
         xnew = self.env.start
         print('start rrt*... ')
-        while ind < self.maxiter:
+        self.fig = plt.figure(figsize = (10,8))
+        while self.ind < self.maxiter:
             xrand    = sampleFree(self)
             xnearest = nearest(self,xrand)
             xnew     = steer(self,xnearest,xrand)
             if not isCollide(self,xnearest,xnew):
                 Xnear = near(self,xnew)
                 self.V.append(xnew) # add point
-                # visualization(self)
+                visualization(self)
                 # minimal path and minimal cost
                 xmin, cmin = xnearest, cost(self, xnearest) + getDist(xnearest, xnew)
                 # connecting along minimal cost path
@@ -82,12 +83,13 @@ class rrtstar():
                             self.removewire(xnear)
                             self.wireup(xnear, xnew)
                 self.i += 1
-            ind += 1
+            self.ind += 1
         # max sample reached
         self.reached()
         print('time used = ' + str(time.time()-starttime))
         print('Total distance = '+str(self.D))
         visualization(self)
+        plt.show()
         
 
 if __name__ == '__main__':

+ 8 - 2
Sampling-based Planning/rrt_3D/utils3D.py

@@ -53,11 +53,17 @@ def isinside(initparams, x):
             return True
     return False
 
+def isinbound(i, x):
+    if i[0] <= x[0] < i[3] and i[1] <= x[1] < i[4] and i[2] <= x[2] < i[5]:
+        return True
+    return False
 
 def isCollide(initparams, x, y):
     '''see if line intersects obstacle'''
     ray = getRay(x, y)
     dist = getDist(x, y)
+    if not isinbound(initparams.env.boundary,y):
+        return True
     for i in getAABB(initparams.env.blocks):
         shot = pyrr.geometric_tests.ray_intersect_aabb(ray, i)
         if shot is not None:
@@ -67,8 +73,8 @@ def isCollide(initparams, x, y):
     for i in initparams.env.balls:
         shot = pyrr.geometric_tests.ray_intersect_sphere(ray, i)
         if shot != []:
-            dists_wall = [getDist(x, j) for j in shot]
-            if all(dists_wall <= dist):  # collide
+            dists_ball = [getDist(x, j) for j in shot]
+            if all(dists_ball <= dist):  # collide
                 return True
     return False