plotting.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. """
  2. Plotting tools for RRT_2D
  3. @author: huiming zhou
  4. """
  5. import matplotlib.pyplot as plt
  6. import matplotlib.patches as patches
  7. from rrt_2D import env
  8. class Plotting:
  9. def __init__(self, x_start, x_goal):
  10. self.xI, self.xG = x_start, x_goal
  11. self.env = env.Env()
  12. self.obs_bound = self.env.obs_boundary
  13. self.obs_circle = self.env.obs_circle
  14. self.obs_rectangle = self.env.obs_rectangle
  15. def animation(self, nodelist, path, animation=False):
  16. self.plot_grid("RRT")
  17. self.plot_visited(nodelist, animation)
  18. self.plot_path(path)
  19. def plot_grid(self, name):
  20. fig, ax = plt.subplots()
  21. for (ox, oy, w, h) in self.obs_bound:
  22. ax.add_patch(
  23. patches.Rectangle(
  24. (ox, oy), w, h,
  25. edgecolor='black',
  26. facecolor='black',
  27. fill=True
  28. )
  29. )
  30. for (ox, oy, w, h) in self.obs_rectangle:
  31. ax.add_patch(
  32. patches.Rectangle(
  33. (ox, oy), w, h,
  34. edgecolor='black',
  35. facecolor='gray',
  36. fill=True
  37. )
  38. )
  39. for (ox, oy, r) in self.obs_circle:
  40. ax.add_patch(
  41. patches.Circle(
  42. (ox, oy), r,
  43. edgecolor='black',
  44. facecolor='gray',
  45. fill=True
  46. )
  47. )
  48. plt.plot(self.xI[0], self.xI[1], "bs", linewidth=3)
  49. plt.plot(self.xG[0], self.xG[1], "gs", linewidth=3)
  50. plt.title(name)
  51. plt.axis("equal")
  52. @staticmethod
  53. def plot_visited(nodelist, animation):
  54. if animation:
  55. for node in nodelist:
  56. if node.parent:
  57. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  58. plt.gcf().canvas.mpl_connect('key_release_event',
  59. lambda event: [exit(0) if event.key == 'escape' else None])
  60. plt.pause(0.001)
  61. else:
  62. for node in nodelist:
  63. if node.parent:
  64. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  65. @staticmethod
  66. def plot_path(path):
  67. plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2)
  68. plt.pause(0.01)
  69. plt.show()