dijkstra.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. Dijkstra 2D
  3. @author: huiming zhou
  4. """
  5. import os
  6. import sys
  7. import math
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  9. "/../../Search-based Planning/")
  10. from Search_2D import queue
  11. from Search_2D import plotting
  12. from Search_2D import env
  13. class Dijkstra:
  14. def __init__(self, s_start, s_goal):
  15. self.s_start, self.s_goal = s_start, s_goal
  16. self.Env = env.Env()
  17. self.plotting = plotting.Plotting(self.s_start, self.s_goal)
  18. self.u_set = self.Env.motions # feasible input set
  19. self.obs = self.Env.obs # position of obstacles
  20. self.g = {self.s_start: 0, self.s_goal: float("inf")} # cost to come
  21. self.OPEN = queue.QueuePrior() # priority queue / OPEN set
  22. self.OPEN.put(self.s_start, 0)
  23. self.CLOSED = [] # closed set & visited
  24. self.PARENT = {self.s_start: self.s_start}
  25. def searching(self):
  26. """
  27. Dijkstra Searching.
  28. :return: path, order of visited nodes in the planning
  29. """
  30. while not self.OPEN.empty():
  31. s = self.OPEN.get()
  32. self.CLOSED.append(s)
  33. if s == self.s_goal:
  34. break
  35. for s_n in self.get_neighbor(s):
  36. new_cost = self.g[s] + self.cost(s, s_n)
  37. if s_n not in self.g:
  38. self.g[s_n] = float("inf")
  39. if new_cost < self.g[s_n]:
  40. self.g[s_n] = new_cost
  41. self.OPEN.put(s_n, new_cost)
  42. self.PARENT[s_n] = s
  43. return self.extract_path(), self.CLOSED
  44. def get_neighbor(self, s):
  45. """
  46. find neighbors of state s that not in obstacles.
  47. :param s: state
  48. :return: neighbors
  49. """
  50. s_list = []
  51. for u in self.u_set:
  52. s_list.append(tuple([s[i] + u[i] for i in range(2)]))
  53. return s_list
  54. def extract_path(self):
  55. """
  56. Extract the path based on PARENT set.
  57. :return: The planning path
  58. """
  59. path = [self.s_goal]
  60. s = self.s_goal
  61. while True:
  62. s = self.PARENT[s]
  63. path.append(s)
  64. if s == self.s_start:
  65. break
  66. return list(path)
  67. def cost(self, s_start, s_goal):
  68. """
  69. Calculate cost for this motion
  70. :param s_start: starting node
  71. :param s_goal: end node
  72. :return: cost for this motion
  73. :note: cost function could be more complicate!
  74. """
  75. if self.is_collision(s_start, s_goal):
  76. return float("inf")
  77. return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
  78. def is_collision(self, s_start, s_end):
  79. if s_start in self.obs or s_end in self.obs:
  80. return True
  81. if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
  82. if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
  83. s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  84. s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  85. else:
  86. s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
  87. s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
  88. if s1 in self.obs or s2 in self.obs:
  89. return True
  90. return False
  91. def main():
  92. s_start = (5, 5)
  93. s_goal = (45, 25)
  94. dijkstra = Dijkstra(s_start, s_goal)
  95. plot = plotting.Plotting(s_start, s_goal)
  96. path, visited = dijkstra.searching()
  97. plot.animation(path, visited, "Dijkstra's") # animation generate
  98. if __name__ == '__main__':
  99. main()