Dijkstra.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """
  2. Dijkstra 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. import heapq
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Search_based_Planning/")
  11. from Search_based_Planning.Search_2D import plotting, env
  12. class Dijkstra:
  13. def __init__(self, s_start, s_goal):
  14. self.s_start = s_start
  15. self.s_goal = s_goal
  16. self.Env = env.Env()
  17. self.plotting = plotting.Plotting(self.s_start, self.s_goal)
  18. self.u_set = self.Env.motions # feasible input set
  19. self.obs = self.Env.obs # position of obstacles
  20. self.OPEN = [] # priority queue / OPEN set
  21. self.CLOSED = [] # closed set & visited
  22. self.PARENT = dict() # record parent
  23. self.g = dict() # Cost to come
  24. def searching(self):
  25. """
  26. Dijkstra Searching.
  27. :return: path, visited order
  28. """
  29. self.PARENT[self.s_start] = self.s_start
  30. self.g[self.s_start] = 0
  31. self.g[self.s_goal] = math.inf
  32. heapq.heappush(self.OPEN, (0, self.s_start))
  33. while self.OPEN:
  34. _, s = heapq.heappop(self.OPEN)
  35. self.CLOSED.append(s)
  36. if s == self.s_goal:
  37. break
  38. for s_n in self.get_neighbor(s):
  39. new_cost = self.g[s] + self.cost(s, s_n)
  40. if s_n not in self.g:
  41. self.g[s_n] = math.inf
  42. if new_cost < self.g[s_n]:
  43. self.g[s_n] = new_cost
  44. heapq.heappush(self.OPEN, (new_cost, s_n))
  45. self.PARENT[s_n] = s
  46. return self.extract_path(), self.CLOSED
  47. def get_neighbor(self, s):
  48. """
  49. find neighbors of state s that not in obstacles.
  50. :param s: state
  51. :return: neighbors
  52. """
  53. return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
  54. def extract_path(self):
  55. """
  56. Extract the path based on PARENT set.
  57. :return: The planning path
  58. """
  59. path = [self.s_goal]
  60. s = self.s_goal
  61. while True:
  62. s = self.PARENT[s]
  63. path.append(s)
  64. if s == self.s_start:
  65. break
  66. return list(path)
  67. def cost(self, s_start, s_goal):
  68. """
  69. Calculate Cost for this motion
  70. :param s_start: starting node
  71. :param s_goal: end node
  72. :return: Cost for this motion
  73. """
  74. if self.is_collision(s_start, s_goal):
  75. return math.inf
  76. return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
  77. def is_collision(self, s_start, s_end):
  78. """
  79. check if the line segment (s_start, s_end) is collision.
  80. :param s_start: start node
  81. :param s_end: end node
  82. :return: True: is collision / False: not collision
  83. """
  84. if s_start in self.obs or s_end in self.obs:
  85. return True
  86. if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
  87. if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
  88. s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  89. s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  90. else:
  91. s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  92. s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  93. if s1 in self.obs or s2 in self.obs:
  94. return True
  95. return False
  96. def main():
  97. s_start = (5, 5)
  98. s_goal = (45, 25)
  99. dijkstra = Dijkstra(s_start, s_goal)
  100. plot = plotting.Plotting(s_start, s_goal)
  101. path, visited = dijkstra.searching()
  102. plot.animation(path, visited, "Dijkstra's") # animation generate
  103. if __name__ == '__main__':
  104. main()