import numpy as np import pylab as pl class spline: def __init__(self, dx, mu): self.dx = dx self.mu = mu self.pt_cnt = int(2.0/dx) + 1 #self.points = [i*dx - 1.0 for i in range(self.pt_cnt)] self.points = [np.sin((i*dx - 1.0)*pl.pi/2) for i in range(self.pt_cnt)] #self.points = [np.sin((i*dx - 1.0)*pl.pi) for i in range(self.pt_cnt)] #self.points = [np.tanh((i*dx - 1.0)*3.0) for i in range(self.pt_cnt)] def getC(self, u): u2 = u*u u3 = u2*u c = [( -u3+2.0*u2-u )*0.5, ( 3.0*u3-5.0*u2 +2.0)*0.5, (-3.0*u3+4.0*u2+u )*0.5, ( u3 -u2 )*0.5] return c def getCdrv(self, u): u2 = u*u cdrv = [(-3.0*u2 +4.0*u-1.0)*0.5, ( 9.0*u2-10.0*u )*0.5, (-9.0*u2 +8.0*u+1.0)*0.5, ( 3.0*u2 -2.0*u )*0.5] return cdrv def getParIdx(self, x): if x < -1.0: return 0.0, 0 elif x > 1.0: return 1.0, self.pt_cnt-4 z = x/self.dx + (self.pt_cnt - 3)/2.0 if z < self.pt_cnt - 4: i = int(z) else: i = self.pt_cnt - 4 u = z - float(i) return u, i def get(self, x): u, i = self.getParIdx(x) c = self.getC(u) y = self.points[i+0] * c[0] y += self.points[i+1] * c[1] y += self.points[i+2] * c[2] y += self.points[i+3] * c[3] return y def getDrv(self, x): u, i = self.getParIdx(x) dc = self.getCdrv(u) dyZdx = self.points[i+0] * dc[0] dyZdx += self.points[i+1] * dc[1] dyZdx += self.points[i+2] * dc[2] dyZdx += self.points[i+3] * dc[3] return dyZdx/self.dx def getInverse(self, y): #search for appropriate part approx = y for i, p in enumerate(self.points[0:self.pt_cnt-2]): if p < self.points[i+1]: if p <= y < self.points[i+1]: approx = i*self.dx - 1.0 break elif p > self.points[i+1]: if self.points[i+1] <= y < p: approx = i*self.dx - 1.0 break #starting from that approximation search for precise value xprev = approx diff = 1.0 x = 0.0 cnt = 0 while pl.absolute(diff) > 1.e-4 and cnt < 10: x = xprev - (self.get(xprev) - y)/self.getDrv(xprev) diff = x - xprev xprev = x cnt += 1 if cnt >= 10: print("No convergence for y = "+str(y)) return approx return x spl = spline(0.01, 0.01) x = [t*0.001-1 for t in range(2000)] y = [] for xi in x: y.append(spl.get(xi)) spl_x = [t*spl.dx - 1.0 for t in range(spl.pt_cnt)] pl.plot(spl_x, spl.points, '.') pl.hold(True) pl.plot(x, y) pl.show() pl.figure() y = [] for xi in x: y.append(spl.getInverse(xi)) pl.plot(x, y) pl.show()
Разукрашено на tohtml.com.