Нелинейная функция и обратная ей на сплайнах

import numpy as np
import pylab as pl

class spline:
    
    def __init__(self, dx, mu):
        self.dx = dx
        self.mu = mu
        self.pt_cnt = int(2.0/dx) + 1
        #self.points = [i*dx - 1.0 for i in range(self.pt_cnt)]
        self.points = [np.sin((i*dx - 1.0)*pl.pi/2) for i in range(self.pt_cnt)]
        #self.points = [np.sin((i*dx - 1.0)*pl.pi) for i in range(self.pt_cnt)]
        #self.points = [np.tanh((i*dx - 1.0)*3.0) for i in range(self.pt_cnt)]

    def getC(self, u):
        u2 = u*u
        u3 = u2*u
        c = [(    -u3+2.0*u2-u    )*0.5,
             ( 3.0*u3-5.0*u2  +2.0)*0.5,
             (-3.0*u3+4.0*u2+u    )*0.5,
             (     u3    -u2      )*0.5]
        return c
    
    def getCdrv(self, u):
        u2 = u*u
        cdrv = [(-3.0*u2 +4.0*u-1.0)*0.5,
                ( 9.0*u2-10.0*u    )*0.5,
                (-9.0*u2 +8.0*u+1.0)*0.5,
                ( 3.0*u2 -2.0*u    )*0.5]
        return cdrv

    def getParIdx(self, x):
        if x < -1.0:
            return 0.0, 0
        elif x > 1.0:
            return 1.0, self.pt_cnt-4
        z = x/self.dx + (self.pt_cnt - 3)/2.0
        if z < self.pt_cnt - 4:
            i = int(z)
        else:
            i = self.pt_cnt - 4
        u = z - float(i)
        return u, i
    
    def get(self, x):
        u, i = self.getParIdx(x)
        c  = self.getC(u)
        y  = self.points[i+0] * c[0]
        y += self.points[i+1] * c[1]
        y += self.points[i+2] * c[2]
        y += self.points[i+3] * c[3]
        return y
        
    def getDrv(self, x):
        u, i   = self.getParIdx(x)
        dc     = self.getCdrv(u)
        dyZdx  = self.points[i+0] * dc[0]
        dyZdx += self.points[i+1] * dc[1]
        dyZdx += self.points[i+2] * dc[2]
        dyZdx += self.points[i+3] * dc[3]
        return dyZdx/self.dx
        
    def getInverse(self, y):
        #search for appropriate part
        approx = y
        for i, p in enumerate(self.points[0:self.pt_cnt-2]):
            if p < self.points[i+1]:
                if p <= y < self.points[i+1]:
                    approx = i*self.dx - 1.0
                    break
            elif p > self.points[i+1]:
                if self.points[i+1] <= y < p:
                    approx = i*self.dx - 1.0
                    break
        #starting from that approximation search for precise value
        xprev = approx
        diff  = 1.0
        x     = 0.0
        cnt   = 0
        while pl.absolute(diff) > 1.e-4 and cnt < 10:
            x     = xprev - (self.get(xprev) - y)/self.getDrv(xprev)
            diff  = x - xprev
            xprev = x
            cnt  += 1
        if cnt >= 10:
            print("No convergence for y = "+str(y))
            return approx
        return x
    
spl = spline(0.01, 0.01)

x = [t*0.001-1 for t in range(2000)]
y = []
for xi in x:
    y.append(spl.get(xi))

spl_x = [t*spl.dx - 1.0 for t in range(spl.pt_cnt)]

pl.plot(spl_x, spl.points, '.')
pl.hold(True)

pl.plot(x, y)
pl.show()

pl.figure()
y = []
for xi in x:
    y.append(spl.getInverse(xi))
pl.plot(x, y)
pl.show()

Разукрашено на tohtml.com.

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *

Этот сайт использует Akismet для борьбы со спамом. Узнайте, как обрабатываются ваши данные комментариев.