plotting.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import matplotlib.pyplot as plt
  2. import matplotlib.patches as patches
  3. from rrt_2D import env
  4. class Plotting:
  5. def __init__(self, xI, xG):
  6. self.xI, self.xG = xI, xG
  7. self.env = env.Env()
  8. self.obs_bound = self.env.obs_boundary
  9. self.obs_circle = self.env.obs_circle
  10. self.obs_rectangle = self.env.obs_rectangle
  11. def animation(self, nodelist, path, animation=False):
  12. if path is None:
  13. print("No path found!")
  14. return
  15. self.plot_grid("RRT")
  16. self.plot_visited(nodelist, animation)
  17. self.plot_path(path)
  18. def plot_grid(self, name):
  19. fig, ax = plt.subplots()
  20. for (ox, oy, w, h) in self.obs_bound:
  21. ax.add_patch(
  22. patches.Rectangle(
  23. (ox, oy), w, h,
  24. edgecolor='black',
  25. facecolor='black',
  26. fill=True
  27. )
  28. )
  29. for (ox, oy, w, h) in self.obs_rectangle:
  30. ax.add_patch(
  31. patches.Rectangle(
  32. (ox, oy), w, h,
  33. edgecolor='black',
  34. facecolor='gray',
  35. fill=True
  36. )
  37. )
  38. for (ox, oy, r) in self.obs_circle:
  39. ax.add_patch(
  40. patches.Circle(
  41. (ox, oy), r,
  42. edgecolor='black',
  43. facecolor='gray',
  44. fill=True
  45. )
  46. )
  47. plt.plot(self.xI[0], self.xI[1], "bs", linewidth=3)
  48. plt.plot(self.xG[0], self.xG[1], "gs", linewidth=3)
  49. plt.title(name)
  50. plt.axis("equal")
  51. @staticmethod
  52. def plot_visited(nodelist, animation):
  53. if animation:
  54. for node in nodelist:
  55. if node.parent:
  56. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  57. plt.gcf().canvas.mpl_connect('key_release_event',
  58. lambda event: [exit(0) if event.key == 'escape' else None])
  59. plt.pause(0.001)
  60. else:
  61. for node in nodelist:
  62. if node.parent:
  63. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  64. @staticmethod
  65. def plot_path(path):
  66. plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2)
  67. plt.pause(0.01)
  68. plt.show()