plotting.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """
  2. Plotting tools for Sampling-based algorithms
  3. @author: huiming zhou
  4. """
  5. import matplotlib.pyplot as plt
  6. import matplotlib.patches as patches
  7. import os
  8. import sys
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) +
  10. "/../../Sampling_based_Planning/")
  11. from Sampling_based_Planning.rrt_2D import env
  12. class Plotting:
  13. def __init__(self, x_start, x_goal):
  14. self.xI, self.xG = x_start, x_goal
  15. self.env = env.Env()
  16. self.obs_bound = self.env.obs_boundary
  17. self.obs_circle = self.env.obs_circle
  18. self.obs_rectangle = self.env.obs_rectangle
  19. def animation(self, nodelist, path, name, animation=False):
  20. self.plot_grid(name)
  21. self.plot_visited(nodelist, animation)
  22. self.plot_path(path)
  23. def animation_connect(self, V1, V2, path, name):
  24. self.plot_grid(name)
  25. self.plot_visited_connect(V1, V2)
  26. self.plot_path(path)
  27. def plot_grid(self, name):
  28. fig, ax = plt.subplots()
  29. for (ox, oy, w, h) in self.obs_bound:
  30. ax.add_patch(
  31. patches.Rectangle(
  32. (ox, oy), w, h,
  33. edgecolor='black',
  34. facecolor='black',
  35. fill=True
  36. )
  37. )
  38. for (ox, oy, w, h) in self.obs_rectangle:
  39. ax.add_patch(
  40. patches.Rectangle(
  41. (ox, oy), w, h,
  42. edgecolor='black',
  43. facecolor='gray',
  44. fill=True
  45. )
  46. )
  47. for (ox, oy, r) in self.obs_circle:
  48. ax.add_patch(
  49. patches.Circle(
  50. (ox, oy), r,
  51. edgecolor='black',
  52. facecolor='gray',
  53. fill=True
  54. )
  55. )
  56. plt.plot(self.xI[0], self.xI[1], "bs", linewidth=3)
  57. plt.plot(self.xG[0], self.xG[1], "gs", linewidth=3)
  58. plt.title(name)
  59. plt.axis("equal")
  60. @staticmethod
  61. def plot_visited(nodelist, animation):
  62. if animation:
  63. count = 0
  64. for node in nodelist:
  65. count += 1
  66. if node.parent:
  67. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  68. plt.gcf().canvas.mpl_connect('key_release_event',
  69. lambda event:
  70. [exit(0) if event.key == 'escape' else None])
  71. if count % 10 == 0:
  72. plt.pause(0.001)
  73. else:
  74. for node in nodelist:
  75. if node.parent:
  76. plt.plot([node.parent.x, node.x], [node.parent.y, node.y], "-g")
  77. @staticmethod
  78. def plot_visited_connect(V1, V2):
  79. len1, len2 = len(V1), len(V2)
  80. for k in range(max(len1, len2)):
  81. if k < len1:
  82. if V1[k].parent:
  83. plt.plot([V1[k].x, V1[k].parent.x], [V1[k].y, V1[k].parent.y], "-g")
  84. if k < len2:
  85. if V2[k].parent:
  86. plt.plot([V2[k].x, V2[k].parent.x], [V2[k].y, V2[k].parent.y], "-g")
  87. plt.gcf().canvas.mpl_connect('key_release_event',
  88. lambda event: [exit(0) if event.key == 'escape' else None])
  89. if k % 2 == 0:
  90. plt.pause(0.001)
  91. plt.pause(0.01)
  92. @staticmethod
  93. def plot_path(path):
  94. if len(path) != 0:
  95. plt.plot([x[0] for x in path], [x[1] for x in path], '-r', linewidth=2)
  96. plt.pause(0.01)
  97. plt.show()