yue qi 5 rokov pred
rodič
commit
0a8f49ac61

+ 125 - 65
Sampling_based_Planning/rrt_3D/BIT_star3D.py

@@ -33,107 +33,167 @@ class BIT_star:
     def __init__(self):
         self.env = env()
         self.xstart, self.xgoal = tuple(self.env.start), tuple(self.env.goal)
-        self.maxiter = 1000
+        self.maxiter = 1000 # used for determining how many batches needed
         # radius calc
         self.eta = 1 # bigger or equal to 1
         self.n = 1000
-        self.Xf_hat = 1 # TODO
         self.nn = 1 # TODO
+        
+        self.edgeCost = {} # corresponding to c
+        self.heuristic_edgeCost = {} # correspoinding to c_hat
 
     def run(self):
-        V = {self.xstart}
-        E = set()
-        T = (V, E) # tree
-        Xsamples = {self.xgoal}
-        QE = set()
-        QV = set()
-        r = np.inf
+        self.V = {self.xstart}
+        self.E = set()
+        self.Parent = {}
+        self.T = (self.V, self.E) # tree
+        self.Xsamples = {self.xgoal}
+        self.QE = set()
+        self.QV = set()
+        self.r = np.inf
         ind = 0
         while True:
-            if len(QE) == 0 and len(QV) == 0:
-                Xsamples, V, E = self.Prune(self.g_T(self.xgoal), Xsamples, V, E)
-                Vold = copy.deepcopy(V)
-                QV = copy.deepcopy(V)
-                r = self.radius(len(V) + len(Xsamples))
-            while self.BestQueueValue(QV) <= self.BestQueueValue(QE):
-                QV, QE = self.ExpandVertex(self.BestInQueue(QV), QV, QE, Xsamples, Vold, E, V, r)
-            (vm, xm) = self.BestInQueue(QE)
-            QE.difference_update({(vm, xm)})
+            # for the first round
+            if len(self.QE) == 0 and len(self.QV) == 0:
+                self.Prune(self.g_T(self.xgoal))
+                self.Xsamples = self.Sample(m, self.g_T(self.xgoal)) # sample function
+                self.Vold = copy.deepcopy(self.V)
+                self.QV = copy.deepcopy(self.V)
+                self.r = self.radius(len(self.V) + len(self.Xsamples))
+            while self.BestQueueValue(self.QV, mode = 'QV') <= self.BestQueueValue(self.QE, mode = 'QE'):
+                self.ExpandVertex(self.BestInQueue(self.QV, mode = 'QV'))
+            (vm, xm) = self.BestInQueue(self.QE, mode = 'QE')
+            self.QE.difference_update({(vm, xm)})
             if self.g_T(vm) + self.c_hat(vm, xm) + self.h_hat(xm) < self.g_T(self.xgoal):
                 if self.g_hat(vm) + self.c(vm, xm) + self.h_hat(xm) < self.g_T(self.xgoal):
                     if self.g_T(vm) + self.c(vm, xm) < self.g_T(xm):
-                        if xm in V:
-                            E.difference_update({(v, x) for (v, x) in E if x == xm})
+                        if xm in self.V:
+                            self.E.difference_update({(v, x) for (v, x) in self.E if x == xm})
                         else:
-                            Xsamples.difference_update({xm})
-                            V.add(xm)
-                            QV.add(xm)
-                        E.add((vm, xm))
-                        QE.difference_update({(v, x) for (v, x) in QE if x == xm and self.g_T(v) + self.c_hat(v, x) >= self.g_T(x)})
+                            self.Xsamples.difference_update({xm})
+                            self.V.add(xm)
+                            self.QV.add(xm)
+                        self.E.add((vm, xm))
+                        self.Parent[vm] = xm # add parent or update parent
+                        self.QE.difference_update({(v, x) for (v, x) in self.QE if x == xm and self.g_T(v) + self.c_hat(v, x) >= self.g_T(x)})
 
             else:
-                QE = set()
-                QV = set()
+                self.QE = set()
+                self.QV = set()
             ind += 1
             if ind > self.maxiter:
                 break
-            return T
-
-    def ExpandVertex(self, v , QV, QE, Xsamples, Vold, E, V, r):
-        QV.difference_update({v})
-        Xnear = {x for x in Xsamples if getDist(x, v) <= r}
-        QE = {(v, x) for v in V for x in Xnear if self.g_hat(v) + self.c_hat(v, x) + self.h_hat(x) < self.g_T(self.xgoal)}
-        if v not in Vold:
-            Vnear = {w for w in V if getDist(w, v) <= r}
-            QE.update({(v,w) for v in V for w in Vnear if \
-                ((v,w) not in E) and \
+            return self.T
+
+    def Sample(self, m, cost):
+        # TODO need the informed rrt
+        pass
+    
+    def ExpandVertex(self, v):
+        self.QV.difference_update({v})
+        Xnear = {x for x in self.Xsamples if getDist(x, v) <= self.r}
+        self.QE.update({(v, x) for v in self.V for x in Xnear if self.g_hat(v) + self.c_hat(v, x) + self.h_hat(x) < self.g_T(self.xgoal)})
+        if v not in self.Vold:
+            Vnear = {w for w in self.V if getDist(w, v) <= self.r}
+            self.QE.update({(v,w) for v in self.V for w in Vnear if \
+                ((v,w) not in self.E) and \
                 (self.g_hat(v) + self.c_hat(v, w) + self.h_hat(w) < self.g_T(self.xgoal)) and \
                 (self.g_T(v) + self.c_hat(v, w) < self.g_T(w))})
-        return QV, QE
 
-    def Prune(self, c, Xsamples, V, E):
-        Xsamples = {x for x in Xsamples if self.f_hat(x) >= c}
-        V.difference_update({v for v in V if self.f_hat(v) >=c})
-        E.difference_update({(v, w) for (v, w) in E if (self.f_hat(v) > c) or (self.f_hat(w) > c)})
-        Xsamples.update({v for v in V if self.g_T(v) == np.inf})
-        V.difference_update({v for v in V if self.g_T(v) == np.inf})
-        return Xsamples, V, E
+    def Prune(self, c):
+        self.Xsamples = {x for x in self.Xsamples if self.f_hat(x) >= c}
+        self.V.difference_update({v for v in self.V if self.f_hat(v) >= c})
+        self.E.difference_update({(v, w) for (v, w) in self.E if (self.f_hat(v) > c) or (self.f_hat(w) > c)})
+        self.Xsamples.update({v for v in self.V if self.g_T(v) == np.inf})
+        self.V.difference_update({v for v in self.V if self.g_T(v) == np.inf})
 
     def radius(self, q):
         return 2 * self.eta * (1 + 1/self.n) ** (1/self.n) * \
-            (self.Lambda(self.Xf_hat) / self.Zeta ) ** (1/self.n) * \
+            (self.Lambda(self.Xf_hat(self.V)) / self.Zeta ) ** (1/self.n) * \
             (np.log(q) / q) ** (1/self.n)
 
     def Lambda(self, inputset):
         # lebesgue measure of a set, defined as 
         # mu: L(Rn) --> [0, inf], e.g. volume
-        pass 
+        return len(inputset)
 
-    def Zeta(self):
-        # unit ball
-        pass
+    def Xf_hat(self, X):
+        # the X is a set, defined as {x in X | fhat(x) <= cbest}
+        # where cbest is current best cost.
+        cbest = self.g_T(self.xgoal)
+        return {x for x in X if self.f_hat(x) <= cbest}
 
-    def BestInQueue(self, inputset):
-        pass
+    def Zeta(self):
+        # Lebesgue measure of a n dimensional unit ball
+        # since it's the 3D, use volume
+        return 4/3 * np.pi
+
+    def BestInQueue(self, inputset, mode):
+        # returns the best vertex in the vertex queue given this ordering
+        # mode = 'QE' or 'QV'
+        _, best_state = self.find_best(inputset, mode)
+        return best_state
+
+    def BestQueueValue(self, inputset, mode):
+        # returns the best value in the vertex queue given this ordering
+        # mode = 'QE' or 'QV'
+        best_val, _ = self.find_best(inputset, mode)
+        return best_val
+
+    def find_best(self, inputset, mode):
+        min_val, min_state = np.inf, None
+        for state in inputset:
+            if mode == 'QE':
+                curr_val = self.g_T(state[0]) + self.c_hat(state[0], state[1]) + self.h_hat(state[1])
+            elif mode == 'QV':
+                curr_val = self.g_T(state) + self.h_hat(state)
+            if curr_val < min_val:
+                min_val, min_state = curr_val, state
+        return min_val, min_state
+    
+    def g_hat(self, v):
+        return getDist(self.xstart, v)
 
-    def BestQueueValue(self, inputset):
-        pass
+    def h_hat(self, v):
+        return getDist(self.xgoal, v)
 
-    def g_hat(self, v):
-        pass
+    def f_hat(self, v):
+        # f = g + h: estimate cost
+        return self.g_hat(v) + self.h_hat(v)
 
     def c(self, v, w):
-        pass
+        # admissible estimate of the cost of an edge between state v, w
+        if (v,w) in self.edgeCost:
+            pass
+        else:
+            collide, dist = isCollide(self, v, w)
+            if collide:
+                self.edgeCost[(v,w)] = np.inf
+            else: 
+                self.edgeCost[(v,w)] = dist
+        return self.edgeCost[(v,w)]
 
     def c_hat(self, v, w):
-        pass
-
-    def f_hat(self, v):
-        pass
-
-    def h_hat(self, v):
-        pass
+        # c_hat < c < np.inf
+        # heuristic estimate of the edge cost, since c is expensive
+        if (v,w) in self.heuristic_edgeCost:
+            pass
+        else:
+            self.heuristic_edgeCost[(v,w)] = getDist(v, w)
+        return self.heuristic_edgeCost[(v,w)]
 
     def g_T(self, v):
-        pass
+        # represent cost-to-come from the start in the tree, 
+        # if the state is not in tree, or unreachable, return inf
+        if v in self.Parent:
+            cost_to_come = 0
+            while v != self.xstart:
+                cost_to_come += self.c(v, self.Parent[v])
+                v = self.Parent[v]
+            return cost_to_come
+        elif v == self.xstart:
+            return 0
+        else:
+            return np.inf
+         
     

+ 33 - 0
Sampling_based_Planning/rrt_3D/queue.py

@@ -26,6 +26,11 @@ class MinheapPQ:
         heapq.heappush(self.pq, entry)
         self.nodes.add(item)
 
+    def put_set(self, dictin):
+        '''add a new dict into the priority queue'''
+        for item, priority in enumerate(dictin):
+            self.put(item, priority)
+
     def check_remove(self, item):
         if item not in self.entry_finder:
             return
@@ -33,6 +38,34 @@ class MinheapPQ:
         entry[-1] = self.REMOVED
         self.nodes.remove(item)
 
+    def check_remove_set(self, set_input):
+        if len(set_input) == 0:
+            return
+        for item in set_input:
+            if item not in self.entry_finder:
+                continue
+            entry = self.entry_finder.pop(item)
+            entry[-1] = self.REMOVED
+            self.nodes.remove(item)
+
+    def priority_filtering(self, threshold, mode):
+        # mode: bigger: check and remove those key vals bigger than threshold
+        if mode == 'lowpass':
+            for entry in self.enumerate():
+                item = entry[2]
+                if entry[0] >= threshold: # priority
+                    _ = self.entry_finder.pop(item)
+                    entry[-1] = self.REMOVED
+                    self.nodes.remove(item)
+        # mode: smaller: check and remove those key vals smaller than threshold
+        elif mode == 'highpass':
+            for entry in self.enumerate():
+                item = entry[2]
+                if entry[0] <= threshold: # priority
+                    _ = self.entry_finder.pop(item)
+                    entry[-1] = self.REMOVED
+                    self.nodes.remove(item)
+        
     def get(self):
         """Remove and return the lowest priority task. Raise KeyError if empty."""
         while self.pq: