plotting.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import matplotlib.pyplot as plt
  2. import env
  3. class Plotting:
  4. def __init__(self, xI, xG):
  5. self.xI, self.xG = xI, xG
  6. self.env = env.Env()
  7. self.obs = self.env.obs_map()
  8. def animation(self, path, visited, name):
  9. self.plot_grid(name)
  10. self.plot_visited(visited)
  11. self.plot_path(path)
  12. plt.show()
  13. def plot_grid(self, name):
  14. obs_x = [self.obs[i][0] for i in range(len(self.obs))]
  15. obs_y = [self.obs[i][1] for i in range(len(self.obs))]
  16. plt.plot(self.xI[0], self.xI[1], "bs")
  17. plt.plot(self.xG[0], self.xG[1], "gs")
  18. plt.plot(obs_x, obs_y, "sk")
  19. plt.title(name)
  20. plt.axis("equal")
  21. def plot_visited(self, visited, cl='gray'):
  22. if self.xI in visited:
  23. visited.remove(self.xI)
  24. if self.xG in visited:
  25. visited.remove(self.xG)
  26. count = 0
  27. for x in visited:
  28. count += 1
  29. plt.plot(x[0], x[1], linewidth='3', color=cl, marker='o')
  30. plt.gcf().canvas.mpl_connect('key_release_event',
  31. lambda event: [exit(0) if event.key == 'escape' else None])
  32. if count < len(visited) / 3:
  33. length = 15
  34. elif count < len(visited) * 2 / 3:
  35. length = 25
  36. else:
  37. length = 35
  38. if count % length == 0:
  39. plt.pause(0.001)
  40. plt.pause(0.01)
  41. def plot_path(self, path, cl='r', flag=False):
  42. if self.xI in path:
  43. path.remove(self.xI)
  44. if self.xG in path:
  45. path.remove(self.xG)
  46. path_x = [path[i][0] for i in range(len(path))]
  47. path_y = [path[i][1] for i in range(len(path))]
  48. if not flag:
  49. plt.plot(path_x, path_y, linewidth='3', color='r', marker='o')
  50. else:
  51. plt.plot(path_x, path_y, linewidth='3', color=cl, marker='o')
  52. plt.pause(0.01)
  53. def animation_ara_star(self, path, visited, name):
  54. self.plot_grid(name)
  55. cl_v, cl_p = self.color_list()
  56. for k in range(len(path)):
  57. self.plot_visited(visited[k], cl_v[k])
  58. self.plot_path(path[k], cl_p[k], True)
  59. plt.pause(0.5)
  60. plt.show()
  61. @staticmethod
  62. def color_list():
  63. cl_v = ['silver', 'wheat', 'lightskyblue', 'plum', 'slategray']
  64. cl_p = ['gray', 'orange', 'deepskyblue', 'red', 'm']
  65. return cl_v, cl_p