Dijkstra.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """
  2. Dijkstra 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  9. "/../../Search_based_Planning/")
  10. from Search_based_Planning.Search_2D import queue, plotting, env
  11. class Dijkstra:
  12. def __init__(self, s_start, s_goal):
  13. self.s_start, self.s_goal = s_start, s_goal
  14. self.Env = env.Env()
  15. self.plotting = plotting.Plotting(self.s_start, self.s_goal)
  16. self.u_set = self.Env.motions # feasible input set
  17. self.obs = self.Env.obs # position of obstacles
  18. self.g = {self.s_start: 0, self.s_goal: float("inf")} # Cost to come
  19. self.OPEN = queue.QueuePrior() # priority queue / OPEN set
  20. self.OPEN.put(self.s_start, 0)
  21. self.CLOSED = [] # closed set & visited
  22. self.PARENT = {self.s_start: self.s_start}
  23. def searching(self):
  24. """
  25. Dijkstra Searching.
  26. :return: path, order of visited nodes in the planning
  27. """
  28. while not self.OPEN.empty():
  29. s = self.OPEN.get()
  30. self.CLOSED.append(s)
  31. if s == self.s_goal:
  32. break
  33. for s_n in self.get_neighbor(s):
  34. new_cost = self.g[s] + self.cost(s, s_n)
  35. if s_n not in self.g:
  36. self.g[s_n] = float("inf")
  37. if new_cost < self.g[s_n]:
  38. self.g[s_n] = new_cost
  39. self.OPEN.put(s_n, new_cost)
  40. self.PARENT[s_n] = s
  41. return self.extract_path(), self.CLOSED
  42. def get_neighbor(self, s):
  43. """
  44. find neighbors of state s that not in obstacles.
  45. :param s: state
  46. :return: neighbors
  47. """
  48. s_list = []
  49. for u in self.u_set:
  50. s_list.append(tuple([s[i] + u[i] for i in range(2)]))
  51. return s_list
  52. def extract_path(self):
  53. """
  54. Extract the path based on PARENT set.
  55. :return: The planning path
  56. """
  57. path = [self.s_goal]
  58. s = self.s_goal
  59. while True:
  60. s = self.PARENT[s]
  61. path.append(s)
  62. if s == self.s_start:
  63. break
  64. return list(path)
  65. def cost(self, s_start, s_goal):
  66. """
  67. Calculate Cost for this motion
  68. :param s_start: starting node
  69. :param s_goal: end node
  70. :return: Cost for this motion
  71. :note: Cost function could be more complicate!
  72. """
  73. if self.is_collision(s_start, s_goal):
  74. return float("inf")
  75. return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
  76. def is_collision(self, s_start, s_end):
  77. if s_start in self.obs or s_end in self.obs:
  78. return True
  79. if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
  80. if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
  81. s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  82. s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  83. else:
  84. s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  85. s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  86. if s1 in self.obs or s2 in self.obs:
  87. return True
  88. return False
  89. def main():
  90. s_start = (5, 5)
  91. s_goal = (45, 25)
  92. dijkstra = Dijkstra(s_start, s_goal)
  93. plot = plotting.Plotting(s_start, s_goal)
  94. path, visited = dijkstra.searching()
  95. plot.animation(path, visited, "Dijkstra's") # animation generate
  96. if __name__ == '__main__':
  97. main()