plotting.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import matplotlib.pyplot as plt
  2. import env
  3. class Plotting():
  4. def __init__(self, xI, xG):
  5. self.xI, self.xG = xI, xG
  6. self.env = env.Env(self.xI, self.xG)
  7. self.obs = self.env.obs_map()
  8. self.lose = self.env.lose_map()
  9. def animation(self, path, name):
  10. """
  11. animation.
  12. :param path: optimal path
  13. :param name: tile of figure
  14. :return: an animation
  15. """
  16. plt.figure(1)
  17. self.plot_grid(name)
  18. self.plot_lose()
  19. self.plot_path(path)
  20. def plot_grid(self, name):
  21. """
  22. plot the obstacles in environment.
  23. :param name: title of figure
  24. :return: plot
  25. """
  26. obs_x = [self.obs[i][0] for i in range(len(self.obs))]
  27. obs_y = [self.obs[i][1] for i in range(len(self.obs))]
  28. plt.plot(self.xI[0], self.xI[1], "bs", ms=24)
  29. plt.plot(self.xG[0], self.xG[1], "gs", ms=24)
  30. plt.plot(obs_x, obs_y, "sk", ms=24)
  31. plt.title(name)
  32. plt.axis("equal")
  33. def plot_lose(self):
  34. """
  35. plot losing states in environment.
  36. :return: a plot
  37. """
  38. lose_x = [self.lose[i][0] for i in range(len(self.lose))]
  39. lose_y = [self.lose[i][1] for i in range(len(self.lose))]
  40. plt.plot(lose_x, lose_y, color='#A52A2A', marker='s', ms=24)
  41. def plot_visited(self, visited):
  42. """
  43. animation of order of visited nodes.
  44. :param visited: visited nodes
  45. :return: animation
  46. """
  47. visited.remove(self.xI)
  48. count = 0
  49. for x in visited:
  50. count += 1
  51. plt.plot(x[0], x[1], linewidth='3', color='#808080', marker='o')
  52. plt.gcf().canvas.mpl_connect('key_release_event', lambda event:
  53. [exit(0) if event.key == 'escape' else None])
  54. if count < len(visited) / 3:
  55. length = 15
  56. elif count < len(visited) * 2 / 3:
  57. length = 30
  58. else:
  59. length = 45
  60. if count % length == 0: plt.pause(0.001)
  61. def plot_path(self, path):
  62. path.remove(self.xI)
  63. path.remove(self.xG)
  64. for x in path:
  65. plt.plot(x[0], x[1], color='#808080', marker='o', ms=23)
  66. plt.gcf().canvas.mpl_connect('key_release_event', lambda event:
  67. [exit(0) if event.key == 'escape' else None])
  68. plt.pause(0.001)
  69. plt.show()
  70. plt.pause(0.5)
  71. def plot_diff(self, diff, name):
  72. plt.figure(2)
  73. plt.title(name, fontdict=None)
  74. plt.xlabel('iterations')
  75. plt.ylabel('difference of successive iterations')
  76. plt.grid('on')
  77. count = 0
  78. for x in diff:
  79. plt.plot(count, x, color='#808080', marker='o') # plot dots for animation
  80. plt.gcf().canvas.mpl_connect('key_release_event', lambda event:
  81. [exit(0) if event.key == 'escape' else None])
  82. plt.pause(0.07)
  83. count += 1
  84. plt.plot(diff, color='#808080')
  85. plt.pause(0.01)
  86. plt.show()