import re
import string
import random
from sympy import Symbol, IndexedBase, Indexed, Idx
from sympy import Mul, Pow, Function, Tuple
from sympy import sqrt as sympy_sqrt, Range
from sympy.utilities.iterables import cartes
from sympde.topology.space import ScalarFunction
from sympde.topology.space import VectorFunction
from sympde.topology.space import IndexedVectorFunction
from sympde.topology.space import element_of
from sympde.topology import Mapping
from sympde.topology import Boundary
from sympde.topology.derivatives import _partial_derivatives
from sympde.topology.derivatives import _logical_partial_derivatives
from sympde.topology.derivatives import get_atom_derivatives
from sympde.topology.derivatives import get_index_derivatives
from sympde.topology.derivatives import get_atom_logical_derivatives
from sympde.topology.derivatives import get_index_logical_derivatives
from sympde.topology import LogicalExpr
from sympde.topology import SymbolicExpr
from sympde.core import Constant
from psydac.pyccel.ast.core import Variable, IndexedVariable
from psydac.pyccel.ast.core import For
from psydac.pyccel.ast.core import Assign
from psydac.pyccel.ast.core import AugAssign
from psydac.pyccel.ast.core import Product
from psydac.pyccel.ast.core import _atomic
from psydac.pyccel.ast.core import Comment
from psydac.pyccel.ast.core import String
from psydac.pyccel.ast.core import AnnotatedArgument
__all__ = (
'build_pyccel_type_annotations',
'build_pythran_types_header',
'compute_atoms_expr',
'compute_atoms_expr_field',
'compute_atoms_expr_mapping',
'compute_boundary_jacobian',
'compute_normal_vector',
'compute_tangent_vector',
'filter_loops',
'filter_product',
'fusion_loops',
'get_name',
'is_mapping',
'logical2physical',
'math_atoms_as_str',
'random_string',
'rationalize_eval_mapping',
'select_loops',
'variables',
)
#==============================================================================
[docs]
def random_string( n ):
chars = string.ascii_lowercase + string.digits
selector = random.SystemRandom()
return ''.join( selector.choice( chars ) for _ in range( n ) )
#==============================================================================
[docs]
def is_mapping(expr):
if isinstance(expr, _logical_partial_derivatives):
return is_mapping(expr.args[0])
elif isinstance(expr, Indexed) and isinstance(expr.base, Mapping):
return True
elif isinstance(expr, Mapping):
return True
return False
#==============================================================================
[docs]
def logical2physical(expr):
partial_der = dict(zip(_logical_partial_derivatives,_partial_derivatives))
if isinstance(expr, _logical_partial_derivatives):
argument = logical2physical(expr.args[0])
new_expr = partial_der[type(expr)](argument)
return new_expr
else:
return expr
#==============================================================================
def _get_name(atom):
atom_name = None
if isinstance( atom, ScalarFunction ):
atom_name = str(atom.name)
elif isinstance( atom, VectorFunction ):
atom_name = str(atom.name)
elif isinstance( atom, IndexedVectorFunction ):
atom_name = str(atom.base.name)
else:
raise TypeError('> Wrong type')
return atom_name
#==============================================================================
[docs]
def compute_atoms_expr(atomic_exprs, indices_quad, indices_test,
indices_trial, basis_trial,
basis_test, cords, test_function,
is_linear,
mapping):
"""
This function computes atomic expressions needed
to evaluate the Kernel final expression
Parameters
----------
atomic_exprs : <list>
list of atoms
indices_quad : <list>
list of quadrature indices used in the quadrature loops
indices_test : <list>
list of test_functions indices used in the for loops of the basis functions
indices_trial : <list>
list of trial_functions indices used in the for loops of the basis functions
basis_test : <list>
list of basis functions in each dimesion
cords : <list>
list of coordinates Symbols
test_function : <Symbol>
test_function Symbol
is_linear : <boolean>
variable to determine if we are in the linear case
mapping : <Mapping>
Mapping object
Returns
-------
inits : <list>
list of assignments of the atomic expression evaluated in the quadrature points
map_stmts : <list>
list of assigments of atomic expression in case of mapping
"""
cls = (_partial_derivatives,
VectorFunction,
ScalarFunction,
IndexedVectorFunction)
dim = len(indices_test)
if not isinstance(atomic_exprs, (list, tuple, Tuple)):
raise TypeError('Expecting a list of atoms')
for atom in atomic_exprs:
if not isinstance(atom, cls):
raise TypeError('atom must be of type {}'.format(str(cls)))
# If there is a mapping, compute [dx(u), dy(u), dz(u)] as functions
# of [dx1(u), dx2(u), dx3(u)], and store results into intermediate
# variables [u_x, u_y, u_z]. (Same thing is done for higher derivatives.)
#
# Accordingly, we create a new list of atoms where all partial derivatives
# are taken with respect to the logical coordinates.
if mapping:
new_atoms = set()
map_stmts = []
get_index = get_index_logical_derivatives
get_atom = get_atom_logical_derivatives
for atom in atomic_exprs:
if isinstance(atom, _partial_derivatives):
lhs = SymbolicExpr(atom)
rhs_p = LogicalExpr(mapping, atom)
# we look for new_atoms that must be added to atomic_exprs
# because we need them in the maps stmts
logical_atoms = _atomic(rhs_p, cls=_logical_partial_derivatives)
for a in logical_atoms:
ls = _atomic(a, Symbol)
assert len(ls) == 1
if isinstance(ls[0], cls):
new_atoms.add(a)
rhs = SymbolicExpr(rhs_p)
map_stmts += [Assign(lhs, rhs)]
else:
new_atoms.add(atom)
else:
new_atoms = atomic_exprs
map_stmts = []
get_index = get_index_derivatives
get_atom = get_atom_derivatives
# Create a list of statements for initialization of the point values,
# for each of the atoms in our (possibly new) list.
inits = []
for atom in new_atoms:
orders = [*get_index(atom).values()]
a = get_atom(atom)
test = _get_name(a) in [_get_name(f) for f in test_function]
if test or is_linear:
basis = basis_test
idxs = indices_test
else:
basis = basis_trial
idxs = indices_trial
args = [b[i, d, q] for b, i, d, q in zip(basis, idxs, orders, indices_quad)]
lhs = SymbolicExpr(atom)
rhs = Mul(*args)
inits += [Assign(lhs, rhs)]
# Return the initialization statements, and the additional initialization
# of intermediate variables in case of mapping
return inits, map_stmts
#==============================================================================
[docs]
def compute_atoms_expr_field(atomic_exprs, indices_quad,
idxs, basis,
test_function, mapping):
"""
This function computes atomic expressions needed
to evaluate EvaluteField/VectorField final expression
Parameters
----------
atomic_exprs : <list>
list of atoms
indices_quad : <list>
list of quadrature indices used in the quadrature loops
idxs : <list>
list of basis functions indices used in the for loops of the basis functions
basis : <list>
list of basis functions in each dimesion
test_function : <Symbol>
test_function Symbol
mapping : <Mapping>
Mapping object
Returns
-------
inits : <list>
list of assignments of the atomic expression evaluated in the quadrature points
updates : <list>
list of augmented assignments which are updated in each loop iteration
map_stmts : <list>
list of assignments of atomic expression in case of mapping
new_atoms: <list>
updated list of atomic expressions (some were introduced in case of a mapping)
"""
inits = []
updates = []
map_stmts = []
cls = (_partial_derivatives,
ScalarFunction,
IndexedVectorFunction,
VectorFunction)
# If there is a mapping, compute [dx(u), dy(u), dz(u)] as functions
# of [dx1(u), dx2(u), dx3(u)], and store results into intermediate
# variables [u_x, u_y, u_z]. (Same thing is done for higher derivatives.)
#
# Accordingly, we create a new list of atoms where all partial derivatives
# are taken with respect to the logical coordinates.
if mapping:
new_atoms = set()
map_stmts = []
get_index = get_index_logical_derivatives
get_atom = get_atom_logical_derivatives
for atom in atomic_exprs:
if isinstance(atom, _partial_derivatives):
lhs = SymbolicExpr(atom)
rhs_p = LogicalExpr(mapping, atom)
# we look for new_atoms that must be added to atomic_exprs
# because we need them in the maps stmts
logical_atoms = _atomic(rhs_p, cls=_logical_partial_derivatives)
for a in logical_atoms:
ls = _atomic(a, Symbol)
assert len(ls) == 1
if isinstance(ls[0], cls):
new_atoms.add(a)
rhs = SymbolicExpr(rhs_p)
map_stmts += [Assign(lhs, rhs)]
else:
new_atoms.add(atom)
else:
new_atoms = atomic_exprs
map_stmts = []
get_index = get_index_derivatives
get_atom = get_atom_derivatives
# Make sure that we only pick one between 'dx1(dx2(u))' and 'dx2(dx1(u))'
new_atoms = {SymbolicExpr(a).name : a for a in new_atoms}
new_atoms = tuple(new_atoms.values())
# Create a list of statements for initialization of the point values,
# for each of the atoms in our (possibly new) list.
inits = []
for atom in new_atoms:
# Extract field, compute name of coefficient variable, and get base
if atom.atoms(ScalarFunction):
field = atom.atoms(ScalarFunction).pop()
field_name = 'coeff_' + SymbolicExpr(field).name
base = field
elif atom.atoms(VectorFunction):
field = atom.atoms(IndexedVectorFunction).pop()
field_name = 'coeff_' + SymbolicExpr(field).name
base = field.base
else:
raise TypeError('atom must be either scalar or vector field')
# Obtain variable for storing point values of test function
test_fun = SymbolicExpr(atom.subs(base, test_function))
# ...
orders = [*get_index(atom).values()]
args = [b[i, d, q] for b, i, d, q in zip(basis, idxs, orders, indices_quad)]
inits += [Assign(test_fun, Mul(*args))]
# ...
# ...
args = [IndexedBase(field_name)[idxs], test_fun]
val_name = SymbolicExpr(atom).name + '_values'
val = IndexedBase(val_name)[indices_quad]
updates += [AugAssign(val,'+',Mul(*args))]
# ...
return inits, updates, map_stmts, new_atoms
#==============================================================================
# TODO: merge into 'compute_atoms_expr_field'
[docs]
def compute_atoms_expr_mapping(atomic_exprs, indices_quad,
idxs, basis,
test_function):
"""
This function computes atomic expressions needed
to evaluate EvalMapping final expression
Parameters
----------
atomic_exprs : <list>
list of atoms
indices_quad : <list>
list of quadrature indices used in the quadrature loops
idxs : <list>
list of basis functions indices used in the for loops of the basis functions
basis : <list>
list of basis functions in each dimesion
test_function : <Symbol>
test_function Symbol
Returns
-------
inits : <list>
list of assignments of the atomic expression evaluated in the quadrature points
updates : <list>
list of augmented assignments which are updated in each loop iteration
"""
inits = []
updates = []
for atom in atomic_exprs:
element = get_atom_logical_derivatives(atom)
element_name = 'coeff_' + SymbolicExpr(element).name
# ...
test_fun = atom.subs(element, test_function)
test_fun = SymbolicExpr(test_fun)
# ...
# ...
orders = [*get_index_logical_derivatives(atom).values()]
args = [b[i, d, q] for b, i, d, q in zip(basis, idxs, orders, indices_quad)]
inits += [Assign(test_fun, Mul(*args))]
# ...
# ...
val_name = SymbolicExpr(atom).name + '_values'
val = IndexedBase(val_name)[indices_quad]
expr = IndexedBase(element_name)[idxs] * test_fun
updates += [AugAssign(val, '+', expr)]
# ...
return inits, updates
#==============================================================================
[docs]
def rationalize_eval_mapping(mapping, nderiv, space, indices_quad):
M = mapping
dim = space.ldim
ops = _logical_partial_derivatives[:dim]
# ... mapping components and their derivatives
components = [M[i] for i in range(0, dim)]
elements = list(components)
if nderiv > 0:
elements += [d(M[i]) for d in ops for i in range(0, dim)]
if nderiv > 1:
elements += [d1(d2(M[i])) for e,d1 in enumerate(ops)
for d2 in ops[:e+1]
for i in range(0, dim)]
if nderiv > 2:
raise NotImplementedError('TODO')
# ...
# ... weights and their derivatives
# TODO check if 'w' exist already
weights = element_of(space, name='w')
weights_elements = [weights]
if nderiv > 0:
weights_elements += [d(weights) for d in ops]
if nderiv > 1:
weights_elements += [d1(d2(weights)) for e,d1 in enumerate(ops)
for d2 in ops[:e+1]]
if nderiv > 2:
raise NotImplementedError('TODO')
# ...
stmts = []
# declarations
stmts += [Comment('declarations')]
for atom in elements + weights_elements:
atom_name = SymbolicExpr(atom).name
val_name = atom_name + '_values'
val = IndexedBase(val_name)[indices_quad]
stmt = Assign(atom_name, val)
stmts += [stmt]
# assignements
stmts += [Comment('rationalize')]
# 0 order terms
for i in range(dim):
w = SymbolicExpr(weights)
u = SymbolicExpr(M[i])
val_name = u.name + '_values'
val = IndexedBase(val_name)[indices_quad]
stmt = Assign(val, u / w )
stmts += [stmt]
# 1 order terms
if nderiv >= 1:
for d in ops:
w = SymbolicExpr( weights )
dw = SymbolicExpr(d(weights))
for i in range(dim):
u = SymbolicExpr( M[i] )
du = SymbolicExpr(d(M[i]))
val_name = du.name + '_values'
val = IndexedBase(val_name)[indices_quad]
stmt = Assign(val, du / w - u * dw / w**2 )
stmts += [stmt]
# 2 order terms
if nderiv >= 2:
for e, d1 in enumerate(ops):
for d2 in ops[:e+1]:
w = SymbolicExpr( weights )
d1w = SymbolicExpr( d1(weights) )
d2w = SymbolicExpr( d2(weights) )
d1d2w = SymbolicExpr(d1(d2(weights)))
for i in range(dim):
u = SymbolicExpr( M[i] )
d1u = SymbolicExpr( d1(M[i]) )
d2u = SymbolicExpr( d2(M[i]) )
d1d2u = SymbolicExpr(d1(d2(M[i])))
val_name = d1d2u.name + '_values'
val = IndexedBase(val_name)[indices_quad]
stmt = Assign(val,
d1d2u / w - u * d1d2w / w**2
- d1w * d2u / w**2 - d2w * d1u / w**2
+ 2 * u * d1w * d2w / w**3)
stmts += [stmt]
return stmts
#==============================================================================
[docs]
def filter_product(indices, args, boundary):
mask = []
ext = []
if boundary:
if isinstance(boundary, Boundary):
mask = [boundary.axis]
ext = [boundary.ext]
else:
raise TypeError
# discrete_boundary gives the perpendicular indices, then we need to
# remove them from directions
dim = len(indices)
args = [args[i][indices[i]] for i in range(dim) if not(i in mask)]
return Mul(*args)
#==============================================================================
# TODO remove it later
[docs]
def filter_loops(indices, ranges, body, boundary, boundary_basis=False):
quad_mask = []
quad_ext = []
if boundary:
if isinstance(boundary, Boundary):
quad_mask = [boundary.axis]
quad_ext = [boundary.ext]
else:
raise TypeError
# discrete_boundary gives the perpendicular indices, then we need to
# remove them from directions
dim = len(indices)
for i in range(dim-1,-1,-1):
rx = ranges[i]
x = indices[i]
start = rx.start
end = rx.stop
if i in quad_mask:
i_index = quad_mask.index(i)
ext = quad_ext[i_index]
if ext == -1:
end = start + 1
elif ext == 1:
start = end - 1
else:
raise ValueError('> Wrong value for ext. It should be -1 or 1')
rx = Range(start, end)
body = [For(x, rx, body)]
body = fusion_loops(body)
return body
#==============================================================================
[docs]
def select_loops(indices, ranges, body, boundary, boundary_basis=False):
quad_mask = []
quad_ext = []
if boundary:
if isinstance(boundary, Boundary):
quad_mask = [boundary.axis]
quad_ext = [boundary.ext]
else:
raise TypeError
# discrete_boundary gives the perpendicular indices, then we need to
# remove them from directions
dim = len(indices)
dims = [i for i in range(dim-1,-1,-1) if not( i in quad_mask )]
for i in dims:
rx = ranges[i]
x = indices[i]
start = rx.start
end = rx.stop
rx = Range(start, end)
body = [For(x, rx, body)]
body = fusion_loops(body)
return body
#==============================================================================
[docs]
def fusion_loops(loops):
ranges = []
indices = []
loops_cp = loops
while len(loops) == 1 and isinstance(loops[0], For):
loops = loops[0]
target = loops.target
iterable = loops.iterable
if isinstance(iterable, Product):
ranges += list(iterable.elements)
indices += list(target)
if not isinstance(target,(tuple,list,Tuple)):
raise ValueError('target must be a list or a tuple of indices')
elif isinstance(iterable, Range):
ranges.append(iterable)
indices.append(target)
else:
raise TypeError('only range an product are supported')
loops = loops.body
if len(ranges)>1:
return [For(indices, Product(*ranges), loops)]
else:
return loops_cp
#==============================================================================
[docs]
def compute_boundary_jacobian(parent_namespace, boundary, mapping=None):
# Sanity check on arguments
if not isinstance(boundary, Boundary):
raise TypeError(boundary)
if mapping is None:
stmts = []
else:
# Compute metric determinant g on manifold
J = SymbolicExpr(mapping.jacobian)
Jm = J[:, [i for i in range(J.shape[1]) if i != boundary.axis]]
g = (Jm.T * Jm).det()
# Create statements for computing sqrt(g)
det_jac_bnd = parent_namespace['det_jac_bnd']
stmts = [Assign(det_jac_bnd, sympy_sqrt(g))]
return stmts
#==============================================================================
[docs]
def compute_normal_vector(parent_namespace, vector, boundary, mapping=None):
# Sanity check on arguments
if isinstance(boundary, Boundary):
axis = boundary.axis
ext = boundary.ext
else:
raise TypeError(boundary)
# If there is no mapping, normal vector has only one non-zero component,
# which is +1 or -1 according to the orientation of the boundary.
if mapping is None:
return [Assign(v, ext if i==axis else 0) for i, v in enumerate(vector)]
# Given the Jacobian matrix J, we need to extract the (i=axis) row of
# J^(-1) and then normalize it. We recall that J^(-1)[i, j] is equal to
# the cofactor of J[i, j] divided by det(J). For efficiency we only
# compute the cofactors C[i=0:dim] of the (j=axis) column of J, and we
# do not divide them by det(J) because the normal vector will need to
# be normalized anyway.
#
# NOTE: we also change the vector orientation according to 'ext'
J = SymbolicExpr(mapping.jacobian)
values = [ext * J.cofactor(i, j=axis) for i in range(J.shape[0])]
# Create statements for computing normal vector components
stmts = [Assign(lhs, rhs) for lhs, rhs in zip(vector, values)]
# Normalize vector
inv_norm_variable = Symbol('inv_norm')
inv_norm_value = 1 / sympy_sqrt(sum(v**2 for v in values))
stmts += [Assign(inv_norm_variable, inv_norm_value)]
stmts += [AugAssign(v, '*', inv_norm_variable) for v in vector]
return stmts
#==============================================================================
[docs]
def compute_tangent_vector(parent_namespace, vector, boundary, mapping):
raise NotImplementedError('TODO')
#==============================================================================
_range = re.compile('([0-9]*:[0-9]+|[a-zA-Z]?:[a-zA-Z])')
[docs]
def variables(names, dtype, **args):
def contruct_variable(cls, name, dtype, rank, **args):
if issubclass(cls, Variable):
return Variable(dtype, name, rank=rank, **args)
elif issubclass(cls, IndexedVariable):
return IndexedVariable(name, dtype=dtype, rank=rank, **args)
elif cls==Idx:
assert dtype == "int"
rank = args.pop('rank', 0)
assert rank == 0
return Idx(name)
else:
raise TypeError('only Variables and IndexedVariables are supported')
result = []
cls = args.pop('cls', Variable)
rank = args.pop('rank', 0)
if isinstance(names, str):
marker = 0
literals = [r'\,', r'\:', r'\ ']
for i in range(len(literals)):
lit = literals.pop(0)
if lit in names:
while chr(marker) in names:
marker += 1
lit_char = chr(marker)
marker += 1
names = names.replace(lit, lit_char)
literals.append((lit_char, lit[1:]))
def literal(s):
if literals:
for c, l in literals:
s = s.replace(c, l)
return s
names = names.strip()
as_seq = names.endswith(',')
if as_seq:
names = names[:-1].rstrip()
if not names:
raise ValueError('no symbols given')
# split on commas
names = [n.strip() for n in names.split(',')]
if not all(n for n in names):
raise ValueError('missing symbol between commas')
# split on spaces
for i in range(len(names) - 1, -1, -1):
names[i: i + 1] = names[i].split()
seq = args.pop('seq', as_seq)
for name in names:
if not name:
raise ValueError('missing variable')
if ':' not in name:
var = contruct_variable(cls, literal(name), dtype, rank, **args)
result.append(var)
continue
split = _range.split(name)
# remove 1 layer of bounding parentheses around ranges
for i in range(len(split) - 1):
if i and ':' in split[i] and split[i] != ':' and \
split[i - 1].endswith('(') and \
split[i + 1].startswith(')'):
split[i - 1] = split[i - 1][:-1]
split[i + 1] = split[i + 1][1:]
for i, s in enumerate(split):
if ':' in s:
if s[-1].endswith(':'):
raise ValueError('missing end range')
a, b = s.split(':')
if b[-1] in string.digits:
a = 0 if not a else int(a)
b = int(b)
split[i] = [str(c) for c in range(a, b)]
else:
a = a or 'a'
split[i] = [string.ascii_letters[c] for c in range(
string.ascii_letters.index(a),
string.ascii_letters.index(b) + 1)] # inclusive
if not split[i]:
break
else:
split[i] = [s]
else:
seq = True
if len(split) == 1:
names = split[0]
else:
names = [''.join(s) for s in cartes(*split)]
if literals:
result.extend([contruct_variable(cls, literal(s), dtype, rank, **args) for s in names])
else:
result.extend([contruct_variable(cls, s, dtype, rank, **args) for s in names])
if not seq and len(result) <= 1:
if not result:
return ()
return result[0]
return tuple(result)
elif isinstance(names,(tuple,list)):
return tuple(variables(i, dtype, cls=cls,rank=rank,**args) for i in names)
else:
raise TypeError('Expecting a string')
#==============================================================================
[docs]
def build_pyccel_type_annotations(args, order=None):
new_args = []
for a in args:
if isinstance(a, Variable):
rank = a.rank
dtype = a.dtype.name.lower()
elif isinstance(a, IndexedVariable):
rank = a.rank
dtype = a.dtype.name.lower()
elif isinstance(a, Constant):
rank = 0
if a.is_integer:
dtype = 'int'
elif a.is_real:
dtype = 'float'
elif a.is_complex:
dtype = 'complex'
else:
raise TypeError(f"The Constant {a} don't have any information about the type of the variable.\n"
f"Please create the Constant like this Constant('{a}', real=True), Constant('{a}', complex=True) or Constant('{a}', integer=True).")
else:
raise TypeError('unexpected type for {}'.format(a))
if rank > 0:
shape = ','.join(':' * rank)
dtype = '{dtype}[{shape}]'.format(dtype=dtype, shape=shape)
if order and rank > 1:
dtype = "{dtype}(order={ordering})".format(dtype=dtype, ordering=order)
dtype = String(dtype)
new_a = AnnotatedArgument(a, dtype)
new_args.append(new_a)
return new_args
#==============================================================================
pythran_dtypes = {'real':'float','int':'int'}
#==============================================================================
from sympy import preorder_traversal
from sympy import NumberSymbol
from sympy import Pow, S
_known_functions_math = {
'acos': 'acos',
'acosh': 'acosh',
'asin': 'asin',
'asinh': 'asinh',
'atan': 'atan',
'atan2': 'atan2',
'atanh': 'atanh',
'ceiling': 'ceil',
'cos': 'cos',
'cosh': 'cosh',
'erf': 'erf',
'erfc': 'erfc',
'exp': 'exp',
'expm1': 'expm1',
'factorial': 'factorial',
'floor': 'floor',
'gamma': 'gamma',
'hypot': 'hypot',
'loggamma': 'lgamma',
'log': 'log',
'ln': 'log',
'log10': 'log10',
'log1p': 'log1p',
'log2': 'log2',
'sin': 'sin',
'sinh': 'sinh',
'Sqrt': 'sqrt',
'tan': 'tan',
'tanh': 'tanh'
} # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf
# radians trunc fmod fsum gcd degrees fabs]
_known_constants_math = {
'Exp1': 'e',
'Pi': 'pi',
'E': 'e'
# Only in python >= 3.5:
# 'Infinity': 'inf',
# 'NaN': 'nan'
}
_not_in_mpmath = 'log1p log2'.split()
_in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
_known_functions_mpmath = dict(_in_mpmath, **{
'beta': 'beta',
'fresnelc': 'fresnelc',
'fresnels': 'fresnels',
'sign': 'sign',
})
_known_constants_mpmath = {
'Exp1': 'e',
'Pi': 'pi',
'GoldenRatio': 'phi',
'EulerGamma': 'euler',
'Catalan': 'catalan',
'NaN': 'nan',
'Infinity': 'inf',
'NegativeInfinity': 'ninf'
}
_not_in_numpy = 'erf erfc factorial gamma loggamma'.split()
_in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy]
_known_functions_numpy = dict(_in_numpy, **{
'acos': 'arccos',
'acosh': 'arccosh',
'asin': 'arcsin',
'asinh': 'arcsinh',
'atan': 'arctan',
'atan2': 'arctan2',
'atanh': 'arctanh',
'exp2': 'exp2',
'sign': 'sign',
})
_known_constants_numpy = {
'Exp1': 'e',
'Pi': 'pi',
'EulerGamma': 'euler_gamma',
'NaN': 'nan',
'Infinity': 'PINF',
'NegativeInfinity': 'NINF'
}
[docs]
def math_atoms_as_str(expr, lib='math'):
"""
Given a Sympy expression, find all known mathematical atoms (functions and
constants) that need to be imported from a math library (e.g. Numpy) when
generating Python code.
Parameters
----------
expr : sympy.core.expr.Expr
Symbolic expression for which Python code is to be generated.
lib : str
Library used to translate symbolic functions/constants into standard
Python ones. Options: ['math', 'mpmath', 'numpy']. Default: 'math'.
Returns
-------
imports : set of str
Set of all names (strings) to be imported.
"""
# Choose translation dictionaries
if lib == 'math':
known_functions = _known_functions_math
known_constants = _known_constants_math
elif lib == 'mpmath':
known_functions = _known_functions_mpmath
known_constants = _known_constants_mpmath
elif lib == 'numpy':
known_functions = _known_functions_numpy
known_constants = _known_constants_numpy # numpy version missing
else:
raise ValueError("Library {} not supported.".format(mod))
# Initialize variables
math_functions = set()
math_constants = set()
sqrt = False
# Walk expression tree
for i in preorder_traversal(expr):
# Search for math functions (e.g. cos, sin, exp, ...)
if isinstance(i, Function):
s = str(type(i))
if s in known_functions:
p = known_functions[s]
math_functions.add(p)
# Search for math constants (e.g. pi, e, ...)
elif isinstance(i, NumberSymbol):
s = type(i).__name__
if s in known_constants:
p = known_constants[s]
math_constants.add(p)
# Search for square roots
elif (not sqrt):
if isinstance(i, Pow) and ((i.exp is S.Half) or (i.exp == -S.Half)):
math_functions.add('sqrt')
sqrt = True
return set.union(math_functions, math_constants)
[docs]
def get_name(lhs):
"""
Given a list of variable return the meaningful part of the name of the
first variable that has a _name attribute.
Was added to solve issue #327 caused by trying to access the name of a
variable that has not such attribute.
Parameters
----------
lhs : list
list from whom we need to extract a name.
Returns
-------
str
meaningful part of the name of the variable or "zero term" if no
variable has a name.
"""
for term in lhs:
if hasattr(term, '_name'):
return term._name[12:-8]
return "zero_term"