cubic_spline.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. u"""
  4. Cubic Spline library on python
  5. author Atsushi Sakai
  6. usage: see test codes as below
  7. license: MIT
  8. """
  9. import math
  10. import numpy as np
  11. import bisect
  12. class Spline:
  13. u"""
  14. Cubic Spline class
  15. """
  16. def __init__(self, x, y):
  17. self.b, self.c, self.d, self.w = [], [], [], []
  18. self.x = x
  19. self.y = y
  20. self.nx = len(x) # dimension of s
  21. h = np.diff(x)
  22. # calc coefficient cBest
  23. self.a = [iy for iy in y]
  24. # calc coefficient cBest
  25. A = self.__calc_A(h)
  26. B = self.__calc_B(h)
  27. self.c = np.linalg.solve(A, B)
  28. # print(self.c1)
  29. # calc spline coefficient b and d
  30. for i in range(self.nx - 1):
  31. self.d.append((self.c[i + 1] - self.c[i]) / (3.0 * h[i]))
  32. tb = (self.a[i + 1] - self.a[i]) / h[i] - h[i] * \
  33. (self.c[i + 1] + 2.0 * self.c[i]) / 3.0
  34. self.b.append(tb)
  35. def calc(self, t):
  36. u"""
  37. Calc position
  38. if t is outside of the input s, return None
  39. """
  40. if t < self.x[0]:
  41. return None
  42. elif t > self.x[-1]:
  43. return None
  44. i = self.__search_index(t)
  45. dx = t - self.x[i]
  46. result = self.a[i] + self.b[i] * dx + \
  47. self.c[i] * dx ** 2.0 + self.d[i] * dx ** 3.0
  48. return result
  49. def calcd(self, t):
  50. u"""
  51. Calc first derivative
  52. if t is outside of the input s, return None
  53. """
  54. if t < self.x[0]:
  55. return None
  56. elif t > self.x[-1]:
  57. return None
  58. i = self.__search_index(t)
  59. dx = t - self.x[i]
  60. result = self.b[i] + 2.0 * self.c[i] * dx + 3.0 * self.d[i] * dx ** 2.0
  61. return result
  62. def calcdd(self, t):
  63. u"""
  64. Calc second derivative
  65. """
  66. if t < self.x[0]:
  67. return None
  68. elif t > self.x[-1]:
  69. return None
  70. i = self.__search_index(t)
  71. dx = t - self.x[i]
  72. result = 2.0 * self.c[i] + 6.0 * self.d[i] * dx
  73. return result
  74. def __search_index(self, x):
  75. u"""
  76. search data segment index
  77. """
  78. return bisect.bisect(self.x, x) - 1
  79. def __calc_A(self, h):
  80. u"""
  81. calc matrix A for spline coefficient cBest
  82. """
  83. A = np.zeros((self.nx, self.nx))
  84. A[0, 0] = 1.0
  85. for i in range(self.nx - 1):
  86. if i != (self.nx - 2):
  87. A[i + 1, i + 1] = 2.0 * (h[i] + h[i + 1])
  88. A[i + 1, i] = h[i]
  89. A[i, i + 1] = h[i]
  90. A[0, 1] = 0.0
  91. A[self.nx - 1, self.nx - 2] = 0.0
  92. A[self.nx - 1, self.nx - 1] = 1.0
  93. # print(A)
  94. return A
  95. def __calc_B(self, h):
  96. u"""
  97. calc matrix B for spline coefficient cBest
  98. """
  99. B = np.zeros(self.nx)
  100. for i in range(self.nx - 2):
  101. B[i + 1] = 3.0 * (self.a[i + 2] - self.a[i + 1]) / \
  102. h[i + 1] - 3.0 * (self.a[i + 1] - self.a[i]) / h[i]
  103. # print(B)
  104. return B
  105. class Spline2D:
  106. u"""
  107. 2D Cubic Spline class
  108. """
  109. def __init__(self, x, y):
  110. self.s = self.__calc_s(x, y)
  111. self.sx = Spline(self.s, x)
  112. self.sy = Spline(self.s, y)
  113. def __calc_s(self, x, y):
  114. dx = np.diff(x)
  115. dy = np.diff(y)
  116. self.ds = [math.sqrt(idx ** 2 + idy ** 2)
  117. for (idx, idy) in zip(dx, dy)]
  118. s = [0]
  119. s.extend(np.cumsum(self.ds))
  120. return s
  121. def calc_position(self, s):
  122. u"""
  123. calc position
  124. """
  125. x = self.sx.calc(s)
  126. y = self.sy.calc(s)
  127. return x, y
  128. def calc_curvature(self, s):
  129. u"""
  130. calc curvature
  131. """
  132. dx = self.sx.calcd(s)
  133. ddx = self.sx.calcdd(s)
  134. dy = self.sy.calcd(s)
  135. ddy = self.sy.calcdd(s)
  136. k = (ddy * dx - ddx * dy) / (dx ** 2 + dy ** 2)
  137. return k
  138. def calc_yaw(self, s):
  139. u"""
  140. calc yaw
  141. """
  142. dx = self.sx.calcd(s)
  143. dy = self.sy.calcd(s)
  144. yaw = math.atan2(dy, dx)
  145. return yaw
  146. def calc_spline_course(x, y, ds=0.1):
  147. sp = Spline2D(x, y)
  148. s = np.arange(0, sp.s[-1], ds)
  149. rx, ry, ryaw, rk = [], [], [], []
  150. for i_s in s:
  151. ix, iy = sp.calc_position(i_s)
  152. rx.append(ix)
  153. ry.append(iy)
  154. ryaw.append(sp.calc_yaw(i_s))
  155. rk.append(sp.calc_curvature(i_s))
  156. return rx, ry, ryaw, rk, s
  157. def test_spline2d():
  158. print("Spline 2D test")
  159. import matplotlib.pyplot as plt
  160. x = [-2.5, 0.0, 2.5, 5.0, 7.5, 3.0, -1.0]
  161. y = [0.7, -6, 5, 6.5, 0.0, 5.0, -2.0]
  162. sp = Spline2D(x, y)
  163. s = np.arange(0, sp.s[-1], 0.1)
  164. rx, ry, ryaw, rk = [], [], [], []
  165. for i_s in s:
  166. ix, iy = sp.calc_position(i_s)
  167. rx.append(ix)
  168. ry.append(iy)
  169. ryaw.append(sp.calc_yaw(i_s))
  170. rk.append(sp.calc_curvature(i_s))
  171. flg, ax = plt.subplots(1)
  172. plt.plot(x, y, "xb", label="input")
  173. plt.plot(rx, ry, "-r", label="spline")
  174. plt.grid(True)
  175. plt.axis("equal")
  176. plt.xlabel("s[m]")
  177. plt.ylabel("y[m]")
  178. plt.legend()
  179. flg, ax = plt.subplots(1)
  180. plt.plot(s, [math.degrees(iyaw) for iyaw in ryaw], "-r", label="yaw")
  181. plt.grid(True)
  182. plt.legend()
  183. plt.xlabel("line length[m]")
  184. plt.ylabel("yaw angle[deg]")
  185. flg, ax = plt.subplots(1)
  186. plt.plot(s, rk, "-r", label="curvature")
  187. plt.grid(True)
  188. plt.legend()
  189. plt.xlabel("line length[m]")
  190. plt.ylabel("curvature [1/m]")
  191. plt.show()
  192. def test_spline():
  193. print("Spline test")
  194. import matplotlib.pyplot as plt
  195. x = [-0.5, 0.0, 0.5, 1.0, 1.5]
  196. y = [3.2, 2.7, 6, 5, 6.5]
  197. spline = Spline(x, y)
  198. rx = np.arange(-2.0, 4, 0.01)
  199. ry = [spline.calc(i) for i in rx]
  200. plt.plot(x, y, "xb")
  201. plt.plot(rx, ry, "-r")
  202. plt.grid(True)
  203. plt.axis("equal")
  204. plt.show()
  205. if __name__ == '__main__':
  206. test_spline()
  207. # test_spline2d()