plotting.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """
  2. Plotting tools for RRT_2D
  3. @author: huiming zhou
  4. """
  5. import matplotlib.pyplot as plt
  6. import matplotlib.patches as patches
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Sampling-based Planning/")
  11. from rrt_2D import env
  12. class Plotting:
  13. def __init__(self, x_start, x_goal):
  14. self.xI, self.xG = x_start, x_goal
  15. self.env = env.Env()
  16. self.obs_bound = self.env.obs_boundary
  17. self.obs_circle = self.env.obs_circle
  18. self.obs_rectangle = self.env.obs_rectangle
  19. def animation(self, nodelist, path, name, animation=False):
  20. self.plot_grid(name)
  21. self.plot_visited(nodelist, animation)
  22. self.plot_path(path)
  23. def animation_connect(self, V1, V2, path, name):
  24. self.plot_grid(name)
  25. self.plot_visited_connect(V1, V2)
  26. self.plot_path(path)
  27. def plot_grid(self, name):
  28. fig, ax = plt.subplots()
  29. for (ox, oy, w, h) in self.obs_bound:
  30. ax.add_patch(
  31. patches.Rectangle(
  32. (ox, oy), w, h,
  33. edgecolor='black',
  34. facecolor='black',
  35. fill=True
  36. )
  37. )
  38. for (ox, oy, w, h) in self.obs_rectangle:
  39. ax.add_patch(
  40. patches.Rectangle(
  41. (ox, oy), w, h,
  42. edgecolor='black',
  43. facecolor='gray',
  44. fill=True
  45. )
  46. )
  47. for (ox, oy, r) in self.obs_circle:
  48. ax.add_patch(
  49. patches.Circle(
  50. (ox, oy), r,
  51. edgecolor='black',
  52. facecolor='gray',
  53. fill=True
  54. )
  55. )
  56. plt.plot(self.xI[0], self.xI[1], "bs", linewidth=3)
  57. plt.plot(self.xG[0], self.xG[1], "gs", linewidth=3)
  58. plt.title(name)
  59. plt.axis("equal")
  60. @staticmethod
  61. def plot_visited(nodelist, animation):
  62. if animation:
  63. count = 0
  64. for node in nodelist:
  65. count += 1
  66. if node.parent:
  67. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  68. plt.gcf().canvas.mpl_connect('key_release_event',
  69. lambda event: [exit(0) if event.key == 'escape' else None])
  70. if count % 10 == 0:
  71. plt.pause(0.001)
  72. else:
  73. for node in nodelist:
  74. if node.parent:
  75. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  76. @staticmethod
  77. def plot_visited_connect(V1, V2):
  78. len1, len2 = len(V1), len(V2)
  79. for k in range(max(len1, len2)):
  80. if k < len1:
  81. if V1[k].parent:
  82. plt.plot([V1[k].x, V1[k].parent.x], [V1[k].y, V1[k].parent.y], "-g")
  83. if k < len2:
  84. if V2[k].parent:
  85. plt.plot([V2[k].x, V2[k].parent.x], [V2[k].y, V2[k].parent.y], "-g")
  86. plt.gcf().canvas.mpl_connect('key_release_event',
  87. lambda event: [exit(0) if event.key == 'escape' else None])
  88. if k % 2 == 0:
  89. plt.pause(0.001)
  90. plt.pause(0.01)
  91. @staticmethod
  92. def plot_path(path):
  93. plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2)
  94. plt.pause(0.01)
  95. plt.show()