# coding: utf-8
# Copyright 2018 Ahmed Ratnani, Yaman Güçlü
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix, dia_matrix
from sympde.topology.space import BasicFunctionSpace
from psydac.linalg.stencil import StencilVectorSpace
from psydac.linalg.direct_solvers import BandedSolver, SparseSolver
from psydac.fem.basic import FemSpace, FemField
from psydac.core.bsplines import (
find_span,
basis_funs,
collocation_matrix,
histopolation_matrix,
breakpoints,
greville,
make_knots,
elevate_knots,
basis_integrals,
)
from psydac.utilities.utils import unroll_edges, refine_array_1d
from psydac.ddm.cart import DomainDecomposition, CartDecomposition
__all__ = ('SplineSpace',)
#===============================================================================
[docs]
class SplineSpace( FemSpace ):
"""
a 1D Splines Finite Element space
Parameters
----------
degree : int
Polynomial degree.
knots : array_like
Coordinates of knots (clamped or extended by periodicity).
grid: array_like
Coordinates of the grid. Used to construct the knots sequence, if not given.
multiplicity: int
Multiplicity of the knots in the knot sequence.
parent_multiplicity: int
Multiplicity of the parent knot sequence, if the space is reduced space.
periodic : bool
True if domain is periodic, False otherwise.
Default: False
dirichlet : tuple, list
True if using homogeneous dirichlet boundary conditions, False
otherwise. Must be specified for each bound
Default: (False, False)
basis : str
Set to "B" for B-splines (have partition of unity)
Set to "M" for M-splines (have unit integrals)
"""
def __init__(self, degree, knots=None, grid=None, multiplicity=None, parent_multiplicity=None,
periodic=False, dirichlet=(False, False), basis='B', pads=None):
if basis not in ['B', 'M']:
raise ValueError(" only options for basis functions are B or M ")
if (knots is not None) and (grid is not None):
raise ValueError( 'Cannot provide both grid and knots.' )
if (knots is None) and (grid is None):
raise ValueError('Either knots or grid must be provided.')
if (knots is not None) and (multiplicity is not None):
raise ValueError( 'Cannot provide both knots and multiplicity.' )
if (multiplicity is not None) and multiplicity<1:
raise ValueError('multiplicity should be >=1')
if (parent_multiplicity is not None) and parent_multiplicity<1:
raise ValueError('parent_multiplicity should be >=1')
if knots is None:
if multiplicity is None:multiplicity = 1
knots = make_knots( grid, degree, periodic, multiplicity )
if grid is None:
grid = breakpoints(knots, degree)
indices = np.where(np.diff(knots[degree:len(knots)-degree])>1e-15)[0]
if len(indices)>0:
multiplicity = np.diff(indices).max(initial=1)
else:
multiplicity = max(1,len(knots[degree+1:-degree-1]))
if parent_multiplicity is None:
parent_multiplicity = multiplicity
assert parent_multiplicity >= multiplicity
# TODO: verify that user-provided knots make sense in periodic case
# Number of basis function in space (= cardinality)
if periodic:
nbasis = len(knots) - 2*degree - 2 + multiplicity
else:
defect = 0
if dirichlet[0]: defect += 1
if dirichlet[1]: defect += 1
nbasis = len(knots) - degree - 1 - defect
# Coefficients to convert B-splines to M-splines (if needed)
if basis == 'M':
scaling_array = 1 / basis_integrals(knots, degree)
else:
scaling_array = None
# Store attributes in object
self._degree = degree
self._pads = pads or degree
self._knots = knots
self._periodic = periodic
self._multiplicity = multiplicity
self._dirichlet = dirichlet
self._basis = basis
self._nbasis = nbasis
self._breaks = grid
self._ncells = len(grid) - 1
self._greville = greville(knots, degree, periodic, multiplicity = multiplicity)
self._ext_greville = greville(elevate_knots(knots, degree, periodic, multiplicity=multiplicity), degree+1, periodic, multiplicity = multiplicity)
self._scaling_array = scaling_array
self._parent_multiplicity = parent_multiplicity
self._histopolation_grid = unroll_edges(self.domain, self.ext_greville)
# Create space of spline coefficients
domain_decomposition = DomainDecomposition([self._ncells], [periodic])
cart = CartDecomposition(domain_decomposition, [nbasis], [np.array([0])],[np.array([nbasis-1])], [self._pads], [multiplicity])
self._vector_space = StencilVectorSpace( cart )
# Store flag: object NOT YET prepared for interpolation / histopolation
self._interpolation_ready = False
self._histopolation_ready = False
self._symbolic_space = None
# ...
# ...
@property
def histopolation_grid(self):
"""
Coordinates of the N+1 points x[i] that define the N 1D edges
(x[i], x[i+1]) for histopolation, where N is equal to the number of
basis functions (i.e. the cardinality of the space).
In the non-periodic case x is simply the array of extended Greville
points. In the periodic case we "unroll" the 1D edges to ensure that
they correspond to positive, well-defined intervals with x[i] < x[i+1].
"""
return self._histopolation_grid
# ...
[docs]
def init_interpolation( self, dtype=float ):
"""
Compute the 1D collocation matrix and factorize it, in preparation
for the calculation of a spline interpolant given the values at the
Greville points.
"""
imat = collocation_matrix(
knots = self.knots,
degree = self.degree,
periodic = self.periodic,
normalization = self.basis,
xgrid = self.greville,
multiplicity = self.multiplicity
)
if self.periodic:
# Convert to CSC format and compute sparse LU decomposition
self._interpolator = SparseSolver( csc_matrix( imat ) )
else:
# Convert to LAPACK banded format (see DGBTRF function)
dmat = dia_matrix( imat )
l = abs( dmat.offsets.min() )
u = dmat.offsets.max()
cmat = csr_matrix( dmat )
bmat = np.zeros( (1+u+2*l, cmat.shape[1]), dtype=dtype )
for i,j in zip( *cmat.nonzero() ):
bmat[u+l+i-j,j] = cmat[i,j]
self._interpolator = BandedSolver( u, l, bmat )
self.imat = imat
# Store flag
self._interpolation_ready = True
# ...
[docs]
def init_histopolation( self, dtype=float):
"""
Compute the 1D histopolation matrix and factorize it, in preparation
for the calculation of a spline interpolant given the integrals within
the cells defined by the extended Greville points.
"""
imat = histopolation_matrix(
knots = self.knots,
degree = self.degree,
periodic = self.periodic,
normalization = self.basis,
xgrid = self.ext_greville,
multiplicity = self._multiplicity
)
self.hmat= imat
if self.periodic:
# Convert to CSC format and compute sparse LU decomposition
self._histopolator = SparseSolver( csc_matrix( imat ) )
else:
# Convert to LAPACK banded format (see DGBTRF function)
dmat = dia_matrix( imat )
l = abs( dmat.offsets.min() )
u = dmat.offsets.max()
cmat = csr_matrix( dmat )
bmat = np.zeros( (1+u+2*l, cmat.shape[1]), dtype=dtype)
for i,j in zip( *cmat.nonzero() ):
bmat[u+l+i-j,j] = cmat[i,j]
self._histopolator = BandedSolver( u, l, bmat )
# Store flag
self._histopolation_ready = True
#--------------------------------------------------------------------------
# Abstract interface: read-only attributes
#--------------------------------------------------------------------------
@property
def ldim( self ):
""" Parametric dimension.
"""
return 1
@property
def periodic( self ):
""" True if domain is periodic, False otherwise.
"""
return self._periodic
@property
def pads( self ):
""" Padding for potential parallel assembly.
"""
return self._pads
@property
def mapping( self ):
""" Assume identity mapping for now.
"""
return None
@property
def vector_space( self ):
"""Returns the topological associated vector space."""
return self._vector_space
@property
def is_product(self):
return False
@property
def symbolic_space( self ):
return self._symbolic_space
@symbolic_space.setter
def symbolic_space( self, symbolic_space ):
assert isinstance(symbolic_space, BasicFunctionSpace)
self._symbolic_space = symbolic_space
#--------------------------------------------------------------------------
# Abstract interface: evaluation methods
#--------------------------------------------------------------------------
[docs]
def eval_field(self, field, *eta , weights=None):
assert isinstance( field, FemField )
assert field.space is self
assert len(eta) == 1
eta = eta[0]
span = find_span( self.knots, self.degree, eta)
basis_array = basis_funs( self.knots, self.degree, eta, span)
index = slice(span-self.degree, span + 1)
if self.basis == 'M':
basis_array *= self._scaling_array[index]
coeffs = field.coeffs[index].copy()
if weights:
coeffs *= weights[index]
return np.dot(coeffs,basis_array)
# ...
[docs]
def eval_field_gradient( self, field, *eta , weights=None):
assert isinstance( field, FemField )
assert field.space is self
assert len( eta ) == 1
raise NotImplementedError()
#--------------------------------------------------------------------------
# Other properties
#--------------------------------------------------------------------------
@property
def is_scalar( self ):
""" Only scalar field is implemented for now.
"""
return True
@property
def basis( self ):
return self._basis
@property
def interpolation_grid( self ):
if self.basis == 'B':
return self.greville
elif self.basis == 'M':
return self.ext_greville
else:
raise NotImplementedError()
@property
def nbasis( self ):
""" Number of basis functions, i.e. cardinality of spline space.
"""
return self._nbasis
@property
def degree( self ):
""" Spline degree.
"""
return self._degree
@property
def ncells( self ):
""" Number of cells in domain.
"""
return self._ncells
@property
def dirichlet( self ):
""" True if using homogeneous dirichlet boundary conditions, False otherwise.
"""
return self._dirichlet
@property
def knots( self ):
""" Knot sequence.
"""
return self._knots
@property
def multiplicity( self ):
return self._multiplicity
@property
def parent_multiplicity( self ):
return self._parent_multiplicity
@property
def breaks( self ):
""" List of breakpoints.
"""
return self._breaks
@property
def domain( self ):
""" Domain boundaries [a,b].
"""
breaks = self.breaks
return breaks[0], breaks[-1]
@property
def greville( self ):
""" Coordinates of all Greville points. Used for interpolation.
"""
return self._greville
@property
def ext_greville( self ):
""" Greville coordinates of 'extended' space with degree p+1.
Used for histopolation.
"""
return self._ext_greville
@property
def scaling_array(self):
"""
If self.basis=='M', return array used to rescale B-splines to M-splines
If self.basis=='B', return None.
The length of the scaling array is (len(knots)-degree-1).
"""
return self._scaling_array
#--------------------------------------------------------------------------
# Other methods
#--------------------------------------------------------------------------
[docs]
def compute_interpolant( self, values, field ):
"""
Compute field (i.e. update its spline coefficients) such that it
interpolates a certain function $f(x)$ at the Greville points.
Parameters
----------
values : array_like (nbasis,)
Function values $f(x_i)$ at the 'nbasis' Greville points $x_i$,
to be interpolated.
field : FemField
Input/output argument: spline that has to interpolate the given
values.
"""
assert len( values ) == self.nbasis
assert isinstance( field, FemField )
assert field.space is self
if not self._interpolation_ready:
self.init_interpolation()
n = self.nbasis
c = field.coeffs
c[0:n] = self._interpolator.solve( values )
c.update_ghost_regions()
# ...
[docs]
def compute_histopolant( self, values, field ):
"""
Compute field (i.e. update its spline coefficients) such that its
integrals between the extended Greville points match the given
values.
Parameters
----------
values : array_like (nbasis,)
Integral values between the 'nbasis' extended Greville cells
$[x_i, x_{i+1}]$, to be matched by the spline.
field : FemField
Input/output argument: spline that has to match the given
integral values.
"""
assert len( values ) == self.nbasis
assert isinstance( field, FemField )
assert field.space is self
if not self._histopolation_ready:
self.init_histopolation()
n = self.nbasis
c = field.coeffs
c[0:n] = self._histopolator.solve( values )
c.update_ghost_regions()
# ...
[docs]
def refine(self, ncells):
"""
Create a refined 1D spline space with the given number of cells.
Parameters
----------
ncells : int
Number of cells of refined space. Must be multiple of self.ncells.
Returns
-------
SplineSpace
Refined 1D spline space which contains the original space.
"""
# Sanity checks
if int(ncells) != ncells:
msg = f"{ncells} is not an integer"
elif ncells < self.ncells:
msg = f"{ncells} is smaller than minimum value {self.ncells}"
elif ncells % self.ncells != 0:
msg = f"{ncells} is not multiple of {self.ncells}"
else:
msg = None
if msg:
raise ValueError("Wrong number of cells: " + msg)
if ncells == self.ncells:
return self
refinement_factor = ncells // self.ncells
grid = refine_array_1d(self.breaks, refinement_factor)
return SplineSpace(self.degree,
grid=grid,
multiplicity=self.multiplicity,
parent_multiplicity=self.parent_multiplicity,
periodic=self.periodic,
dirichlet=self.dirichlet,
basis=self.basis,
pads=self.pads)
# ...
def __str__(self):
"""Pretty printing"""
txt = '\n'
txt += '> ldim :: {ldim}\n'.format( ldim=self.ldim )
txt += '> nbasis :: {dim} \n'.format( dim=self.nbasis )
txt += '> degree :: {degree}'.format( degree=self.degree )
return txt
[docs]
def draw(self):
from scipy.interpolate import BSpline
import matplotlib.pyplot as plt
d = self.degree
n = self.nbasis + d*self.periodic
knots = self.knots
fig, ax = plt.subplots()
xx = np.linspace(knots[0], knots[-1], 200)
for i in range(n):
c = [0]*n
c[i] = 1
spl = BSpline(knots, c, d)
ax.plot(xx, spl(xx), label='N{}'.format(i))
ax.grid(True)
ax.legend()
plt.show()