plotting.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. Plotting tools for RRT_2D
  3. @author: huiming zhou
  4. """
  5. import matplotlib.pyplot as plt
  6. import matplotlib.patches as patches
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Sampling-based Planning/")
  11. from rrt_2D import env
  12. class Plotting:
  13. def __init__(self, x_start, x_goal):
  14. self.xI, self.xG = x_start, x_goal
  15. self.env = env.Env()
  16. self.obs_bound = self.env.obs_boundary
  17. self.obs_circle = self.env.obs_circle
  18. self.obs_rectangle = self.env.obs_rectangle
  19. def animation(self, nodelist, path, animation=False):
  20. self.plot_grid("RRT")
  21. self.plot_visited(nodelist, animation)
  22. self.plot_path(path)
  23. def plot_grid(self, name):
  24. fig, ax = plt.subplots()
  25. for (ox, oy, w, h) in self.obs_bound:
  26. ax.add_patch(
  27. patches.Rectangle(
  28. (ox, oy), w, h,
  29. edgecolor='black',
  30. facecolor='black',
  31. fill=True
  32. )
  33. )
  34. for (ox, oy, w, h) in self.obs_rectangle:
  35. ax.add_patch(
  36. patches.Rectangle(
  37. (ox, oy), w, h,
  38. edgecolor='black',
  39. facecolor='gray',
  40. fill=True
  41. )
  42. )
  43. for (ox, oy, r) in self.obs_circle:
  44. ax.add_patch(
  45. patches.Circle(
  46. (ox, oy), r,
  47. edgecolor='black',
  48. facecolor='gray',
  49. fill=True
  50. )
  51. )
  52. plt.plot(self.xI[0], self.xI[1], "bs", linewidth=3)
  53. plt.plot(self.xG[0], self.xG[1], "gs", linewidth=3)
  54. plt.title(name)
  55. plt.axis("equal")
  56. @staticmethod
  57. def plot_visited(nodelist, animation):
  58. if animation:
  59. for node in nodelist:
  60. if node.parent:
  61. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  62. plt.gcf().canvas.mpl_connect('key_release_event',
  63. lambda event: [exit(0) if event.key == 'escape' else None])
  64. plt.pause(0.001)
  65. else:
  66. for node in nodelist:
  67. if node.parent:
  68. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  69. @staticmethod
  70. def plot_path(path):
  71. plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2)
  72. plt.pause(0.01)
  73. plt.show()