Former-commit-id:a02aeb236c
[formerly9f19e3f712
] [formerly06a8b51d6d
[formerly 64fa9254b946eae7e61bbc3f513b7c3696c4f54f]] Former-commit-id:06a8b51d6d
Former-commit-id:3360eb6c5f
487 lines
18 KiB
Python
Executable file
487 lines
18 KiB
Python
Executable file
# This module provides interpolation for functions defined on a grid.
|
|
#
|
|
# Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
|
|
# last revision: 2008-8-18
|
|
#
|
|
|
|
"""
|
|
Interpolation of functions defined on a grid
|
|
"""
|
|
|
|
from Scientific import N
|
|
import Polynomial
|
|
from Scientific.indexing import index_expression
|
|
from Scientific_interpolation import _interpolate
|
|
import operator
|
|
|
|
#
|
|
# General interpolating functions.
|
|
#
|
|
class InterpolatingFunction:
|
|
|
|
"""X{Function} defined by values on a X{grid} using X{interpolation}
|
|
|
|
An interpolating function of M{n} variables with M{m}-dimensional values
|
|
is defined by an M{(n+m)}-dimensional array of values and M{n}
|
|
one-dimensional arrays that define the variables values
|
|
corresponding to the grid points. The grid does not have to be
|
|
equidistant.
|
|
|
|
An InterpolatingFunction object has attributes C{real} and C{imag}
|
|
like a complex function (even if its values are real).
|
|
"""
|
|
|
|
def __init__(self, axes, values, default = None, period = None):
|
|
"""
|
|
@param axes: a sequence of one-dimensional arrays, one for each
|
|
variable, specifying the values of the variables at
|
|
the grid points
|
|
@type axes: sequence of N.array
|
|
|
|
@param values: the function values on the grid
|
|
@type values: N.array
|
|
|
|
@param default: the value of the function outside the grid. A value
|
|
of C{None} means that the function is undefined outside
|
|
the grid and that any attempt to evaluate it there
|
|
raises an exception.
|
|
@type default: number or C{None}
|
|
|
|
@param period: the period for each of the variables, or C{None} for
|
|
variables in which the function is not periodic.
|
|
@type period: sequence of numbers or C{None}
|
|
"""
|
|
if len(axes) > len(values.shape):
|
|
raise ValueError('Inconsistent arguments')
|
|
self.axes = list(axes)
|
|
self.shape = sum([axis.shape for axis in self.axes], ())
|
|
self.values = values
|
|
self.default = default
|
|
if period is None:
|
|
period = len(self.axes)*[None]
|
|
self.period = period
|
|
if len(self.period) != len(self.axes):
|
|
raise ValueError('Inconsistent arguments')
|
|
for a, p in zip(self.axes, self.period):
|
|
if p is not None and a[0]+p <= a[-1]:
|
|
raise ValueError('Period too short')
|
|
|
|
def __call__(self, *points):
|
|
"""
|
|
@returns: the function value obtained by linear interpolation
|
|
@rtype: number
|
|
@raise TypeError: if the number of arguments (C{len(points)})
|
|
does not match the number of variables of the function
|
|
@raise ValueError: if the evaluation point is outside of the
|
|
domain of definition and no default value is defined
|
|
"""
|
|
if len(points) != len(self.axes):
|
|
raise TypeError('Wrong number of arguments')
|
|
if len(points) == 1:
|
|
# Fast Pyrex implementation for the important special case
|
|
# of a function of one variable with all arrays of type double.
|
|
period = self.period[0]
|
|
if period is None: period = 0.
|
|
try:
|
|
return _interpolate(points[0], self.axes[0],
|
|
self.values, period)
|
|
except:
|
|
# Run the Python version if anything goes wrong
|
|
pass
|
|
try:
|
|
neighbours = map(_lookup, points, self.axes, self.period)
|
|
except ValueError, text:
|
|
if self.default is not None:
|
|
return self.default
|
|
else:
|
|
raise ValueError(text)
|
|
slices = sum([item[0] for item in neighbours], ())
|
|
values = self.values[slices]
|
|
for item in neighbours:
|
|
weight = item[1]
|
|
values = (1.-weight)*values[0]+weight*values[1]
|
|
return values
|
|
|
|
def __len__(self):
|
|
"""
|
|
@returns: number of variables
|
|
@rtype: C{int}
|
|
"""
|
|
return len(self.axes[0])
|
|
|
|
def __getitem__(self, i):
|
|
"""
|
|
@param i: any indexing expression possible for C{N.array}
|
|
that does not use C{N.NewAxis}
|
|
@type i: indexing expression
|
|
@returns: an InterpolatingFunction whose number of variables
|
|
is reduced, or a number if no variable is left
|
|
@rtype: L{InterpolatingFunction} or number
|
|
@raise TypeError: if i is not an allowed index expression
|
|
"""
|
|
if isinstance(i, int):
|
|
if len(self.axes) == 1:
|
|
return (self.axes[0][i], self.values[i])
|
|
else:
|
|
return self._constructor(self.axes[1:], self.values[i])
|
|
elif isinstance(i, slice):
|
|
axes = [self.axes[0][i]] + self.axes[1:]
|
|
return self._constructor(axes, self.values[i])
|
|
elif isinstance(i, tuple):
|
|
axes = []
|
|
rest = self.axes[:]
|
|
for item in i:
|
|
if not isinstance(item, int):
|
|
axes.append(rest[0][item])
|
|
del rest[0]
|
|
axes = axes + rest
|
|
return self._constructor(axes, self.values[i])
|
|
else:
|
|
raise TypeError("illegal index type")
|
|
|
|
def __getslice__(self, i, j):
|
|
"""
|
|
@param i: lower slice index
|
|
@type i: C{int}
|
|
@param j: upper slice index
|
|
@type j: C{int}
|
|
@returns: an InterpolatingFunction whose number of variables
|
|
is reduced by one, or a number if no variable is left
|
|
@rtype: L{InterpolatingFunction} or number
|
|
"""
|
|
axes = [self.axes[0][i:j]] + self.axes[1:]
|
|
return self._constructor(axes, self.values[i:j])
|
|
|
|
def __getattr__(self, attr):
|
|
if attr == 'real':
|
|
values = self.values
|
|
try:
|
|
values = values.real
|
|
except ValueError:
|
|
pass
|
|
default = self.default
|
|
try:
|
|
default = default.real
|
|
except:
|
|
pass
|
|
return self._constructor(self.axes, values, default. self.period)
|
|
elif attr == 'imag':
|
|
try:
|
|
values = self.values.imag
|
|
except ValueError:
|
|
values = 0*self.values
|
|
default = self.default
|
|
try:
|
|
default = self.default.imag
|
|
except:
|
|
try:
|
|
default = 0*self.default
|
|
except:
|
|
default = None
|
|
return self._constructor(self.axes, values, default, self.period)
|
|
else:
|
|
raise AttributeError(attr)
|
|
|
|
def selectInterval(self, first, last, variable=0):
|
|
"""
|
|
@param first: lower limit of an axis interval
|
|
@type first: C{float}
|
|
@param last: upper limit of an axis interval
|
|
@type last: C{float}
|
|
@param variable: the index of the variable of the function
|
|
along which the interval restriction is applied
|
|
@type variable: C{int}
|
|
@returns: a new InterpolatingFunction whose grid is restricted
|
|
@rtype: L{InterpolatingFunction}
|
|
"""
|
|
x = self.axes[variable]
|
|
c = N.logical_and(N.greater_equal(x, first),
|
|
N.less_equal(x, last))
|
|
i_axes = self.axes[:variable] + [N.compress(c, x)] + \
|
|
self.axes[variable+1:]
|
|
i_values = N.compress(c, self.values, variable)
|
|
return self._constructor(i_axes, i_values, None, None)
|
|
|
|
def derivative(self, variable = 0):
|
|
"""
|
|
@param variable: the index of the variable of the function
|
|
with respect to which the X{derivative} is taken
|
|
@type variable: C{int}
|
|
@returns: a new InterpolatingFunction containing the numerical
|
|
derivative
|
|
@rtype: L{InterpolatingFunction}
|
|
"""
|
|
diffaxis = self.axes[variable]
|
|
ai = index_expression[::] + \
|
|
(len(self.values.shape)-variable-1) * index_expression[N.NewAxis]
|
|
period = self.period[variable]
|
|
if period is None:
|
|
ui = variable*index_expression[::] + \
|
|
index_expression[1::] + index_expression[...]
|
|
li = variable*index_expression[::] + \
|
|
index_expression[:-1:] + index_expression[...]
|
|
d_values = (self.values[ui]-self.values[li]) / \
|
|
(diffaxis[1:]-diffaxis[:-1])[ai]
|
|
diffaxis = 0.5*(diffaxis[1:]+diffaxis[:-1])
|
|
else:
|
|
u = N.take(self.values, range(1, len(diffaxis))+[0], axis=variable)
|
|
l = self.values
|
|
ua = N.concatenate((diffaxis[1:], period+diffaxis[0:1]))
|
|
la = diffaxis
|
|
d_values = (u-l)/(ua-la)[ai]
|
|
diffaxis = 0.5*(ua+la)
|
|
d_axes = self.axes[:variable]+[diffaxis]+self.axes[variable+1:]
|
|
d_default = None
|
|
if self.default is not None:
|
|
d_default = 0.
|
|
return self._constructor(d_axes, d_values, d_default, self.period)
|
|
|
|
def integral(self, variable = 0):
|
|
"""
|
|
@param variable: the index of the variable of the function
|
|
with respect to which the X{integration} is performed
|
|
@type variable: C{int}
|
|
@returns: a new InterpolatingFunction containing the numerical
|
|
X{integral}. The integration constant is defined such that
|
|
the integral at the first grid point is zero.
|
|
@rtype: L{InterpolatingFunction}
|
|
"""
|
|
if self.period[variable] is not None:
|
|
raise ValueError('Integration over periodic variables not defined')
|
|
intaxis = self.axes[variable]
|
|
ui = variable*index_expression[::] + \
|
|
index_expression[1::] + index_expression[...]
|
|
li = variable*index_expression[::] + \
|
|
index_expression[:-1:] + index_expression[...]
|
|
uai = index_expression[1::] + (len(self.values.shape)-variable-1) * \
|
|
index_expression[N.NewAxis]
|
|
lai = index_expression[:-1:] + (len(self.values.shape)-variable-1) * \
|
|
index_expression[N.NewAxis]
|
|
i_values = 0.5*N.add.accumulate((self.values[ui]
|
|
+self.values[li])* \
|
|
(intaxis[uai]-intaxis[lai]),
|
|
variable)
|
|
s = list(self.values.shape)
|
|
s[variable] = 1
|
|
z = N.zeros(tuple(s))
|
|
return self._constructor(self.axes,
|
|
N.concatenate((z, i_values), variable),
|
|
None)
|
|
|
|
def definiteIntegral(self, variable = 0):
|
|
"""
|
|
@param variable: the index of the variable of the function
|
|
with respect to which the X{integration} is performed
|
|
@type variable: C{int}
|
|
@returns: a new InterpolatingFunction containing the numerical
|
|
X{integral}. The integration constant is defined such that
|
|
the integral at the first grid point is zero. If the original
|
|
function has only one free variable, the definite integral
|
|
is a number
|
|
@rtype: L{InterpolatingFunction} or number
|
|
"""
|
|
if self.period[variable] is not None:
|
|
raise ValueError('Integration over periodic variables not defined')
|
|
intaxis = self.axes[variable]
|
|
ui = variable*index_expression[::] + \
|
|
index_expression[1::] + index_expression[...]
|
|
li = variable*index_expression[::] + \
|
|
index_expression[:-1:] + index_expression[...]
|
|
uai = index_expression[1::] + (len(self.values.shape)-variable-1) * \
|
|
index_expression[N.NewAxis]
|
|
lai = index_expression[:-1:] + (len(self.values.shape)-variable-1) * \
|
|
index_expression[N.NewAxis]
|
|
i_values = 0.5*N.add.reduce((self.values[ui]+self.values[li]) * \
|
|
(intaxis[uai]-intaxis[lai]), variable)
|
|
if len(self.axes) == 1:
|
|
return i_values
|
|
else:
|
|
i_axes = self.axes[:variable] + self.axes[variable+1:]
|
|
return self._constructor(i_axes, i_values, None)
|
|
|
|
def fitPolynomial(self, order):
|
|
"""
|
|
@param order: the order of the X{polynomial} to be fitted
|
|
@type order: C{int}
|
|
@returns: a polynomial whose coefficients have been obtained
|
|
by a X{least-squares} fit to the grid values
|
|
@rtype: L{Scientific.Functions.Polynomial}
|
|
"""
|
|
for p in self.period:
|
|
if p is not None:
|
|
raise ValueError('Polynomial fit not possible ' +
|
|
'for periodic function')
|
|
points = _combinations(self.axes)
|
|
return Polynomial._fitPolynomial(order, points,
|
|
N.ravel(self.values))
|
|
|
|
def __abs__(self):
|
|
values = abs(self.values)
|
|
try:
|
|
default = abs(self.default)
|
|
except:
|
|
default = self.default
|
|
return self._constructor(self.axes, values, default)
|
|
|
|
def _mathfunc(self, function):
|
|
if self.default is None:
|
|
default = None
|
|
else:
|
|
default = function(self.default)
|
|
return self._constructor(self.axes, function(self.values), default)
|
|
|
|
def exp(self):
|
|
return self._mathfunc(N.exp)
|
|
|
|
def log(self):
|
|
return self._mathfunc(N.log)
|
|
|
|
def sqrt(self):
|
|
return self._mathfunc(N.sqrt)
|
|
|
|
def sin(self):
|
|
return self._mathfunc(N.sin)
|
|
|
|
def cos(self):
|
|
return self._mathfunc(N.cos)
|
|
|
|
def tan(self):
|
|
return self._mathfunc(N.tan)
|
|
|
|
def sinh(self):
|
|
return self._mathfunc(N.sinh)
|
|
|
|
def cosh(self):
|
|
return self._mathfunc(N.cosh)
|
|
|
|
def tanh(self):
|
|
return self._mathfunc(N.tanh)
|
|
|
|
def arcsin(self):
|
|
return self._mathfunc(N.arcsin)
|
|
|
|
def arccos(self):
|
|
return self._mathfunc(N.arccos)
|
|
|
|
def arctan(self):
|
|
return self._mathfunc(N.arctan)
|
|
|
|
InterpolatingFunction._constructor = InterpolatingFunction
|
|
|
|
#
|
|
# Interpolating function on data in netCDF file
|
|
#
|
|
class NetCDFInterpolatingFunction(InterpolatingFunction):
|
|
|
|
"""Function defined by values on a grid in a X{netCDF} file
|
|
|
|
A subclass of L{InterpolatingFunction}.
|
|
"""
|
|
|
|
def __init__(self, filename, axesnames, variablename, default = None,
|
|
period = None):
|
|
"""
|
|
@param filename: the name of the netCDF file
|
|
@type filename: C{str}
|
|
|
|
@param axesnames: the names of the netCDF variables that contain the
|
|
axes information
|
|
@type axesnames: sequence of C{str}
|
|
|
|
@param variablename: the name of the netCDF variable that contains
|
|
the data values
|
|
@type variablename: C{str}
|
|
|
|
@param default: the value of the function outside the grid. A value
|
|
of C{None} means that the function is undefined outside
|
|
the grid and that any attempt to evaluate it there
|
|
raises an exception.
|
|
@type default: number or C{None}
|
|
|
|
@param period: the period for each of the variables, or C{None} for
|
|
variables in which the function is not periodic.
|
|
@type period: sequence of numbers or C{None}
|
|
"""
|
|
from Scientific.IO.NetCDF import NetCDFFile
|
|
self.file = NetCDFFile(filename, 'r')
|
|
self.axes = map(lambda n, f=self.file: f.variables[n], axesnames)
|
|
self.values = self.file.variables[variablename]
|
|
self.default = default
|
|
self.shape = ()
|
|
for axis in self.axes:
|
|
self.shape = self.shape + axis.shape
|
|
if period is None:
|
|
period = len(self.axes)*[None]
|
|
self.period = period
|
|
if len(self.period) != len(self.axes):
|
|
raise ValueError('Inconsistent arguments')
|
|
for a, p in zip(self.axes, self.period):
|
|
if p is not None and a[0]+p <= a[-1]:
|
|
raise ValueError('Period too short')
|
|
|
|
NetCDFInterpolatingFunction._constructor = InterpolatingFunction
|
|
|
|
|
|
# Helper functions
|
|
|
|
def _lookup(point, axis, period):
|
|
if period is None:
|
|
j = N.int_sum(N.less_equal(axis, point))
|
|
if j == len(axis):
|
|
if N.fabs(point - axis[j-1]) < 1.e-9:
|
|
return index_expression[j-2:j:1], 1.
|
|
else:
|
|
j = 0
|
|
if j == 0:
|
|
raise ValueError('Point outside grid of values')
|
|
i = j-1
|
|
weight = (point-axis[i])/(axis[j]-axis[i])
|
|
return index_expression[i:j+1:1], weight
|
|
else:
|
|
point = axis[0] + (point-axis[0]) % period
|
|
j = N.int_sum(N.less_equal(axis, point))
|
|
i = j-1
|
|
if j == len(axis):
|
|
weight = (point-axis[i])/(axis[0]+period-axis[i])
|
|
return index_expression[0:i+1:i], 1.-weight
|
|
else:
|
|
weight = (point-axis[i])/(axis[j]-axis[i])
|
|
return index_expression[i:j+1:1], weight
|
|
|
|
def _combinations(axes):
|
|
if len(axes) == 1:
|
|
return map(lambda x: (x,), axes[0])
|
|
else:
|
|
rest = _combinations(axes[1:])
|
|
l = []
|
|
for x in axes[0]:
|
|
for y in rest:
|
|
l.append((x,)+y)
|
|
return l
|
|
|
|
|
|
# Test code
|
|
|
|
if __name__ == '__main__':
|
|
|
|
## axis = N.arange(0,1.1,0.1)
|
|
## values = N.sqrt(axis)
|
|
## s = InterpolatingFunction((axis,), values)
|
|
## print s(0.22), N.sqrt(0.22)
|
|
## sd = s.derivative()
|
|
## print sd(0.35), 0.5/N.sqrt(0.35)
|
|
## si = s.integral()
|
|
## print si(0.42), (0.42**1.5)/1.5
|
|
## print s.definiteIntegral()
|
|
## values = N.sin(axis[:,N.NewAxis])*N.cos(axis)
|
|
## sc = InterpolatingFunction((axis,axis),values)
|
|
## print sc(0.23, 0.77), N.sin(0.23)*N.cos(0.77)
|
|
|
|
axis = N.arange(20)*(2.*N.pi)/20.
|
|
values = N.sin(axis)
|
|
s = InterpolatingFunction((axis,), values, period=(2.*N.pi,))
|
|
c = s.derivative()
|
|
for x in N.arange(0., 15., 1.):
|
|
print x
|
|
print N.sin(x), s(x)
|
|
print N.cos(x), c(x)
|