# -*- coding: utf-8 -*-
#------------------------------------------------------------------------------------------#
# This file is part of Pyccel which is released under MIT License. See the LICENSE file or #
# go to https://github.com/pyccel/pyccel/blob/devel/LICENSE for full license details. #
#------------------------------------------------------------------------------------------#
""" File containing SemanticParser. This class handles the semantic stage of the translation.
See the developer docs for more details
"""
from itertools import chain, product
import os
from types import ModuleType, BuiltinFunctionType
import typing
import warnings
from sympy.utilities.iterables import iterable as sympy_iterable
from sympy import Sum as Summation
from sympy import Symbol as sp_Symbol
from sympy import Integer as sp_Integer
from sympy.logic.boolalg import BooleanTrue as sp_True
from sympy.logic.boolalg import BooleanFalse as sp_False
from sympy import ceiling
from textx.exceptions import TextXSyntaxError
#==============================================================================
from pyccel.ast.basic import PyccelAstNode, TypedAstNode, ScopedAstNode, iterable
from pyccel.ast.bitwise_operators import PyccelBitOr, PyccelLShift, PyccelRShift, PyccelBitAnd
from pyccel.ast.builtins import PythonPrint, PythonTupleFunction, PythonSetFunction
from pyccel.ast.builtins import PythonComplex, PythonDict, PythonListFunction
from pyccel.ast.builtins import builtin_functions_dict, PythonImag, PythonReal
from pyccel.ast.builtins import PythonList, PythonConjugate , PythonSet, VariableIterator
from pyccel.ast.builtins import PythonRange, PythonZip, PythonEnumerate, PythonTuple
from pyccel.ast.builtins import Lambda, PythonMap, PythonBool
from pyccel.ast.builtin_methods.dict_methods import DictKeys
from pyccel.ast.builtin_methods.list_methods import ListAppend, ListPop, ListInsert
from pyccel.ast.builtin_methods.set_methods import SetAdd, SetUnion, SetCopy, SetIntersectionUpdate
from pyccel.ast.builtin_methods.set_methods import SetPop
from pyccel.ast.builtin_methods.dict_methods import DictGetItem, DictGet, DictPop, DictPopitem
from pyccel.ast.core import Comment, CommentBlock, Pass
from pyccel.ast.core import If, IfSection
from pyccel.ast.core import Allocate, Deallocate
from pyccel.ast.core import Assign, AliasAssign
from pyccel.ast.core import AugAssign, CodeBlock
from pyccel.ast.core import Return, FunctionDefArgument, FunctionDefResult
from pyccel.ast.core import ConstructorCall, InlineFunctionDef
from pyccel.ast.core import FunctionDef, Interface, FunctionAddress, FunctionCall, FunctionCallArgument
from pyccel.ast.core import ClassDef
from pyccel.ast.core import For
from pyccel.ast.core import Module
from pyccel.ast.core import While
from pyccel.ast.core import Del
from pyccel.ast.core import Program
from pyccel.ast.core import EmptyNode
from pyccel.ast.core import Concatenate
from pyccel.ast.core import Import
from pyccel.ast.core import AsName
from pyccel.ast.core import With
from pyccel.ast.core import Duplicate
from pyccel.ast.core import StarredArguments
from pyccel.ast.core import Decorator
from pyccel.ast.core import PyccelFunctionDef
from pyccel.ast.core import Assert
from pyccel.ast.core import AllDeclaration
from pyccel.ast.class_defs import get_cls_base, SetClass
from pyccel.ast.datatypes import CustomDataType, PyccelType, TupleType, VoidType, GenericType
from pyccel.ast.datatypes import PrimitiveIntegerType, StringType, SymbolicType
from pyccel.ast.datatypes import PythonNativeBool, PythonNativeInt, PythonNativeFloat
from pyccel.ast.datatypes import DataTypeFactory, HomogeneousContainerType
from pyccel.ast.datatypes import InhomogeneousTupleType, HomogeneousTupleType, HomogeneousSetType, HomogeneousListType
from pyccel.ast.datatypes import PrimitiveComplexType, FixedSizeNumericType, DictType, TypeAlias
from pyccel.ast.datatypes import original_type_to_pyccel_type
from pyccel.ast.functionalexpr import FunctionalSum, FunctionalMax, FunctionalMin, GeneratorComprehension, FunctionalFor
from pyccel.ast.functionalexpr import MaxLimit, MinLimit
from pyccel.ast.headers import Header
from pyccel.ast.internals import PyccelFunction, Slice, PyccelSymbol, PyccelArrayShapeElement
from pyccel.ast.internals import Iterable
from pyccel.ast.itertoolsext import Product
from pyccel.ast.literals import LiteralTrue, LiteralFalse
from pyccel.ast.literals import LiteralInteger, LiteralFloat
from pyccel.ast.literals import Nil, LiteralString, LiteralImaginaryUnit
from pyccel.ast.literals import Literal, convert_to_literal, LiteralEllipsis
from pyccel.ast.low_level_tools import MemoryHandlerType, UnpackManagedMemory, ManagedMemory
from pyccel.ast.mathext import math_constants, MathSqrt, MathAtan2, MathSin, MathCos
from pyccel.ast.numpyext import NumpyMatmul, numpy_funcs
from pyccel.ast.numpyext import NumpyWhere, NumpyArray
from pyccel.ast.numpyext import NumpyTranspose, NumpyConjugate
from pyccel.ast.numpyext import NumpyNewArray, NumpyResultType
from pyccel.ast.numpyext import process_dtype as numpy_process_dtype
from pyccel.ast.numpyext import get_shape_of_multi_level_container
from pyccel.ast.numpytypes import NumpyNDArrayType
from pyccel.ast.omp import (OMP_For_Loop, OMP_Simd_Construct, OMP_Distribute_Construct,
OMP_TaskLoop_Construct, OMP_Sections_Construct, Omp_End_Clause,
OMP_Single_Construct)
from pyccel.ast.operators import PyccelArithmeticOperator, PyccelIs, PyccelIsNot, IfTernaryOperator, PyccelUnarySub
from pyccel.ast.operators import PyccelNot, PyccelAdd, PyccelMinus, PyccelMul, PyccelPow, PyccelOr
from pyccel.ast.operators import PyccelAssociativeParenthesis, PyccelDiv, PyccelIn, PyccelOperator
from pyccel.ast.operators import PyccelAnd
from pyccel.ast.sympy_helper import sympy_to_pyccel, pyccel_to_sympy
from pyccel.ast.type_annotations import VariableTypeAnnotation, UnionTypeAnnotation, SyntacticTypeAnnotation
from pyccel.ast.type_annotations import FunctionTypeAnnotation, typenames_to_dtypes
from pyccel.ast.typingext import TypingFinal, TypingTypeVar
from pyccel.ast.utilities import builtin_import as pyccel_builtin_import
from pyccel.ast.utilities import builtin_import_registry as pyccel_builtin_import_registry
from pyccel.ast.utilities import split_positional_keyword_arguments
from pyccel.ast.utilities import recognised_source, is_literal_integer, get_managed_memory_object
from pyccel.ast.variable import Constant
from pyccel.ast.variable import Variable
from pyccel.ast.variable import IndexedElement, AnnotatedPyccelSymbol
from pyccel.ast.variable import DottedName, DottedVariable
from pyccel.errors.errors import Errors, ErrorsMode, PyccelError, PyccelSemanticError
from pyccel.errors.messages import (PYCCEL_RESTRICTION_TODO, UNDERSCORE_NOT_A_THROWAWAY,
UNDEFINED_VARIABLE, IMPORTING_EXISTING_IDENTIFIED, INDEXED_TUPLE, LIST_OF_TUPLES,
INVALID_INDICES, INCOMPATIBLE_ARGUMENT,
UNRECOGNISED_FUNCTION_CALL, STACK_ARRAY_SHAPE_UNPURE_FUNC, STACK_ARRAY_UNKNOWN_SHAPE,
ARRAY_DEFINITION_IN_LOOP, STACK_ARRAY_DEFINITION_IN_LOOP, MISSING_TYPE_ANNOTATIONS,
INCOMPATIBLE_TYPES_IN_ASSIGNMENT, ARRAY_ALREADY_IN_USE, ASSIGN_ARRAYS_ONE_ANOTHER,
INVALID_POINTER_REASSIGN, ARRAY_IS_ARG,
INCOMPATIBLE_REDEFINITION_STACK_ARRAY, ARRAY_REALLOCATION, RECURSIVE_RESULTS_REQUIRED,
PYCCEL_RESTRICTION_INHOMOG_LIST, UNDEFINED_IMPORT_OBJECT, UNDEFINED_LAMBDA_VARIABLE,
UNDEFINED_LAMBDA_FUNCTION, UNDEFINED_INIT_METHOD, UNDEFINED_FUNCTION,
WRONG_NUMBER_OUTPUT_ARGS, INVALID_FOR_ITERABLE,
PYCCEL_RESTRICTION_LIST_COMPREHENSION_LIMITS, PYCCEL_RESTRICTION_LIST_COMPREHENSION_SIZE,
UNUSED_DECORATORS, UNSUPPORTED_POINTER_RETURN_VALUE, PYCCEL_RESTRICTION_OPTIONAL_NONE,
PYCCEL_RESTRICTION_PRIMITIVE_IMMUTABLE, PYCCEL_RESTRICTION_IS_ISNOT,
FOUND_DUPLICATED_IMPORT, UNDEFINED_WITH_ACCESS,
PYCCEL_INTERNAL_ERROR)
from pyccel.parser.base import BasicParser
from pyccel.parser.syntactic import SyntaxParser
from pyccel.parser.syntax.headers import types_meta
from pyccel.utilities.stage import PyccelStage
import pyccel.decorators as def_decorators
#==============================================================================
errors = Errors()
pyccel_stage = PyccelStage()
type_container = {
PythonTupleFunction : HomogeneousTupleType,
PythonListFunction : HomogeneousListType,
PythonSetFunction : HomogeneousSetType,
NumpyArray : NumpyNDArrayType,
}
#==============================================================================
def _get_name(var):
"""."""
if isinstance(var, str):
return var
if isinstance(var, (PyccelSymbol, DottedName)):
return str(var)
if isinstance(var, (IndexedElement)):
return str(var.base)
if isinstance(var, FunctionCall):
return var.funcdef
name = type(var).__name__
msg = f'Name of Object : {name} cannot be determined'
return errors.report(PYCCEL_RESTRICTION_TODO+'\n'+msg, symbol=var,
severity='fatal')
magic_method_map = {
PyccelAdd: '__add__',
PyccelMinus: '__sub__',
PyccelMul: '__mul__',
PyccelDiv: '__truediv__',
PyccelPow: '__pow__',
PyccelLShift: '__lshift__',
PyccelRShift: '__rshift__',
PyccelBitAnd : '__and__',
PyccelBitOr: '__or__',
}
#==============================================================================
[docs]
class SemanticParser(BasicParser):
"""
Class which handles the semantic stage as described in the developer docs.
This class is described in detail in developer_docs/semantic_stage.md.
It determines all semantic information which must be deduced in order to
print a representation of the AST resulting from the syntactic stage in one
of the target languages.
Parameters
----------
inputs : SyntaxParser
A syntactic parser which has been used to generate a representation of
the input code using Pyccel nodes.
parents : list
A list of parsers describing the files which import this file.
d_parsers : list
A list of parsers describing files imported by this file.
context_dict : dict, optional
A dictionary describing any variables in the context where the translated
objected was defined.
**kwargs : dict
Additional keyword arguments for BasicParser.
"""
def __init__(self, inputs, *, parents = (), d_parsers = (), context_dict = None, **kwargs):
# a Parser can have parents, who are importing it.
# imports are then its sons.
self._parents = list(parents)
self._d_parsers = dict(d_parsers)
# ...
if not isinstance(inputs, SyntaxParser):
raise TypeError('> Expecting a syntactic parser as input')
parser = inputs
# ...
# ...
BasicParser.__init__(self, **kwargs)
# ...
# ...
self._fst = parser._fst
self._ast = parser._ast
self._filename = parser._filename
self._mod_name = ''
self._metavars = parser._metavars
self.scope = parser.scope
self.scope.imports['imports'] = {}
self._module_namespace = self.scope
self._in_annotation = False
# used to store the local variables of a code block needed for garbage collecting
self._allocs = []
# used to store code split into multiple lines to be reinserted in the CodeBlock
self._additional_exprs = []
# used to store variables if optional parameters are changed
self._optional_params = {}
# used to link pointers to their targets. This is important for classes which may
# contain persistent pointers
self._pointer_targets = []
# provides information about the calling context to collect constants
self._context_dict = context_dict or {}
#
self._code = parser._code
# ...
self.annotate()
# ...
#================================================================
# Property accessors
#================================================================
@property
def parents(self):
"""Returns the parents parser."""
return self._parents
@property
def d_parsers(self):
"""Returns the d_parsers parser."""
return self._d_parsers
#================================================================
# Public functions
#================================================================
[docs]
def annotate(self):
"""
Add type information to the AST.
This function is the entry point for this class. It annotates the
AST object created by the syntactic stage which was collected
in the constructor. The annotation adds all necessary information
about the type etc to describe the object sufficiently well for
printing. See the developer docs for more details.
Returns
-------
pyccel.ast.basic.PyccelAstNode
An annotated object which can be printed.
"""
if self.semantic_done:
print ('> semantic analysis already done')
return self.ast
# TODO - add settings to Errors
# - filename
errors = Errors()
if self.filename:
errors.set_target(self.filename)
# then we treat the current file
ast = self.ast
self._allocs.append(set())
self._pointer_targets.append({})
# we add the try/except to allow the parser to find all possible errors
pyccel_stage.set_stage('semantic')
ast = self._visit(ast)
self._ast = ast
self._semantic_done = True
return ast
#================================================================
# Utility functions for scope handling
#================================================================
[docs]
def create_new_function_scope(self, syntactic_name, semantic_name, **kwargs):
"""
Create a new Scope object for a Python function.
Create a new Scope object for a Python function with the given name,
and attach any decorators' information to the scope. The new scope is
a child of the current one, and can be accessed from the dictionary of
its children using the function name as key.
Before returning control to the caller, the current scope (stored in
self._scope) is changed to the one just created, and the function's
name is stored in self._current_function_name.
Parameters
----------
syntactic_name : str
Function's original name in the translated code, used as a key to
retrieve the new scope.
semantic_name : str
The new name of the function by which it will be known in the target
language.
**kwargs : dict
Keyword arguments passed through to the new scope.
Returns
-------
Scope
The new scope for the function.
"""
child = self.scope.new_child_scope(syntactic_name, **kwargs)
child.local_used_symbols[syntactic_name] = semantic_name
child.python_names[semantic_name] = syntactic_name
self._scope = child
self._current_function_name.append(semantic_name)
return child
[docs]
def get_class_prefix(self, name):
"""
Search for the class prefix of a dotted name in the current scope.
Search for a Variable object with the class prefix found in the given
name inside the current scope, defined by the local and global Python
scopes. Return None if not found.
Parameters
----------
name : DottedName
The dotted name which begins with a class definition.
Returns
-------
Variable
Returns the class definition if found or None otherwise.
"""
prefix_parts = name.name[:-1]
syntactic_prefix = prefix_parts[0] if len(prefix_parts) == 1 else DottedName(*prefix_parts)
return self._visit(syntactic_prefix)
[docs]
def check_for_variable(self, name):
"""
Search for a Variable object with the given name in the current scope.
Search for a Variable object with the given name in the current scope,
defined by the local and global Python scopes. Return None if not found.
Parameters
----------
name : str | DottedName
The object describing the variable.
Returns
-------
Variable
Returns the variable if found or None.
See Also
--------
get_variable
A similar function which raises an error if the Variable is not found
instead of returning None.
"""
if isinstance(name, DottedName):
prefix = self.get_class_prefix(name)
try:
class_def = prefix.cls_base
except AttributeError:
class_def = get_cls_base(prefix.class_type) or \
self.scope.find(str(prefix.class_type), 'classes')
attr_name = name.name[-1]
class_scope = class_def.scope
if class_scope is None:
# Pyccel defined classes have no variables
return None
attribute = class_scope.find(attr_name, 'variables') if class_def else None
if attribute:
return attribute.clone(attribute.name, new_class = DottedVariable, lhs = prefix)
else:
return None
return self.scope.find(name, 'variables')
[docs]
def get_variable(self, name):
"""
Get a Variable object with the given name from the current scope.
Search for a Variable object with the given name in the current scope,
defined by the local and global Python scopes. Raise an error if not found.
Parameters
----------
name : str
The object describing the variable.
Returns
-------
Variable
Returns the variable found in the scope.
Raises
------
PyccelSemanticError
Error raised if variable is not found.
See Also
--------
check_for_variable
A similar function which returns None if the Variable is not found
instead of raising an error.
"""
var = self.check_for_variable(name)
if var is None:
if name == '_':
errors.report(UNDERSCORE_NOT_A_THROWAWAY,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
else:
errors.report(UNDEFINED_VARIABLE, symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
return var
[docs]
def get_variables(self, container):
"""
Get all variables in the scope of interest.
Get a list of all variables which are
Parameters
----------
container : Scope
The object describing the relevant scope.
Returns
-------
list
A list of variables.
"""
# this only works if called on a function scope
# TODO needs more tests when we have nested functions
variables = []
variables.extend(container.variables.values())
for sub_container in container.loops:
variables.extend(self.get_variables(sub_container))
return variables
[docs]
def get_class_construct(self, name):
"""
Return the class datatype associated with name.
Return the class datatype for name if it exists.
Raise an error otherwise.
Parameters
----------
name : str
The name of the class.
Returns
-------
PyccelType
The datatype for the class.
Raises
------
PyccelSemanticError
Raised if the datatype cannot be found.
"""
result = self.scope.find(name, 'cls_constructs')
if result is None:
msg = f'class construct {name} not found'
return errors.report(msg,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
else:
return result
[docs]
def insert_import(self, name, target, storage_name = None):
"""
Insert a new import into the scope.
Create and insert a new import in scope if it's not defined
otherwise append target to existing import.
Parameters
----------
name : str-like
The source from which the object is imported.
target : AsName
The imported object.
storage_name : str-like
The name which will be used to identify the Import in the
container.
"""
source = _get_name(name)
if storage_name is None:
storage_name = next((n for n, imp in self.scope.find_all('imports').items()
if imp.source == source), None)
imp = self.scope.find(source, 'imports')
found_from_import_name = False
if imp is None:
imp = self.scope.find(storage_name, 'imports')
found_from_import_name = True
if imp is not None:
if found_from_import_name or source in (imp.source, getattr(imp.source_module, 'name', '')):
imp.define_target(target)
else:
errors.report(IMPORTING_EXISTING_IDENTIFIED,
symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
else:
current_scope = self.scope
while current_scope.is_loop:
current_scope = current_scope.parent_scope
container = current_scope.imports
container['imports'][storage_name] = Import(source, target, True)
[docs]
def create_tuple_of_inhomogeneous_elements(self, tuple_var):
"""
Create a tuple of variables from a variable representing an inhomogeneous object.
Create a tuple of variables that can be printed in a low-level language. An
inhomogeneous object cannot be represented as is in a low-level language so
it must be unpacked into a PythonTuple. This function is recursive so that
variables with a type such as `tuple[tuple[int,bool],float]` generate
`PythonTuple(PythonTuple(var_0_0, var_0_1), var_1)`.
Parameters
----------
tuple_var : Variable
A variable which may or may not be an inhomogeneous tuple.
Returns
-------
Variable | PythonTuple
An object containing only variables that can be printed in a low-level language.
"""
if isinstance(tuple_var.class_type, InhomogeneousTupleType):
return PythonTuple(*[self.create_tuple_of_inhomogeneous_elements(self.scope.collect_tuple_element(v)) for v in tuple_var])
else:
return tuple_var
#=======================================================
# Utility functions
#=======================================================
def _garbage_collector(self, expr):
"""
Search in a CodeBlock if no trailing Return Node is present add the needed frees.
The primary purpose of _garbage_collector is to search within a CodeBlock
instance for cases where no trailing Return node is present, and when such
situations occur, it adds the necessary deallocate operations to free up resources.
Parameters
----------
expr : CodeBlock
The body where the method searches for the absence of trailing `Return` nodes.
Returns
-------
List
A list of instances of the `Deallocate` type.
"""
deallocs = []
if all(r.expr is None for r in expr.get_attribute_nodes(Return)):
for i in self._allocs[-1]:
if isinstance(i, DottedVariable):
if isinstance(i.lhs.class_type, CustomDataType) and self.current_function_name != '__del__':
continue
if isinstance(i.class_type, CustomDataType) and i.is_alias:
continue
deallocs.append(Deallocate(i))
self._allocs.pop()
return deallocs
def _check_pointer_targets(self, exceptions = ()):
"""
Check that all pointer targets to be deallocated are not needed beyond this scope.
At the end of a scope (function/module/class) the objects contained within it are
deallocated. However some objects may persist beyond the scope. For example a
class instance persists after a call to a class method, and the arguments of a
function persist after a call to that function. If one of these persistent objects
contains a pointer then it is important that the target of that pointer has the
same lifetime. The target must not be deallocated at the end of the function if
the pointer persists.
This function checks through self._pointer_targets[-1] which is the dictionary
describing the association of pointers to targets in this scope. First it removes
all pointers which are deallocated at the end of this context. Next it checks if
any of the objects which will be deallocated are present amongst the targets for
the scope. If this is the case then an error is raised. Finally it loops through
the remaining pointer/target pairs and ensures that any arguments which are targets
are marked as such.
Parameters
----------
exceptions : tuple of Variables
A list of objects in `_allocs` which are to be ignored (variables appearing
in a return statement).
"""
assert len(self._allocs) == len(self._pointer_targets)
assert not isinstance(exceptions, Variable)
for i in self._allocs[-1]:
if i in exceptions:
continue
self._pointer_targets[-1].pop(i, None)
targets = {t[0]:t[1] for target_list in self._pointer_targets[-1].values() for t in target_list}
for i in self._allocs[-1]:
if i in exceptions:
continue
if i in targets:
errors.report(f"Variable {i} goes out of scope but may be the target of a pointer which is still required",
severity='error', symbol=targets[i])
if self.current_function_name:
current_func = self._current_function[-1]
arg_vars = {a.var:a for a in current_func.arguments}
for p, t_list in self._pointer_targets[-1].items():
if p in arg_vars and arg_vars[p].bound_argument:
for t,_ in t_list:
if t.is_argument:
argument_objects = t.get_direct_user_nodes(lambda x: isinstance(x, FunctionDefArgument))
assert len(argument_objects) == 1
argument_objects[0].persistent_target = True
def _indicate_pointer_target(self, pointer, target, expr):
"""
Indicate that a pointer is targeting a specific target.
Indicate that a pointer is targeting a specific target by adding the pair
to a dictionary in self._pointer_targets (the last dictionary in the list
should be used as this is the one for the current scope).
Parameters
----------
pointer : Variable
The variable which is pointing at something.
target : Variable | IndexedElement
The object being pointed at by the pointer.
expr : PyccelAstNode
The expression where the pointer was created (used for clear error
messages).
"""
if pointer is target:
return
assert pointer != target
assert not isinstance(pointer.class_type, (StringType, FixedSizeNumericType))
pointing_at_container_element = (isinstance(pointer.class_type, (HomogeneousSetType, HomogeneousListType)) \
and (target.class_type is pointer.class_type.element_type)) or \
(isinstance(pointer.class_type, DictType) \
and (target.class_type is pointer.class_type.value_type))
container_pointing_at_element = (isinstance(target.class_type, (HomogeneousSetType, HomogeneousListType)) \
and (pointer.class_type is target.class_type.element_type)) or \
(isinstance(target.class_type, DictType) \
and (pointer.class_type is target.class_type.value_type))
if pointing_at_container_element or container_pointing_at_element:
managed_var = target if target.rank < pointer.rank else pointer
if isinstance(managed_var, Variable):
managed_mem = managed_var.get_direct_user_nodes(lambda u: isinstance(u, ManagedMemory))
if not managed_mem:
mem_var = Variable(MemoryHandlerType(managed_var.class_type),
self.scope.get_new_name(f'{managed_var.name}_mem'),
shape=None, memory_handling='heap')
self.scope.insert_variable(mem_var)
ManagedMemory(managed_var, mem_var)
# The class itself should also be aware of the target for freeing
if isinstance(pointer, DottedVariable):
self._indicate_pointer_target(pointer.lhs, target, expr)
if isinstance(target, DottedVariable):
self._indicate_pointer_target(pointer, target.lhs, expr)
elif isinstance(target, IndexedElement):
self._indicate_pointer_target(pointer, target.base, expr)
elif isinstance(target, (DictGetItem, DictGet)):
self._indicate_pointer_target(pointer, target.dict_obj, expr)
elif isinstance(target, Variable):
if target.is_alias:
sub_targets = None
try:
sub_targets = self._pointer_targets[-1][target]
except KeyError:
errors.report("Pointer cannot point at a non-local pointer\n"+PYCCEL_RESTRICTION_TODO,
severity='error', symbol=expr)
if sub_targets:
self._pointer_targets[-1].setdefault(pointer, []).extend((t[0], expr) for t in sub_targets)
else:
target.is_target = True
self._pointer_targets[-1].setdefault(pointer, []).append((target, expr))
elif isinstance(target, FunctionCall):
if isinstance(target.funcdef, FunctionDef):
if target.funcdef.result_pointer_map:
raise NotImplementedError("TODO results point at args")
elif isinstance(target, (PythonList, PythonSet, PythonTuple)):
if not isinstance(target.class_type.element_type, (StringType, FixedSizeNumericType)):
for v in target:
self._indicate_pointer_target(pointer, v, expr)
elif isinstance(target, (ListPop, DictPop)):
target_var = target.list_obj if isinstance(target, ListPop) else target.dict_obj
if target_var in self._pointer_targets[-1]:
sub_targets = self._pointer_targets[-1][target_var]
self._pointer_targets[-1].setdefault(pointer, []).extend((t[0], expr) for t in sub_targets)
elif isinstance(pointer, Variable):
self._allocs[-1].add(pointer)
if isinstance(pointer, Variable):
managed_mem = pointer.get_direct_user_nodes(lambda u: isinstance(u, ManagedMemory))
if not managed_mem:
mem_var = Variable(MemoryHandlerType(pointer.class_type),
self.scope.get_new_name(f'{pointer.name}_mem'),
shape=None, memory_handling='heap')
self.scope.insert_variable(mem_var)
ManagedMemory(pointer, mem_var)
elif isinstance(target, PythonDict):
if not isinstance(target.class_type.value_type, (StringType, FixedSizeNumericType)):
for v in target.values:
self._indicate_pointer_target(pointer, v, expr)
elif isinstance(pointer, Variable) and pointer.is_alias:
errors.report("Pointer cannot point at a temporary object",
severity='error', symbol=expr)
def _infer_type(self, expr):
"""
Infer all relevant type information for the expression.
Create a dictionary describing all the type information that can be
inferred about the expression `expr`. This includes information about:
- `class_type`
- `shape`
- `cls_base`
- `memory_handling`
Parameters
----------
expr : pyccel.ast.basic.PyccelAstNode
An AST object representing an object in the code whose type
must be determined.
Returns
-------
dict
Dictionary containing all the type information which was inferred.
"""
if not isinstance(expr, TypedAstNode):
return {'class_type' : SymbolicType()}
d_var = {
'class_type' : expr.class_type,
'shape' : expr.shape,
'cls_base' : self.scope.find(str(expr.class_type), 'classes') or get_cls_base(expr.class_type),
'memory_handling' : 'heap' if expr.rank > 0 else 'stack'
}
if isinstance(expr, Variable):
d_var['memory_handling'] = expr.memory_handling
if expr.cls_base:
d_var['cls_base' ] = expr.cls_base
return d_var
elif isinstance(expr, Concatenate):
if any(getattr(a, 'on_heap', False) for a in expr.args):
d_var['memory_handling'] = 'heap'
else:
d_var['memory_handling'] = 'stack'
return d_var
elif isinstance(expr, Duplicate):
d = self._infer_type(expr.val)
if d.get('on_stack', False) and isinstance(expr.length, LiteralInteger):
d_var['memory_handling'] = 'stack'
else:
d_var['memory_handling'] = 'heap'
return d_var
elif isinstance(expr, NumpyTranspose):
var = expr.internal_var
d_var['memory_handling'] = 'alias' if isinstance(var, Variable) else 'heap'
return d_var
elif isinstance(expr, PythonTuple):
if isinstance(expr.class_type, HomogeneousTupleType):
d_var['shape'] = get_shape_of_multi_level_container(expr)
return d_var
elif isinstance(expr, (DictGetItem, DictGet)):
d_var['memory_handling'] = 'alias' if not isinstance(expr.class_type, FixedSizeNumericType) else 'stack'
return d_var
elif isinstance(expr, TypedAstNode):
d_var['memory_handling'] = 'heap' if expr.rank > 0 else 'stack'
return d_var
else:
type_name = type(expr).__name__
msg = f'Type of Object : {type_name} cannot be inferred'
return errors.report(PYCCEL_RESTRICTION_TODO+'\n'+msg, symbol=expr,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
def _extract_indexed_from_var(self, var, indices, expr):
"""
Use indices to extract appropriate element from object 'var'.
Use indices to extract appropriate element from object 'var'.
This contains most of the contents of _visit_IndexedElement
but is a separate function in order to be recursive.
Parameters
----------
var : Variable
The variable being indexed.
indices : iterable
The indexes used to access the variable.
expr : PyccelAstNode
The node being parsed. This is useful for raising errors.
Returns
-------
TypedAstNode
The visited object.
"""
# case of Pyccel ast Variable
# if not possible we use symbolic objects
if isinstance(var, PythonTuple):
def is_literal_index(a):
def is_int(a):
return isinstance(a, (int, LiteralInteger)) or \
(isinstance(a, PyccelUnarySub) and \
isinstance(a.args[0], (int, LiteralInteger)))
if isinstance(a, Slice):
return all(is_int(s) or s is None for s in (a.start, a.step, a.stop))
else:
return is_int(a)
if all(is_literal_index(a) for a in indices):
if len(indices)==1:
return var[indices[0]]
else:
return self._visit(var[indices[0]][indices[1:]])
else:
pyccel_stage.set_stage('syntactic')
tmp_var = PyccelSymbol(self.scope.get_new_name())
assign = Assign(tmp_var, var)
assign.set_current_ast(expr.python_ast)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(assign))
var = self._visit(tmp_var)
elif isinstance(var, Variable):
# Nothing to do but excludes this case from the subsequent ifs
pass
elif hasattr(var,'__getitem__'):
if len(indices)==1:
return var[indices[0]]
else:
return self._visit(var[indices[0]][indices[1:]])
elif isinstance(var, (PyccelFunction, FunctionCall)):
pyccel_stage.set_stage('syntactic')
tmp_var = PyccelSymbol(self.scope.get_new_name())
assign = Assign(tmp_var, var)
assign.set_current_ast(expr.python_ast)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(assign))
var.remove_user_node(assign)
var = self._visit(tmp_var)
else:
errors.report(f"Can't index {type(var)}", symbol=expr,
severity='fatal')
indices = tuple(indices)
if isinstance(var.class_type, InhomogeneousTupleType):
arg = indices[0]
if isinstance(arg, Slice):
if ((arg.start is not None and not is_literal_integer(arg.start)) or
(arg.stop is not None and not is_literal_integer(arg.stop))):
errors.report(INDEXED_TUPLE, symbol=var,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
idx = slice(arg.start, arg.stop)
orig_vars = [self.scope.collect_tuple_element(v) for v in var]
selected_vars = orig_vars[idx]
if len(indices)==1:
return PythonTuple(*selected_vars)
else:
return PythonTuple(*[self._extract_indexed_from_var(var, indices[1:], expr) for var in selected_vars])
elif isinstance(arg, LiteralInteger):
if len(indices)==1:
return self.scope.collect_tuple_element(var[arg])
var = var[arg]
return self._extract_indexed_from_var(var, indices[1:], expr)
else:
errors.report(INDEXED_TUPLE, symbol=var,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
if isinstance(var, PythonTuple) and not var.is_homogeneous:
errors.report(LIST_OF_TUPLES, symbol=var,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='error')
for arg in var[indices].indices:
if not isinstance(arg, (Slice, LiteralEllipsis)) and not (hasattr(arg, 'dtype') and
isinstance(getattr(arg.dtype, 'primitive_type', None), PrimitiveIntegerType)):
errors.report(INVALID_INDICES, symbol=var[indices],
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='error')
return var[indices]
def _create_PyccelOperator(self, expr, visited_args):
"""
Create a PyccelOperator.
Create a PyccelOperator by passing the visited arguments
to the class.
Called by _visit_PyccelOperator and other classes
inheriting from PyccelOperator.
Parameters
----------
expr : PyccelOperator
The expression being visited.
visited_args : tuple of TypedAstNode
The arguments passed to the operator.
Returns
-------
PyccelOperator
The new operator.
"""
arg1 = visited_args[0]
if all(isinstance(a, PyccelFunctionDef) for a in visited_args):
try:
possible_types = [a.cls_name.static_type() for a in visited_args]
except AttributeError:
errors.report("Unrecognised type in type union statement",
severity='fatal', symbol=expr)
return UnionTypeAnnotation(*[VariableTypeAnnotation(t) for t in possible_types])
class_type = arg1.class_type
class_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)
magic_method_name = magic_method_map.get(type(expr), None)
magic_method = None
if magic_method_name:
magic_method = class_base.get_method(magic_method_name)
if magic_method is None:
arg2 = visited_args[1]
class_type = arg2.class_type
class_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)
magic_method_name = '__r'+magic_method_name[2:]
magic_method = class_base.get_method(magic_method_name)
if magic_method:
visited_args = [visited_args[1], visited_args[0]]
if magic_method:
expr_new = self._handle_function(expr, magic_method, [FunctionCallArgument(v) for v in visited_args])
else:
try:
expr_new = type(expr)(*visited_args)
except PyccelSemanticError as err:
errors.report(str(err), symbol=expr, severity='fatal')
except TypeError as err:
types = ', '.join(str(a.class_type) for a in visited_args)
errors.report(f"Operator {type(expr)} between objects of type ({types}) is not yet handled\n"
+ PYCCEL_RESTRICTION_TODO, symbol=expr, severity='fatal',
traceback = err.__traceback__)
return expr_new
def _create_Duplicate(self, val, length):
"""
Create a node which duplicates a tuple.
Create a node which duplicates a tuple.
Called by _visit_PyccelMul when a Duplicate is identified.
Parameters
----------
val : PyccelAstNode
The tuple object. This object should have a class type which inherits from
TupleType.
length : LiteralInteger | TypedAstNode
The number of times the tuple is duplicated.
Returns
-------
Duplicate | PythonTuple
The duplicated tuple.
"""
# Arguments have been visited in PyccelMul
if not isinstance(val.class_type, (TupleType, HomogeneousListType)):
errors.report("Unexpected Duplicate", symbol=Duplicate(val, length),
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
if isinstance(val.class_type, (HomogeneousTupleType, HomogeneousListType)):
return Duplicate(val, length)
else:
if isinstance(length, LiteralInteger):
length = length.python_value
else:
symbol_map = {}
used_symbols = set()
sympy_length = pyccel_to_sympy(length, symbol_map, used_symbols)
if isinstance(sympy_length, sp_Integer):
length = int(sympy_length)
else:
errors.report("Cannot create inhomogeneous tuple of unknown size",
symbol=Duplicate(val, length),
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
return PythonTuple(*([self.scope.collect_tuple_element(v) for v in val]*length))
def _create_class_destructor(self, expr):
"""
Create the class destructor.
Create the class destructor. This is important to ensure that the data in the
class is correctly deallocated. If the class already has a destructor then the
deallocations are added to the existing destructor. Similarly a flag is added
to the init method to act as a guard in the class to tell if the destructor has
been called.
Parameters
----------
expr : ClassDef
The class that implicit __init__ and __del__ methods should be created for.
"""
class_type = expr.class_type
methods = expr.methods
cls_scope = expr.scope
init_func = cls_scope.functions['__init__']
if isinstance(init_func, Interface):
errors.report("Pyccel does not support interface constructor", symbol=init_func,
severity='fatal')
# create a new attribute to check allocation
deallocater_lhs = Variable(class_type, 'self', cls_base = expr, is_argument=True)
deallocater = DottedVariable(lhs = deallocater_lhs, name = self.scope.get_new_name('is_freed'),
class_type = PythonNativeBool(), is_private=True)
expr.add_new_attribute(deallocater)
deallocater_assign = Assign(deallocater, LiteralFalse())
init_func.body.insert2body(deallocater_assign, back=False)
del_method = next((method for method in methods if method.name == '__del__'), None)
if del_method is None:
argument = FunctionDefArgument(Variable(class_type, 'self', cls_base = expr), bound_argument = True)
del_name = cls_scope.get_new_name('__del__')
scope = self.create_new_function_scope('__del__', del_name)
scope.insert_variable(argument.var)
self.exit_function_scope()
del_method = FunctionDef(del_name, [argument], [Pass()], scope=scope)
self.insert_function(del_method, cls_scope)
expr.add_new_method(del_method)
else:
assert del_method.is_semantic
# Add destructors to __del__ method
self._current_function_name.append(del_method.name)
attribute = []
for attr in expr.attributes:
if not attr.on_stack:
attribute.append(attr)
elif isinstance(attr.class_type, CustomDataType) and not attr.is_alias:
attribute.append(attr)
if attribute:
# Create a new list that store local attributes
self._allocs.append(set())
self._pointer_targets.append({})
self._allocs[-1].update(attribute)
del_method.body.insert2body(*self._garbage_collector(del_method.body))
self._pointer_targets.pop()
condition = If(IfSection(PyccelNot(deallocater),
[del_method.body]+[Assign(deallocater, LiteralTrue())]))
del_method.body = [condition]
self._current_function_name.pop()
def _handle_function_args(self, arguments):
"""
Get a list of all function arguments.
Get a list of all the function arguments which are passed
to a function. This is done by visiting the syntactic
FunctionCallArguments. If this argument contains a
starred arguments object then the contents of this object
are extracted into the final list.
Parameters
----------
arguments : list of FunctionCallArgument
The arguments which were passed to the function.
Returns
-------
list of FunctionCallArgument
The arguments passed to the function.
"""
args = []
for arg in arguments:
a = self._visit(arg)
val = a.value
if isinstance(val, FunctionDef) and not isinstance(val, PyccelFunctionDef) and not val.is_semantic:
semantic_func = self._annotate_the_called_function_def(val, ())
a = FunctionCallArgument(semantic_func, keyword = a.keyword, python_ast = a.python_ast)
if isinstance(val, StarredArguments):
args.extend([FunctionCallArgument(av) for av in val.args_var])
else:
args.append(a)
return args
def _check_argument_compatibility(self, input_args, func_args, func, elemental, raise_error=True, error_type='error'):
"""
Check that the provided arguments match the expected types.
Check that the provided arguments match the expected types.
Parameters
----------
input_args : list
The arguments provided to the function.
func_args : list
The arguments expected by the function.
func : FunctionDef
The called function (used for error output).
elemental : bool
Indicates if the function is elemental.
raise_error : bool, default : True
Raise the error if the arguments are incompatible.
error_type : str, default : error
The error type if errors are raised from the function.
Returns
-------
bool
Return True if the arguments are compatible, False otherwise.
"""
if elemental:
def incompatible(i_arg, f_arg):
return i_arg.class_type.datatype != f_arg.class_type.datatype
else:
def incompatible(i_arg, f_arg):
return i_arg.class_type != f_arg.class_type
err_msgs = []
# Compare each set of arguments
for idx, (i_arg, f_arg) in enumerate(zip(input_args, func_args)):
i_arg = i_arg.value
f_arg = f_arg.var
# Ignore types which cannot be compared
if (i_arg is Nil()
or isinstance(f_arg, FunctionAddress)
or f_arg.class_type is GenericType()):
continue
# Check for compatibility
if incompatible(i_arg, f_arg):
expected = str(f_arg.class_type)
type_name = str(i_arg.class_type)
received = f'{i_arg} ({type_name})'
err_msgs += [INCOMPATIBLE_ARGUMENT.format(idx+1, received, func, expected)]
if err_msgs:
if raise_error:
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset)
errors.report('\n\n'.join(err_msgs), symbol = func, bounding_box=bounding_box, severity=error_type)
else:
return False
return True
def _handle_function(self, expr, func, args, *, is_method = False, use_build_functions = True):
"""
Create the node representing the function call.
Create a FunctionCall or an instance of a PyccelFunction
from the function information and arguments.
Parameters
----------
expr : TypedAstNode
The expression where this call is found (used for error output).
func : FunctionDef | Interface
The function being called.
args : iterable
The arguments passed to the function.
is_method : bool, default = False
Indicates if the function is a class method.
use_build_functions : bool, default = True
In `func` is a PyccelFunctionDef, indicates that the `_build_X` methods should
be used. This is almost always true but may be false if this function is called
from a `_build_X` method.
Returns
-------
FunctionCall/PyccelFunction
The semantic representation of the call.
"""
if isinstance(func, PyccelFunctionDef):
if use_build_functions:
annotation_method = '_build_' + func.cls_name.__name__
if hasattr(self, annotation_method):
if isinstance(expr, DottedName):
pyccel_stage.set_stage('syntactic')
if is_method:
new_expr = DottedName(args[0].value, FunctionCall(func, args[1:]))
else:
new_expr = FunctionCall(func, args)
new_expr.set_current_ast(expr.python_ast or self._current_ast_node)
pyccel_stage.set_stage('semantic')
for u in expr.get_all_user_nodes():
new_expr.set_current_user_node(u)
expr = new_expr
return getattr(self, annotation_method)(expr, args)
argument_description = func.argument_description
func = func.cls_name
args, kwargs = split_positional_keyword_arguments(*args)
# Ignore values passed by position but add any unspecified keywords
# with the correct default value
for kw, val in list(argument_description.items())[len(args):]:
if kw not in kwargs:
kwargs[kw] = val
try:
new_expr = func(*args, **kwargs)
except TypeError as e:
message = str(e)
if not message:
message = UNRECOGNISED_FUNCTION_CALL
errors.report(message,
symbol = expr,
traceback = e.__traceback__,
severity = 'fatal')
return new_expr
else:
is_inline = func.is_inline if isinstance(func, FunctionDef) else False
if is_inline:
return self._visit_InlineFunctionDef(func, args, expr)
elif not func.is_semantic:
func = self._annotate_the_called_function_def(func, args)
if self.current_function_name == func.name:
if func.results and not isinstance(func.results.var, TypedAstNode):
errors.report(RECURSIVE_RESULTS_REQUIRED, symbol=func, severity="fatal")
parent_assign = expr.get_direct_user_nodes(lambda x: isinstance(x, Assign) and not isinstance(x, AugAssign))
func_results = func.results if isinstance(func, FunctionDef) else func.functions[0].results
if not parent_assign and func_results.var.rank > 0:
pyccel_stage.set_stage('syntactic')
tmp_var = PyccelSymbol(self.scope.get_new_name())
assign = Assign(tmp_var, expr)
assign.set_current_ast(expr.python_ast)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(assign))
return self._visit(tmp_var)
func_args = func.arguments if isinstance(func,FunctionDef) else func.functions[0].arguments
if len(args) > len(func_args):
errors.report("Too many arguments passed in function call",
symbol = expr,
severity='fatal')
new_expr = FunctionCall(func, args, self.current_function_name)
for a, f_a in zip(new_expr.args, func_args):
if f_a.persistent_target:
assert is_method
val = a.value
if isinstance(val, Variable):
a.value.is_target = True
self._indicate_pointer_target(args[0].value, a.value, expr)
else:
errors.report(f"{val} cannot be passed to function call as target. Please create a temporary variable.",
severity='error', symbol=expr)
if None in new_expr.args:
errors.report("Too few arguments passed in function call",
symbol = expr,
severity='error')
elif isinstance(func, FunctionDef):
self._check_argument_compatibility(args, func_args,
func, func.is_elemental)
return new_expr
def _sort_function_call_args(self, func_args, args):
"""
Sort and add the missing call arguments to match the arguments in the function definition.
We sort the call arguments by dividing them into two chunks, positional arguments and keyword arguments.
We provide the default value of the keyword argument if the corresponding call argument is not present.
Parameters
----------
func_args : list[FunctionDefArgument]
The arguments of the function definition.
args : list[FunctionCallArgument]
The arguments of the function call.
Returns
-------
list[FunctionCallArgument]
The sorted and complete call arguments.
"""
input_args = [a for a in args if a.keyword is None]
nargs = len(input_args)
for ka in func_args[nargs:]:
key = ka.name
relevant_args = [a for a in args[nargs:] if a.keyword == key]
n_relevant_args = len(relevant_args)
assert n_relevant_args <= 1
if n_relevant_args == 0 and ka.has_default:
input_args.append(ka.default_call_arg)
elif n_relevant_args == 1:
input_args.append(relevant_args[0])
return input_args
def _annotate_the_called_function_def(self, old_func, function_call_args):
"""
Annotate the called FunctionDef.
Annotate the called FunctionDef.
Parameters
----------
old_func : FunctionDef|Interface
The function that needs to be annotated.
function_call_args : list[FunctionCallArgument]
The list of the call arguments.
Returns
-------
func: FunctionDef|Interface
The new annotated function.
"""
assert not old_func.is_inline
cls_base_syntactic = old_func.get_direct_user_nodes(lambda p: isinstance(p, ClassDef))
if cls_base_syntactic:
cls_name = cls_base_syntactic[0].name
cls_base = self.scope.find(cls_name, 'classes')
cls_scope = cls_base.scope
new_scope = cls_scope
else:
func_scope = old_func.scope if isinstance(old_func, FunctionDef) else old_func.syntactic_node.scope
new_scope = func_scope
# The function call might be in a completely different scope from the FunctionDef
# Store the current scope and go to the parent scope of the FunctionDef
old_scope = self._scope
old_current_function = self._current_function
old_current_function_name = self._current_function_name
# Walk up scope to root to find names of relevant scopes
scope_names = []
while new_scope.parent_scope is not None:
new_scope = new_scope.parent_scope
if not new_scope.name is None:
scope_names.append(new_scope.name)
# Use scope_names to find semantic scopes
for n in scope_names[::-1]:
new_scope = new_scope.sons_scopes[n]
# Set the Scope to the FunctionDef's parent Scope and annotate the old_func
self._scope = new_scope
self._visit(old_func)
# Retrieve the annotated function
if cls_base_syntactic:
new_name = cls_scope.get_expected_name(old_func.name)
func = cls_scope.find(new_name, 'functions')
else:
new_name = self.scope.get_expected_name(old_func.name)
func = self.scope.find(new_name, 'functions')
assert func is not None
# Add the Module of the imported function to the new function
if old_func.is_imported:
mod = old_func.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
func.set_current_user_node(mod)
# Go back to the original Scope
self._scope = old_scope
self._current_function_name = old_current_function_name
self._current_function = old_current_function
# Remove the old_func from the imports dict and Assign the new annotated one
if old_func.is_imported:
scope = self.scope
while new_name not in scope.imports['functions']:
scope = scope.parent_scope
assert old_func is scope.imports['functions'].get(new_name)
func = func.clone(new_name, is_imported=True)
func.set_current_user_node(mod)
scope.imports['functions'][new_name] = func
return func
def _create_variable(self, name, class_type, rhs, d_lhs, *, arr_in_multirets=False,
insertion_scope = None, rhs_scope = None):
"""
Create a new variable.
Create a new variable. In most cases this is just a call to
`Variable.__init__`
but in the case of a tuple variable it is a recursive call to
create all elements in the tuple.
This is done separately to _assign_lhs_variable to ensure that
elements of a tuple do not exist in the scope.
Parameters
----------
name : str
The name of the new variable.
class_type : PyccelType
The type of the new variable.
rhs : Variable
The value assigned to the lhs. This is required to call
self._infer_type recursively for tuples.
d_lhs : dict
Dictionary of properties for the new Variable.
arr_in_multirets : bool, default: False
If True, the variable that will be created is an array
in multi-values return, false otherwise.
insertion_scope : Scope, optional
The scope where the variable will be inserted. This is used to add any
symbolic aliases for inhomogeneous tuples.
rhs_scope : Scope, optional
The scope where the definition of the right hand side is found. This
is used to locate any symbolic aliases for inhomogeneous tuples. It is
necessary for tuples of tuples as function results.
Returns
-------
Variable
The variable that has been created.
"""
if isinstance(name, PyccelSymbol):
is_temp = name.is_temp
else:
is_temp = False
if insertion_scope is None:
insertion_scope = self.scope
if isinstance(class_type, InhomogeneousTupleType):
if rhs_scope is None:
rhs_scope = self.scope
if isinstance(rhs, FunctionCall):
rhs_scope = rhs.funcdef.scope
iterable = [rhs_scope.collect_tuple_element(v) for v in rhs.funcdef.results.var]
elif isinstance(rhs, PyccelFunction):
iterable = [IndexedElement(rhs, i) for i in range(rhs.shape[0])]
else:
iterable = [rhs_scope.collect_tuple_element(r) for r in rhs]
elem_vars = []
for i,tuple_elem in enumerate(iterable):
# Check if lhs element was named in the syntactic stage (this can happen for
# results of functions)
pyccel_stage.set_stage('syntactic')
idx_name = IndexedElement(name, i)
var = None
if idx_name in self.scope.symbolic_aliases:
elem_name = self.scope.symbolic_aliases[idx_name]
var = self.check_for_variable(elem_name)
else:
elem_name = self.scope.get_new_name( f'{name}_{i}' )
pyccel_stage.set_stage('semantic')
if var is None:
elem_d_lhs = self._infer_type( tuple_elem )
if not arr_in_multirets:
self._ensure_target( tuple_elem, elem_d_lhs )
elem_type = elem_d_lhs.pop('class_type')
var = self._create_variable(elem_name, elem_type, tuple_elem, elem_d_lhs,
insertion_scope = insertion_scope, rhs_scope = rhs_scope)
elem_vars.append(var)
if any(v.is_alias for v in elem_vars):
d_lhs['memory_handling'] = 'alias'
lhs = Variable(class_type, name, **d_lhs, is_temp=is_temp)
for i, v in enumerate(elem_vars):
insertion_scope.insert_symbolic_alias(IndexedElement(lhs, i), v)
else:
lhs = Variable(class_type, name, **d_lhs, is_temp=is_temp)
return lhs
def _ensure_target(self, rhs, d_lhs):
"""
Function using data about the new lhs.
Function using data about the new lhs to determine
whether the lhs is an alias and the rhs is a target.
Parameters
----------
rhs : TypedAstNode
The value assigned to the lhs.
d_lhs : dict
Dictionary of properties for the new Variable.
"""
# rhs is None in an AugAssign
if rhs is None or isinstance(rhs, FunctionalFor):
return
assert rhs.pyccel_staging != 'syntactic'
if isinstance(rhs, NumpyTranspose) and rhs.internal_var.on_heap:
d_lhs['memory_handling'] = 'alias'
rhs.internal_var.is_target = True
if not isinstance(rhs.class_type, (TupleType, StringType, FixedSizeNumericType)):
if isinstance(rhs, Variable):
d_lhs['memory_handling'] = 'alias'
rhs.is_target = not rhs.is_alias
elif isinstance(rhs, IndexedElement) and \
isinstance(rhs.class_type, (HomogeneousTupleType, NumpyNDArrayType)):
d_lhs['memory_handling'] = 'alias'
rhs.base.is_target = not rhs.base.is_alias
elif isinstance(rhs, IndexedElement) and not rhs.is_slice:
d_lhs['memory_handling'] = 'alias'
elif isinstance(rhs, (DictPop, DictPopitem, ListPop)):
target_var = rhs.list_obj if isinstance(rhs, ListPop) else rhs.dict_obj
if target_var in self._pointer_targets[-1]:
d_lhs['memory_handling'] = 'alias'
def _assign_lhs_variable(self, lhs, d_var, rhs, new_expressions, is_augassign = False,
arr_in_multirets=False):
"""
Create a variable from the left-hand side (lhs) of an assignment.
Create a lhs based on the information in d_var, if the lhs already exists
then check that it has the expected properties.
Parameters
----------
lhs : PyccelSymbol (or DottedName of PyccelSymbols)
The representation of the lhs provided by the SyntacticParser.
d_var : dict
Dictionary of expected lhs properties.
rhs : Variable / expression
The representation of the rhs provided by the SemanticParser.
This is necessary in order to set the rhs 'is_target' property
if necessary. It is also used to determine the type of allocation
(init/resize/reserve).
new_expressions : list
A list which allows collection of any additional expressions
resulting from this operation (e.g. Allocation).
is_augassign : bool, default=False
Indicates whether this is an assign ( = ) or an augassign ( += / -= / etc )
This is necessary as the restrictions on the dtype are less strict in this
case.
arr_in_multirets : bool, default=False
If True, rhs has an array in its results, otherwise, it should be set to False.
It helps when we don't need lhs to be a pointer in case of a returned array in
a tuple of results.
Returns
-------
pyccel.ast.variable.Variable
The representation of the lhs provided by the SemanticParser.
"""
if isinstance(lhs, IndexedElement):
lhs = self._visit(lhs)
elif isinstance(lhs, (PyccelSymbol, DottedName)):
name = lhs
if lhs == '_':
name = self.scope.get_new_name()
class_type = d_var.pop('class_type')
d_lhs = d_var.copy()
# ISSUES #177: lhs must be a pointer when rhs is heap array
if not arr_in_multirets:
self._ensure_target(rhs, d_lhs)
if isinstance(lhs, DottedName):
prefix = self.get_class_prefix(lhs)
class_def = prefix.cls_base
attr_name = lhs.name[-1]
attribute = class_def.scope.variables.get(attr_name, None) \
if class_def else None
if attribute:
var = attribute.clone(attribute.name, new_class = DottedVariable, lhs = prefix)
else:
var = None
else:
symbolic_var = self.scope.find(lhs, 'symbolic_aliases')
if symbolic_var:
errors.report(f"{lhs} variable represents a symbolic concept. Its value cannot be changed.",
severity='fatal')
var = self.scope.find(lhs)
# Variable not yet declared (hence array not yet allocated)
if var is None:
if isinstance(lhs, DottedName):
prefix_parts = lhs.name[:-1]
syntactic_prefix = prefix_parts[0] if len(prefix_parts) == 1 else DottedName(*prefix_parts)
prefix = self._visit(syntactic_prefix)
class_def = prefix.cls_base
if prefix.name == 'self':
var = self.get_variable('self')
# Collect the name that should be used in the generated code
attribute_name = lhs.name[-1]
new_name = class_def.scope.get_expected_name(attribute_name)
# Create the attribute
member = self._create_variable(new_name, class_type, rhs, d_lhs,
insertion_scope = class_def.scope)
# Insert the attribute to the class scope
# Passing the original name ensures that the attribute can be found under this name
class_def.scope.insert_variable(member, attribute_name)
lhs = self.insert_attribute_to_class(class_def, var, member)
else:
errors.report(f"{lhs.name[0]} should be named : self", symbol=lhs, severity='fatal')
# Update variable's dictionary with information from function decorators
decorators = self.scope.decorators
if decorators:
if 'stack_array' in decorators:
if name in decorators['stack_array']:
d_lhs.update(memory_handling='stack')
if 'allow_negative_index' in decorators:
if lhs in decorators['allow_negative_index']:
d_lhs.update(allows_negative_indexes=True)
# We cannot allow the definition of a stack array from a shape which
# is unknown at the declaration
if class_type.rank > 0 and d_lhs.get('memory_handling', None) == 'stack':
for a in d_lhs['shape']:
if (isinstance(a, FunctionCall) and not a.funcdef.is_pure) or \
any(not f.funcdef.is_pure for f in a.get_attribute_nodes(FunctionCall)):
errors.report(STACK_ARRAY_SHAPE_UNPURE_FUNC, symbol=a.funcdef.name,
severity='error',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
if (isinstance(a, Variable) and not a.is_argument) \
or not all(b.is_argument for b in a.get_attribute_nodes(Variable)):
errors.report(STACK_ARRAY_UNKNOWN_SHAPE, symbol=name,
severity='error',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
if not isinstance(lhs, DottedVariable):
new_name = self.scope.get_expected_name(name)
# Create new variable
lhs = self._create_variable(new_name, class_type, rhs, d_lhs, arr_in_multirets=arr_in_multirets)
# Add variable to scope
self.scope.insert_variable(lhs, name)
# ...
# Add memory allocation if needed
array_declared_in_function = (isinstance(rhs, FunctionCall) and not isinstance(rhs.funcdef, PyccelFunctionDef) \
and not getattr(rhs.funcdef, 'is_elemental', False) and \
not isinstance(lhs.class_type, HomogeneousTupleType)) or arr_in_multirets or \
isinstance(rhs, (ListPop, SetPop, DictPop, DictPopitem, DictGet, DictGetItem))
if lhs.on_heap and not array_declared_in_function:
if self.scope.is_loop:
# Array defined in a loop may need reallocation at every cycle
errors.report(ARRAY_DEFINITION_IN_LOOP, symbol=name,
severity='warning',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
status='unknown'
else:
# Array defined outside of a loop will be allocated only once
status='unallocated'
# Create Allocate node
if isinstance(lhs.class_type, InhomogeneousTupleType):
args = [self.scope.collect_tuple_element(v) for v in lhs if v.rank>0]
new_args = []
while len(args) > 0:
for a in args:
if isinstance(a.class_type, InhomogeneousTupleType):
new_args.extend(self.scope.collect_tuple_element(v) for v in a if v.rank>0)
elif a.rank > 0:
new_expressions.append(Allocate(a,
shape=a.alloc_shape, status=status))
args = new_args
new_args = []
elif isinstance(lhs.class_type, (HomogeneousListType, HomogeneousSetType,DictType)):
if isinstance(rhs, (PythonList, PythonDict, PythonSet, FunctionCall)):
alloc_type = 'init'
elif isinstance(rhs, IndexedElement) or rhs.get_attribute_nodes(IndexedElement):
alloc_type = 'resize'
else:
alloc_type = 'reserve'
new_expressions.append(Allocate(lhs, shape=lhs.alloc_shape, status=status, alloc_type=alloc_type))
else:
new_expressions.append(Allocate(lhs, shape=lhs.alloc_shape, status=status))
# ...
# ...
# Add memory deallocation
if isinstance(lhs.class_type, CustomDataType) or not lhs.on_stack:
if isinstance(lhs.class_type, InhomogeneousTupleType):
args = [self.scope.collect_tuple_element(v) for v in lhs if v.rank>0]
new_args = []
while len(args) > 0:
for a in args:
if isinstance(a.class_type, InhomogeneousTupleType):
new_args.extend(self.scope.collect_tuple_element(v) for v in a if v.rank>0)
else:
self._allocs[-1].add(a)
args = new_args
new_args = []
else:
self._allocs[-1].add(lhs)
# ...
# We cannot allow the definition of a stack array in a loop
if lhs.is_stack_array and self.scope.is_loop:
errors.report(STACK_ARRAY_DEFINITION_IN_LOOP, symbol=name,
severity='error',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
# Not yet supported for arrays: x=y+z, x=b[:]
# Because we cannot infer shape of right-hand side yet
if array_declared_in_function:
know_lhs_shape = True
elif isinstance(lhs.dtype, StringType):
know_lhs_shape = (lhs.rank == 1) or all(sh is not None for sh in lhs.alloc_shape[:-1])
else:
know_lhs_shape = (lhs.rank == 0) or all(sh is not None for sh in lhs.alloc_shape)
if isinstance(class_type, (NumpyNDArrayType, HomogeneousTupleType)) and not know_lhs_shape \
and not array_declared_in_function:
msg = f"Cannot infer shape of right-hand side for expression {lhs} = {rhs}"
errors.report(PYCCEL_RESTRICTION_TODO+'\n'+msg,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
# Variable already exists
else:
self._ensure_inferred_type_matches_existing(class_type, d_lhs, var, is_augassign, new_expressions, rhs)
# in the case of elemental, lhs is not of the same class_type as
# var.
# TODO d_lhs must be consistent with var!
# the following is a small fix, since lhs must be already
# declared
if isinstance(lhs, DottedName):
lhs = var.clone(var.name, new_class = DottedVariable, lhs = self._visit(lhs.name[0]))
else:
lhs = var
else:
lhs_type = str(type(lhs))
raise NotImplementedError(f"_assign_lhs_variable does not handle {lhs_type}")
return lhs
def _ensure_inferred_type_matches_existing(self, class_type, d_var, var, is_augassign, new_expressions, rhs):
"""
Ensure that the inferred type matches the existing variable.
Ensure that the inferred type of the new variable, matches the existing variable (which has the
same name). If this is not the case then errors are raised preventing pyccel reaching the codegen
stage.
This function also handles any reallocations caused by differing shapes between the two objects.
These allocations/deallocations are saved in the list new_expressions
Parameters
----------
class_type : PyccelType
The inferred PyccelType.
d_var : dict
The inferred information about the variable. Usually created by the _infer_type function.
var : Variable
The existing variable.
is_augassign : bool
A boolean indicating if the assign statement is an augassign (tests are less strict).
new_expressions : list
A list to which any new expressions created are appended.
rhs : TypedAstNode
The right hand side of the expression : lhs=rhs.
If is_augassign is False, this value is not used.
"""
# TODO improve check type compatibility
if not isinstance(var, Variable):
name = var.name
message = INCOMPATIBLE_TYPES_IN_ASSIGNMENT.format(type(var), class_type)
if var.pyccel_staging == "syntactic":
new_name = self.scope.get_expected_name(name)
if new_name != name:
message += '\nThis error may be due to object renaming to avoid name clashes (language-specific or otherwise).'
message += f'The conflict is with "{name}".'
name = new_name
errors.report(message,
symbol=f'{name}={class_type}',
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
if not is_augassign and var.is_ndarray and var.is_target:
errors.report(ARRAY_ALREADY_IN_USE,
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset),
severity='error', symbol=var.name)
return
elif not is_augassign and not var.is_alias and var.rank > 0 and \
isinstance(rhs, (Variable, IndexedElement)) and \
not isinstance(var.class_type, (StringType, TupleType)):
errors.report(ASSIGN_ARRAYS_ONE_ANOTHER,
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset),
severity='error', symbol=var)
return
elif var.rank > 0 and var.is_alias and isinstance(rhs, (NumpyNewArray, PythonList, PythonSet, PythonDict)):
errors.report(INVALID_POINTER_REASSIGN,
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset),
severity='error', symbol=var.name)
return
elif var.is_ndarray and var.is_alias and not is_augassign:
# we allow pointers to be reassigned multiple times
# pointers reassigning need to call free_pointer func
# to remove memory leaks
new_expressions.append(Deallocate(var))
return
elif class_type != var.class_type:
if is_augassign:
tmp_result = PyccelAdd(var, rhs)
result_type = tmp_result.class_type
raise_error = var.class_type != result_type
elif isinstance(var.class_type, InhomogeneousTupleType) and \
isinstance(class_type, HomogeneousTupleType):
if d_var['shape'][0] == var.shape[0]:
rhs_elem = self.scope.collect_tuple_element(var[0])
self._ensure_inferred_type_matches_existing(class_type.element_type,
self._infer_type(rhs_elem), rhs_elem, is_augassign, new_expressions, rhs)
raise_error = False
else:
raise_error = True
elif isinstance(var.class_type, InhomogeneousTupleType) and \
isinstance(class_type, InhomogeneousTupleType):
for i, element_type in enumerate(class_type):
rhs_elem = self.scope.collect_tuple_element(var[i])
self._ensure_inferred_type_matches_existing(element_type,
self._infer_type(rhs_elem), rhs_elem, is_augassign, new_expressions, rhs)
raise_error = False
elif isinstance(var.class_type, HomogeneousTupleType) and \
isinstance(class_type, InhomogeneousTupleType):
# TODO: Remove isinstance(rhs, Variable) condition when tuples are saved like lists
if isinstance(rhs, PythonTuple):
shape = get_shape_of_multi_level_container(rhs)
raise_error = len(shape) != class_type.rank or any(a != var.class_type.element_type for a in class_type)
else:
raise_error = any(a != var.class_type.element_type for a in class_type) or \
not isinstance(rhs, Variable)
else:
raise_error = True
if raise_error:
name = var.name
rhs_str = str(rhs)
errors.report(INCOMPATIBLE_TYPES_IN_ASSIGNMENT.format(var.class_type, class_type),
symbol=f'{name}={rhs_str}',
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='error')
return
if not is_augassign:
shape = var.shape
# Get previous allocation calls
previous_allocations = var.get_direct_user_nodes(lambda p: isinstance(p, Allocate))
if len(previous_allocations) == 0:
var.set_init_shape(d_var['shape'])
if d_var['shape'] != shape:
if var.is_argument:
errors.report(ARRAY_IS_ARG, symbol=var,
severity='error',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
elif var.is_stack_array:
if var.get_direct_user_nodes(lambda a: isinstance(a, Assign) and a.lhs is var):
errors.report(INCOMPATIBLE_REDEFINITION_STACK_ARRAY, symbol=var.name,
severity='error',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
else:
alloc_type = None
if isinstance(var.class_type, (HomogeneousListType, HomogeneousSetType,DictType)):
if isinstance(rhs, (PythonList, PythonDict, PythonSet, FunctionCall)):
alloc_type = 'init'
elif isinstance(rhs, IndexedElement) or rhs.get_attribute_nodes(IndexedElement):
alloc_type = 'resize'
else:
alloc_type = 'reserve'
if previous_allocations:
var.set_changeable_shape()
last_allocation = previous_allocations[-1]
# Find outermost IfSection of last allocation
last_alloc_ifsection = last_allocation.get_user_nodes(IfSection)
alloc_ifsection = last_alloc_ifsection[-1] if last_alloc_ifsection else None
while len(last_alloc_ifsection)>0:
alloc_ifsection = last_alloc_ifsection[-1]
last_alloc_ifsection = alloc_ifsection.get_user_nodes(IfSection)
ifsection_has_if = len(alloc_ifsection.get_direct_user_nodes(
lambda x: isinstance(x,If))) == 1 \
if alloc_ifsection else False
if alloc_ifsection and not ifsection_has_if:
status = last_allocation.status
elif last_allocation.get_user_nodes((If, For, While)):
status='unknown'
else:
status='allocated'
else:
status = 'unallocated'
new_expressions.append(Allocate(var, shape=d_var['shape'], status=status, alloc_type=alloc_type))
if status == 'unallocated':
self._allocs[-1].add(var)
elif isinstance(var.class_type, NumpyNDArrayType):
errors.report(ARRAY_REALLOCATION.format(class_type = var.class_type), symbol=var.name,
severity='warning',
bounding_box=(self.current_ast_node.lineno,
self.current_ast_node.col_offset))
elif previous_allocations and previous_allocations[-1].get_user_nodes(IfSection) \
and not previous_allocations[-1].get_user_nodes((If)):
# If previously allocated in If still under construction
status = previous_allocations[-1].status
new_expressions.append(Allocate(var, shape=d_var['shape'], status=status))
elif isinstance(var.class_type, CustomDataType) and not var.is_alias:
new_expressions.append(Deallocate(var))
def _assign_GeneratorComprehension(self, lhs_name, expr):
"""
Visit the GeneratorComprehension node.
Create all necessary expressions for the
GeneratorComprehension node definition.
Parameters
----------
lhs_name : str
The name to which the expression is assigned.
expr : GeneratorComprehension
The GeneratorComprehension node.
Returns
-------
pyccel.ast.functionalexpr.GeneratorComprehension
CodeBlock containing the semantic version of the GeneratorComprehension node.
"""
result = expr.expr
loop = expr.loops
nlevels = 0
# Create throw-away variable to help obtain result type
index = Variable(PythonNativeInt(),self.scope.get_new_name('to_delete'), is_temp=True)
self.scope.insert_variable(index)
new_expr = []
while isinstance(loop, (For, If)):
nlevels+=1
self._get_for_iterators(loop.iterable, loop.target, new_expr, expr)
loop_elem = loop.body.body[0]
if isinstance(loop_elem, If):
loop_elem = loop_elem.blocks[0].body.body[0]
if isinstance(loop_elem, Assign):
# If the result contains a GeneratorComprehension, treat it and replace
# it with it's lhs variable before continuing
gens = set(loop_elem.get_attribute_nodes(GeneratorComprehension))
if len(gens)==1:
gen = gens.pop()
pyccel_stage.set_stage('syntactic')
assert isinstance(gen.lhs, PyccelSymbol) and gen.lhs.is_temp
gen_lhs = self.scope.get_new_name() if gen.lhs.is_temp else gen.lhs
syntactic_assign = Assign(gen_lhs, gen, python_ast=gen.python_ast)
pyccel_stage.set_stage('semantic')
assign = self._visit(syntactic_assign)
new_expr.append(assign)
loop.substitute(gen, assign.lhs)
loop_elem = loop.body.body[0]
loop = loop_elem
# Remove the throw-away variable from the scope
self.scope.remove_variable(index)
# Visit result expression (correctly defined as iterator
# objects exist in the scope despite not being defined)
result = self._visit(result)
if isinstance(result, CodeBlock):
result = result.body[-1]
# Create start value
if isinstance(expr, FunctionalSum):
dtype = result.dtype
if isinstance(dtype, PythonNativeBool):
val = LiteralInteger(0, dtype)
else:
val = convert_to_literal(0, dtype)
d_var = self._infer_type(PyccelAdd(result, val))
elif isinstance(expr, FunctionalMin):
d_var = self._infer_type(result)
val = MaxLimit(d_var['class_type'])
elif isinstance(expr, FunctionalMax):
d_var = self._infer_type(result)
val = MinLimit(d_var['class_type'])
# Infer the final dtype of the expression
class_type = d_var.pop('class_type')
d_var['is_temp'] = expr.lhs.is_temp
lhs = self.check_for_variable(lhs_name)
if lhs:
self._ensure_inferred_type_matches_existing(class_type, d_var, lhs, False, new_expr, None)
else:
lhs_name = self.scope.get_expected_name(lhs_name)
lhs = Variable(class_type, lhs_name, **d_var)
self.scope.insert_variable(lhs)
# Iterate over the loops
# This provides the definitions of iterators as well
# as the central expression
loops = [self._visit(expr.loops)]
# If necessary add additional expressions corresponding
# to nested GeneratorComprehensions
if new_expr:
loop = loops[0]
for _ in range(nlevels-1):
loop = loop.body.body[0]
for e in new_expr:
loop.body.insert2body(e, back=False)
e.loops[-1].scope.update_parent_scope(loop.scope, is_loop = True)
# Initialise result with correct initial value
stmt = Assign(lhs, val)
stmt.set_current_ast(expr.python_ast)
loops.insert(0, stmt)
indices = [self._visit(i) for i in expr.indices]
if isinstance(expr, FunctionalSum):
expr_new = FunctionalSum(loops, lhs=lhs, indices = indices, conditions=expr.conditions)
elif isinstance(expr, FunctionalMin):
expr_new = FunctionalMin(loops, lhs=lhs, indices = indices, conditions=expr.conditions)
elif isinstance(expr, FunctionalMax):
expr_new = FunctionalMax(loops, lhs=lhs, indices = indices, conditions=expr.conditions)
expr_new.set_current_ast(expr.python_ast)
return expr_new
def _find_superclasses(self, expr):
"""
Find all the superclasses in the scope.
From a syntactic ClassDef, extract the names of the superclasses and
search through the scope to find their definitions. If there is no
definition then an error is raised.
Parameters
----------
expr : ClassDef
The class whose superclasses we wish to find.
Returns
-------
list
An iterable containing the definitions of all the superclasses.
Raises
------
PyccelSemanticError
A `PyccelSemanticError` is reported and will be raised after the
semantic stage is complete.
"""
parent = {s: self.scope.find(s, 'classes') for s in expr.superclasses}
if any(c is None for c in parent.values()):
for s,c in parent.items():
if c is None:
errors.report(f"Couldn't find class {s} in scope", symbol=expr,
severity='error')
parent = {s:c for s,c in parent.items() if c is not None}
return list(parent.values())
def _convert_syntactic_object_to_type_annotation(self, syntactic_annotation):
"""
Convert an arbitrary syntactic object to a type annotation.
Convert an arbitrary syntactic object to a type annotation. This means that
the syntactic object is wrapped in a SyntacticTypeAnnotation (if necessary).
This ensures that a type annotation is obtained instead of e.g. a function.
Parameters
----------
syntactic_annotation : PyccelAstNode
A syntactic object that needs to be visited as a type annotation.
Returns
-------
SyntacticTypeAnnotation
A syntactic object that will be recognised as a type annotation.
"""
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
pyccel_stage.set_stage('syntactic')
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
pyccel_stage.set_stage('semantic')
return syntactic_annotation
def _get_indexed_type(self, base, args, expr):
"""
Extract a type annotation from an IndexedElement.
Extract a type annotation from an IndexedElement. This may be a type indexed with
slices (indicating a NumPy array), or a class type such as tuple/list/etc which is
indexed with the datatype.
Parameters
----------
base : type deriving from PyccelAstNode
The object being indexed.
args : tuple of PyccelAstNode
The indices being used to access the base.
expr : PyccelAstNode
The annotation, used for error printing.
Returns
-------
UnionTypeAnnotation
The type annotation described by this object.
"""
if isinstance(base, PyccelFunctionDef) and base.cls_name is TypingFinal:
syntactic_annotation = args[0]
if not isinstance(syntactic_annotation, SyntacticTypeAnnotation):
pyccel_stage.set_stage('syntactic')
syntactic_annotation = SyntacticTypeAnnotation(dtype=syntactic_annotation)
pyccel_stage.set_stage('semantic')
annotation = self._visit(syntactic_annotation)
for t in annotation.type_list:
t.is_const = True
return annotation
elif isinstance(base, UnionTypeAnnotation):
return UnionTypeAnnotation(*[self._get_indexed_type(t, args, expr) for t in base.type_list])
if all(isinstance(a, Slice) for a in args):
rank = len(args)
order = None if rank < 2 else 'C'
if isinstance(base, VariableTypeAnnotation):
dtype = base.class_type
if dtype.rank != 0:
raise errors.report("NumPy element must be a scalar type", severity='fatal', symbol=expr)
class_type = NumpyNDArrayType(numpy_process_dtype(dtype), rank, order)
elif isinstance(base, PyccelFunctionDef):
dtype_cls = base.cls_name
try:
dtype = numpy_process_dtype(dtype_cls.static_type())
except AttributeError:
errors.report(f"Unrecognised datatype {dtype_cls}", severity='fatal', symbol=expr)
class_type = NumpyNDArrayType(dtype, rank, order)
return VariableTypeAnnotation(class_type)
if not any(isinstance(a, Slice) for a in args):
if isinstance(base, PyccelFunctionDef):
dtype_cls = base.cls_name.static_type()
else:
raise errors.report(f"Unknown annotation base {base}\n"+PYCCEL_RESTRICTION_TODO,
severity='fatal', symbol=expr)
if (len(args) == 2 and args[1] is LiteralEllipsis()) or \
(len(args) == 1 and dtype_cls is not TupleType):
syntactic_annotation = self._convert_syntactic_object_to_type_annotation(args[0])
internal_datatypes = self._visit(syntactic_annotation)
class_type = HomogeneousTupleType if dtype_cls is TupleType else dtype_cls
type_annotations = [VariableTypeAnnotation(class_type(u.class_type), u.is_const)
for u in internal_datatypes.type_list]
return UnionTypeAnnotation(*type_annotations)
elif len(args) == 2 and dtype_cls is DictType:
syntactic_key_annotation = self._convert_syntactic_object_to_type_annotation(args[0])
syntactic_val_annotation = self._convert_syntactic_object_to_type_annotation(args[1])
key_types = self._visit(syntactic_key_annotation)
val_types = self._visit(syntactic_val_annotation)
type_annotations = [VariableTypeAnnotation(dtype_cls(k.class_type, v.class_type)) \
for k,v in zip(key_types.type_list, val_types.type_list)]
return UnionTypeAnnotation(*type_annotations)
elif dtype_cls is TupleType:
syntactic_annotations = [self._convert_syntactic_object_to_type_annotation(a) for a in args]
types = [self._visit(a).type_list for a in syntactic_annotations]
internal_datatypes = list(product(*types))
type_annotations = [VariableTypeAnnotation(InhomogeneousTupleType(*[ui.class_type for ui in u]), True)
for u in internal_datatypes]
return UnionTypeAnnotation(*type_annotations)
else:
raise errors.report("Cannot handle non-homogenous type index\n"+PYCCEL_RESTRICTION_TODO,
severity='fatal', symbol=expr)
raise errors.report("Unrecognised type slice",
severity='fatal', symbol=expr)
[docs]
def insert_attribute_to_class(self, class_def, self_var, attrib):
"""
Insert a new attribute into an existing class.
Insert a new attribute into an existing class definition. In order to do this a dotted
variable must be created. If the new attribute is an inhomogeneous tuple then this
function is called recursively to insert each variable comprising the tuple into the
class definition.
Parameters
----------
class_def : ClassDef
The class definition to which the attribute should be added.
self_var : Variable
The variable representing the 'self' variable of the class instance.
attrib : Variable
The attribute which should be inserted into the class definition.
Returns
-------
DottedVariable | PythonTuple
The object that was inserted into the class definition.
"""
# Create the local DottedVariable
lhs = attrib.clone(attrib.name, new_class = DottedVariable, lhs = self_var)
if isinstance(attrib.class_type, InhomogeneousTupleType):
for v in attrib:
self.insert_attribute_to_class(class_def, self_var, class_def.scope.collect_tuple_element(v))
else:
# update the attributes of the class and push it to the scope
class_def.add_new_attribute(lhs)
return lhs
def _get_iterable(self, syntactic_iterable):
"""
Get an Iterable object from a syntactic object that is used in an iterable context.
Get an Iterable object from a syntactic object that is used in an iterable context.
A typical example of an iterable context is the iterable of a for loop.
Parameters
----------
syntactic_iterable : PyccelAstNode
The syntactic object that should be usable as an iterable.
Returns
-------
Iterable
A semantic Iterable object.
"""
iterable = self._visit(syntactic_iterable)
if isinstance(iterable, (Variable, IndexedElement)):
if isinstance(iterable.class_type, DictType):
iterable = DictKeys(iterable)
else:
iterable = VariableIterator(iterable)
elif not isinstance(iterable, Iterable):
if isinstance(iterable, TypedAstNode):
pyccel_stage.set_stage('syntactic')
tmp_var = self.scope.get_new_name()
syntactic_assign = Assign(tmp_var, iterable, python_ast = iterable.python_ast)
pyccel_stage.set_stage('semantic')
assign = self._visit(syntactic_assign)
self._additional_exprs[-1].append(assign)
iterable = VariableIterator(self._visit(tmp_var))
else:
errors.report(f"{iterable} is not handled as the iterable of a for loop",
symbol=syntactic_iterable, severity='fatal')
return iterable
def _get_for_iterators(self, syntactic_iterable, iterator, new_expr, expr):
"""
Get the semantic target and iterable of a for loop.
Get the semantic target and iterable of a for loop. This method can be used to
handle generators, comprehension expressions or basic for loops.
Parameters
----------
syntactic_iterable : TypedAstNode
The iterable that the for loop iterates over.
iterator : TypedAstNode
The syntactic iterator that takes the value of the elements of the iterable.
new_expr : list[PyccelAstNode]
A list which allows collection of any additional expressions
resulting from this operation (e.g. Allocation).
expr : PyccelAstNode
The expression being visited. This is used for error handling.
Returns
-------
target : TypedAstNode
The semantic iterator that takes the value of the elements of the iterable.
iterable : TypedAstNode
The semantic iterable that the for loop iterates over.
"""
iterable = self._get_iterable(syntactic_iterable)
if iterable.num_loop_counters_required:
indices = [Variable(PythonNativeInt(), self.scope.get_new_name(), is_temp=True)
for i in range(iterable.num_loop_counters_required)]
iterable.set_loop_counter(*indices)
else:
if isinstance(iterable, PythonEnumerate):
if isinstance(iterator, PythonTuple):
syntactic_index = iterator[0]
else:
pyccel_stage.set_stage('syntactic')
syntactic_index = IndexedElement(iterator,0)
pyccel_stage.set_stage('semantic')
else:
syntactic_index = iterator
index = self.check_for_variable(syntactic_index)
if index is None:
start = LiteralInteger(0)
d_var = self._infer_type(start)
if isinstance(syntactic_index, PyccelSymbol):
index = self._assign_lhs_variable(syntactic_index, d_var,
rhs=start, new_expressions=new_expr)
else:
index = self.scope.get_temporary_variable(PythonNativeInt())
iterable.set_loop_counter(index)
# Collect a target with a deducible dtype
iterator_rhs = iterable.get_python_iterable_item()
# Use _visit_Assign to create the requested iterator with the correct type
# The result of this operation is not stored, it is just used to declare
# iterator with the correct dtype to allow correct dtype deductions later
if isinstance(iterator, PyccelSymbol):
if len(iterator_rhs) != 1:
iterator_rhs = PythonTuple(*iterator_rhs, prefer_inhomogeneous=True)
else:
iterator_rhs = iterator_rhs[0]
iterator_d_var = self._infer_type(iterator_rhs)
target = self._assign_lhs_variable(iterator, iterator_d_var,
rhs=iterator_rhs, new_expressions=new_expr)
if target.is_alias:
self._indicate_pointer_target(target, iterator_rhs, expr.python_ast)
if isinstance(target.class_type, InhomogeneousTupleType):
target = [self.scope.collect_tuple_element(v) for v in target]
else:
target = [target]
elif isinstance(iterator, PythonTuple):
target = [self._assign_lhs_variable(it, self._infer_type(rhs),
rhs=rhs, new_expressions=new_expr)
for it, rhs in zip(iterator, iterator_rhs)]
for t, rhs in zip(target, iterator_rhs):
if t.is_alias:
self._indicate_pointer_target(t, rhs, expr.python_ast)
else:
raise errors.report(INVALID_FOR_ITERABLE, symbol=iterator,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='error')
return target, iterable
[docs]
def env_var_to_pyccel(self, env_var, *, name = None):
"""
Convert an environment variable to a Pyccel AST node.
Convert an environment variable (i.e. a variable deduced from the
context where epyccel was called) into a Pyccel AST node as though
the object had been declared explicitly in the code.
Parameters
----------
env_var : object
The environment variable.
name : str, optional
The name that was used to identify the variable.
Returns
-------
PyccelAstNode
The usable Pyccel AST node.
"""
if env_var in original_type_to_pyccel_type:
return VariableTypeAnnotation(original_type_to_pyccel_type[env_var])
elif type(env_var) in original_type_to_pyccel_type:
return convert_to_literal(env_var, dtype = original_type_to_pyccel_type[type(env_var)])
elif env_var is typing.Final:
return PyccelFunctionDef('Final', TypingFinal)
elif isinstance(env_var, typing.GenericAlias):
class_type = self.env_var_to_pyccel(typing.get_origin(env_var)).class_type.static_type()
return VariableTypeAnnotation(class_type(*[self.env_var_to_pyccel(a).class_type for a in typing.get_args(env_var)]))
elif isinstance(env_var, typing.TypeVar):
constraints = [self.env_var_to_pyccel(c) for c in env_var.__constraints__]
return TypingTypeVar(env_var.__name__, *constraints,
covariant = env_var.__covariant__,
contravariant = env_var.__contravariant__)
elif isinstance(env_var, ModuleType):
mod_name = env_var.__name__
if recognised_source(mod_name):
pyccel_stage.set_stage('syntactic')
import_node = Import(AsName(mod_name, name))
pyccel_stage.set_stage('semantic')
# Insert import at global scope
current_scope = self.scope
scope = current_scope
while scope.parent_scope:
scope = scope.parent_scope
self.scope = scope
self._additional_exprs[-1].append(self._visit(import_node))
self.scope = current_scope
return self.scope.find(name)
else:
errors.report(f"Unrecognised module {mod_name} imported in global scope. Please import the module locally if it was previously Pyccelised.",
severity='error', symbol = self.current_ast_node)
elif isinstance(env_var, (typing.ForwardRef, str)):
pyccel_stage.set_stage('syntactic')
try:
annotation = types_meta.model_from_str(getattr(env_var, '__forward_arg__', env_var))
except TextXSyntaxError as e:
errors.report(f"Invalid annotation. {e.message}",
symbol = self.current_ast_node, severity='fatal')
annot = annotation.expr
pyccel_stage.set_stage('semantic')
return self._visit(annot)
errors.report(PYCCEL_RESTRICTION_TODO,
severity='error', symbol = self.current_ast_node)
return None
#====================================================
# _visit functions
#====================================================
def _visit(self, expr):
"""
Annotate the AST.
The annotation is done by finding the appropriate function _visit_X
for the object expr. X is the type of the object expr. If this function
does not exist then the method resolution order is used to search for
other compatible _visit_X functions. If none are found then an error is
raised.
Parameters
----------
expr : pyccel.ast.basic.PyccelAstNode | PyccelSymbol
Object to visit of type X.
Returns
-------
pyccel.ast.basic.PyccelAstNode
AST object which is the semantic equivalent of expr.
"""
if getattr(expr, 'pyccel_staging', 'syntactic') == 'semantic':
return expr
# TODO - add settings to Errors
# - line and column
# - blocking errors
current_ast = self.current_ast_node
if getattr(expr,'python_ast', None) is not None:
self._current_ast_node = expr.python_ast
classes = type(expr).__mro__
for cls in classes:
annotation_method = '_visit_' + cls.__name__
try:
if hasattr(self, annotation_method):
if self._verbose > 2:
print(f">>>> Calling SemanticParser.{annotation_method}")
obj = getattr(self, annotation_method)(expr)
if isinstance(obj, PyccelAstNode) and self.current_ast_node:
obj.set_current_ast(self.current_ast_node)
self._current_ast_node = current_ast
return obj
except (PyccelError, NotImplementedError) as err:
raise err
except Exception as err: #pylint: disable=broad-exception-caught
if ErrorsMode().value == 'user':
errors.report(PYCCEL_INTERNAL_ERROR,
symbol = self._current_ast_node, severity='fatal')
else:
raise err
# Unknown object, we raise an error.
return errors.report(PYCCEL_RESTRICTION_TODO, symbol=type(expr),
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
def _visit_Module(self, expr):
imports = [self._visit(i) for i in expr.imports]
init_func_body = [i for i in imports if not isinstance(i, EmptyNode)]
for f in expr.funcs:
self.insert_function(f)
# Avoid conflicts with symbols from Program
if expr.program:
self.scope.insert_symbols(expr.program.scope.all_used_symbols)
for c in expr.classes:
self._visit(c)
init_func_body += self._visit(expr.init_func).body
mod_name = self.metavars.get('module_name', None)
if mod_name is None:
mod_name = expr.name
else:
self.scope.insert_symbol(mod_name)
self._mod_name = mod_name
if isinstance(expr.name, AsName):
name_suffix = expr.name.name
else:
name_suffix = expr.name
if expr.program:
prog_name = 'prog_'+name_suffix
prog_name = self.scope.get_new_name(prog_name)
self._allocs.append(set())
self._pointer_targets.append({})
mod_scope = self.scope
prog_syntactic_scope = expr.program.scope
self.scope = mod_scope.new_child_scope(prog_name,
used_symbols = prog_syntactic_scope.local_used_symbols.copy(),
original_symbols = prog_syntactic_scope.python_names.copy())
prog_scope = self.scope
imports = [self._visit(i) for i in expr.program.imports]
body = [i for i in imports if not isinstance(i, EmptyNode)]
body += self._visit(expr.program.body).body
program_body = CodeBlock(body)
# Calling the Garbage collecting,
# it will add the necessary Deallocate nodes
# to the ast
program_body.insert2body(*self._garbage_collector(program_body))
self._pointer_targets.pop()
self.scope = mod_scope
funcs_to_visit = list(self.scope.functions.values())
funcs_to_visit.extend(m for c in self.scope.classes.values() for m in c.methods)
for f in funcs_to_visit:
if not f.is_semantic and not isinstance(f, InlineFunctionDef):
assert isinstance(f, FunctionDef)
self._visit(f)
classes = self.scope.classes.values()
for c in classes:
self._create_class_destructor(c)
for f in self.scope.functions.values():
assert f.is_semantic or f.is_inline
variables = self.get_variables(self.scope)
init_func = None
free_func = None
program = None
comment_types = (Header, EmptyNode, Comment, CommentBlock)
if not all(isinstance(l, comment_types) for l in init_func_body):
# If there are any initialisation statements then create an initialisation function
init_var = Variable(PythonNativeBool(), self.scope.get_new_name('initialised'),
is_private=True, is_temp = True)
syntactic_init_func_name = name_suffix+'__init'
init_func_name = self.scope.get_new_name(syntactic_init_func_name)
# Ensure that the function is correctly defined within the namespaces
init_scope = self.create_new_function_scope(syntactic_init_func_name, init_func_name)
for b in init_func_body:
if isinstance(b, ScopedAstNode):
b.scope.update_parent_scope(init_scope, is_loop = True)
if isinstance(b, FunctionalFor):
for l in b.loops:
if isinstance(l, ScopedAstNode):
l.scope.update_parent_scope(init_scope, is_loop = True)
self.exit_function_scope()
# Update variable scope for temporaries
to_remove = []
scope_variables = list(self.scope.variables.values())
for v in scope_variables:
if v.is_temp:
self.scope.remove_variable(v)
init_scope.insert_variable(v)
to_remove.append(v)
variables.remove(v)
# Get deallocations
deallocs = self._garbage_collector(CodeBlock(init_func_body))
# Deallocate temporaries in init function
dealloc_vars = [d.variable for d in deallocs]
for i,v in enumerate(dealloc_vars):
if v in to_remove:
d = deallocs.pop(i)
init_func_body.append(d)
init_func_body = If(IfSection(PyccelNot(init_var),
init_func_body+[Assign(init_var, LiteralTrue())]))
init_func = FunctionDef(init_func_name, [], [init_func_body],
global_vars = variables, scope=init_scope)
self.insert_function(init_func)
if init_func:
syntactic_free_func_name = name_suffix+'__free'
free_func_name = self.scope.get_new_name(syntactic_free_func_name)
pyccelised_imports = [imp for imp_name, imp in self.scope.imports['imports'].items() \
if imp_name in self.d_parsers]
import_frees = [self.d_parsers[imp.source].semantic_parser.ast.free_func for imp in pyccelised_imports \
if imp.source in self.d_parsers]
import_frees = [f if f.name in imp.target else \
f.clone(next(i.target for i in imp.target \
if isinstance(i, AsName) and i.name == f.name)) \
for f,imp in zip(import_frees, pyccelised_imports) if f]
if deallocs or import_frees:
# If there is anything that needs deallocating when the module goes out of scope
# create a deallocation function
import_free_calls = [f() for f in import_frees if f is not None]
free_func_body = If(IfSection(init_var,
import_free_calls+deallocs+[Assign(init_var, LiteralFalse())]))
# Ensure that the function is correctly defined within the namespaces
scope = self.create_new_function_scope(syntactic_free_func_name, free_func_name)
free_func = FunctionDef(free_func_name, [], [free_func_body],
global_vars = variables, scope = scope)
self.exit_function_scope()
self.insert_function(free_func)
funcs = []
interfaces = []
for f in self.scope.functions.values():
if isinstance(f, FunctionDef):
funcs.append(f)
elif isinstance(f, Interface):
interfaces.append(f)
# in the case of a header file, we need to convert all headers to
# FunctionDef etc ...
if self.is_header_file:
if self.metavars.get('external', False):
for f in funcs:
f.is_external = True
for c in classes:
for m in c.methods:
m.is_external = True
for v in variables:
if v.rank > 0 and not v.is_alias:
v.is_target = True
mod = Module(mod_name,
variables,
funcs,
init_func = init_func,
free_func = free_func,
interfaces=interfaces,
classes=classes,
imports=self.scope.imports['imports'].values(),
scope=self.scope)
if expr.program:
container = prog_scope.imports
container['imports'][mod_name] = Import(self.scope.get_python_name(mod_name), mod)
if init_func:
import_init = init_func()
program_body.insert2body(import_init, back=False)
if free_func:
import_free = free_func()
program_body.insert2body(import_free)
imports = list(container['imports'].values())
for i in self.scope.imports['imports'].values():
target = []
for t in i.target:
local_t = self.scope.find(t.name)
if local_t and program_body.is_user_of(local_t, excluded_nodes = (FunctionDef,)):
target.append(t)
if target:
imports.append(Import(i.source, target, ignore_at_print = i.ignore, mod = i.source_module))
program = Program(prog_name,
self.get_variables(prog_scope),
program_body,
imports,
scope=prog_scope)
mod.program = program
return mod
def _visit_PythonTuple(self, expr):
ls = [self._visit(i) for i in expr]
prefer_inhomogeneous = False
if expr.get_user_nodes(Return, (IndexedElement, FunctionCall, PyccelFunction, PyccelOperator)):
func = expr.get_user_nodes(FunctionDef)[0]
n_returns = set(r.n_explicit_results for r in func.get_attribute_nodes(Return))
prefer_inhomogeneous = len(n_returns) == 1
return PythonTuple(*ls, prefer_inhomogeneous = prefer_inhomogeneous)
def _visit_PythonList(self, expr):
ls = [self._visit(i) for i in expr]
try:
expr = PythonList(*ls)
except TypeError:
errors.report(PYCCEL_RESTRICTION_INHOMOG_LIST, symbol=expr,
severity='fatal')
return expr
def _visit_PythonSet(self, expr):
ls = [self._visit(i) for i in expr]
try:
expr = PythonSet(*ls)
except TypeError as e:
message = str(e)
errors.report(message, symbol=expr,
severity='fatal')
return expr
def _visit_PythonDict(self, expr):
keys = [self._visit(k) for k in expr.keys]
vals = [self._visit(v) for v in expr.values]
try:
expr = PythonDict(keys, vals)
except TypeError as e:
errors.report(str(e), symbol=expr,
severity='fatal')
return expr
def _visit_FunctionCallArgument(self, expr):
value = self._visit(expr.value)
a = FunctionCallArgument(value, expr.keyword)
def generate_and_assign_temp_var():
pyccel_stage.set_stage('syntactic')
tmp_var = self.scope.get_new_name()
syntactic_assign = Assign(tmp_var, expr.value, python_ast = expr.value.python_ast)
pyccel_stage.set_stage('semantic')
assign = self._visit(syntactic_assign)
self._additional_exprs[-1].append(assign)
return FunctionCallArgument(self._visit(tmp_var))
if isinstance(value, (PyccelArithmeticOperator, PyccelFunction)) and value.rank:
a = generate_and_assign_temp_var()
elif isinstance(value, FunctionCall) and isinstance(value.class_type, CustomDataType):
if value.funcdef.results.var and not value.funcdef.results.var.is_alias:
a = generate_and_assign_temp_var()
return a
def _visit_UnionTypeAnnotation(self, expr):
annotations = [self._visit(syntax_type_annot) for syntax_type_annot in expr.type_list]
types = [t for a in annotations for t in (a.type_list if isinstance(a, UnionTypeAnnotation) else [a])]
return UnionTypeAnnotation(*types)
def _visit_FunctionTypeAnnotation(self, expr):
arg_types = [self._visit(a)[0] for a in expr.args]
res_type = self._visit(expr.result)
return UnionTypeAnnotation(FunctionTypeAnnotation(arg_types, res_type))
def _visit_TypingFinal(self, expr):
annotation = self._visit(expr.arg)
for t in annotation:
t.is_const = True
return annotation
def _visit_FunctionDefArgument(self, expr):
arg = self._visit(expr.var)
value = None if expr.value is None else self._visit(expr.value)
kwonly = expr.is_kwonly
is_optional = isinstance(value, Nil)
bound_argument = expr.bound_argument
args = []
for v in arg:
if isinstance(v, Variable):
dtype = v.class_type
if isinstance(value, Literal) and value is not Nil():
value = convert_to_literal(value.python_value, dtype)
if isinstance(dtype, InhomogeneousTupleType):
# Raise an error as elements are not yet correctly marked with is_argument.
# This leads to printing errors
errors.report("Inhomogeneous tuples are not yet supported as arguments",
severity='error', symbol=expr)
if isinstance(dtype, CustomDataType) and not bound_argument:
cls = self.scope.find(str(dtype), 'classes')
if cls:
init_method = cls.get_method('__init__', expr)
if not init_method.is_semantic:
self._visit(init_method)
clone_var = v.clone(v.name, is_optional = is_optional, is_argument = True)
args.append(FunctionDefArgument(clone_var, bound_argument = bound_argument,
value = value, kwonly = kwonly, annotation = expr.annotation))
else:
args.append(FunctionDefArgument(v.clone(v.name, is_optional = is_optional,
is_kwonly = kwonly, is_argument = True), bound_argument = bound_argument,
value = value, kwonly = kwonly, annotation = expr.annotation))
return args
def _visit_CodeBlock(self, expr):
ls = []
self._additional_exprs.append([])
for b in expr.body:
if isinstance(b, EmptyNode):
continue
# Save parsed code
line = self._visit(b)
ls.extend(self._additional_exprs[-1])
self._additional_exprs[-1] = []
if isinstance(line, CodeBlock):
ls.extend(line.body)
elif isinstance(line, list) and isinstance(line[0], Variable):
self.scope.insert_variable(line[0])
else:
ls.append(line)
self._additional_exprs.pop()
return CodeBlock(ls)
def _visit_Nil(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Break(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Continue(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Comment(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_CommentBlock(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_AnnotatedComment(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_OmpAnnotatedComment(self, expr):
code = expr._user_nodes
code = code[-1]
index = code.body.index(expr)
combined_loop = expr.combined and ('for' in expr.combined or 'distribute' in expr.combined or 'taskloop' in expr.combined)
if isinstance(expr, (OMP_Sections_Construct, OMP_Single_Construct)) \
and expr.has_nowait:
for node in code.body[index+1:]:
if isinstance(node, Omp_End_Clause):
if node.txt.startswith(expr.name, 4):
node.has_nowait = True
if isinstance(expr, (OMP_For_Loop, OMP_Simd_Construct,
OMP_Distribute_Construct, OMP_TaskLoop_Construct)) or combined_loop:
index += 1
while index < len(code.body) and isinstance(code.body[index], (Comment, CommentBlock, Pass)):
index += 1
if index < len(code.body) and isinstance(code.body[index], For):
end_expr = ['!$omp', 'end', expr.name]
if expr.combined:
end_expr.append(expr.combined)
if expr.has_nowait:
end_expr.append('nowait')
code.body[index].end_annotation = ' '.join(e for e in end_expr if e)+'\n'
else:
type_name = type(expr).__name__
msg = f"Statement after {type_name} must be a for loop."
errors.report(msg, symbol=expr,
severity='fatal')
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Omp_End_Clause(self, expr):
end_loop = any(c in expr.txt for c in ['for', 'distribute', 'taskloop', 'simd'])
if end_loop:
errors.report("For loops do not require an end clause. This clause is ignored",
severity='warning', symbol=expr)
return EmptyNode()
else:
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Literal(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Pass(self, expr):
expr.clear_syntactic_user_nodes()
expr.update_pyccel_staging()
return expr
def _visit_Variable(self, expr):
name = self.scope.get_python_name(expr.name)
var = self.get_variable(name)
return self._optional_params.get(var, var)
def _visit_str(self, expr):
return repr(expr)
def _visit_Slice(self, expr):
start = self._visit(expr.start) if expr.start is not None else None
stop = self._visit(expr.stop) if expr.stop is not None else None
step = self._visit(expr.step) if expr.step is not None else None
return Slice(start, stop, step)
def _visit_IndexedElement(self, expr):
var = self._visit(expr.base)
if isinstance(var, (PyccelFunctionDef, VariableTypeAnnotation, UnionTypeAnnotation)):
return self._get_indexed_type(var, expr.indices, expr)
class_type = var.class_type
if isinstance(class_type, (NumpyNDArrayType, HomogeneousListType, TupleType)):
# TODO check consistency of indices with shape/rank
args = [self._visit(idx) for idx in expr.indices]
if (len(args) == 1 and isinstance(getattr(args[0], 'class_type', None), TupleType)):
args = args[0]
elif any(isinstance(getattr(a, 'class_type', None), TupleType) for a in args):
n_exprs = None
for a in args:
if getattr(a, 'shape', None) and isinstance(a.shape[0], LiteralInteger):
a_len = a.shape[0]
if n_exprs:
assert n_exprs == a_len
else:
n_exprs = a_len
if n_exprs is not None:
new_expr_args = [[a[i] if hasattr(a, '__getitem__') else a for a in args]
for i in range(n_exprs)]
return NumpyArray(PythonTuple(*[var[a] for a in new_expr_args]))
return self._extract_indexed_from_var(var, args, expr)
else:
cls_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)
method = cls_base.get_method('__getitem__')
if method:
class_args = self._handle_function_args([FunctionCallArgument(a) for a in expr.indices])
args = [FunctionCallArgument(var), *class_args]
return self._handle_function(expr, method, args)
else:
raise errors.report(f"No __getitem__ found for type {class_type}",
severity='fatal', symbol=expr)
def _visit_PyccelSymbol(self, expr):
name = expr
var = self.check_for_variable(name)
if var is None:
var = self.scope.find(name)
if var is None:
var = builtin_functions_dict.get(name, None)
if var is not None:
var = PyccelFunctionDef(name, var)
if var is None and self._in_annotation:
var = numpy_funcs.get(name, None)
if name == '*':
return GenericType()
if var is None and name in self._context_dict:
var = self.env_var_to_pyccel(self._context_dict[name], name = name)
if var is None:
if name == '_':
errors.report(UNDERSCORE_NOT_A_THROWAWAY,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
else:
errors.report(UNDEFINED_VARIABLE, symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
return self._optional_params.get(var, var)
def _visit_AnnotatedPyccelSymbol(self, expr):
# Check if the variable already exists
var = self.scope.find(expr.name, 'variables', local_only = True)
if var is not None and not any(isinstance(n, FunctionDefResult) for n in var.get_all_user_nodes()):
errors.report("Variable has been declared multiple times",
symbol=expr, severity='error')
if expr.annotation is None:
errors.report(MISSING_TYPE_ANNOTATIONS,
symbol=expr, severity='fatal')
# Get the semantic type annotation (should be UnionTypeAnnotation)
types = self._visit(expr.annotation)
assert not isinstance(types, TypingTypeVar)
if len(types.type_list) == 0:
errors.report(MISSING_TYPE_ANNOTATIONS,
symbol=expr, severity='fatal')
python_name = expr.name
# Get the collisionless name from the scope
if isinstance(python_name, DottedName):
prefix_parts = python_name.name[:-1]
syntactic_prefix = prefix_parts[0] if len(prefix_parts) == 1 else DottedName(*prefix_parts)
prefix = self._visit(syntactic_prefix)
class_def = prefix.cls_base
attribute_name = python_name.name[-1]
name = class_def.scope.get_expected_name(attribute_name)
var_class = DottedVariable
kwargs = {'lhs': prefix}
else:
name = self.scope.get_expected_name(python_name)
var_class = Variable
kwargs = {}
# Use the local decorators to define the memory and index handling
array_memory_handling = 'heap'
decorators = self.scope.decorators
if decorators:
if 'stack_array' in decorators:
if expr.name in decorators['stack_array']:
array_memory_handling = 'stack'
if 'allow_negative_index' in decorators:
if expr.name in decorators['allow_negative_index']:
kwargs['allows_negative_indexes'] = True
# For each possible data type create the necessary variables
possible_args = []
for t in types.type_list:
if isinstance(t, FunctionTypeAnnotation):
args = t.args
scope = self.create_new_function_scope(name, name)
if t.result.var:
results = FunctionDefResult(t.result.var.clone(t.result.var.name, is_argument = False),
annotation=t.result.annotation)
else:
results = FunctionDefResult(Nil())
self.exit_function_scope()
address = FunctionAddress(name, args, results, scope = scope)
possible_args.append(address)
elif isinstance(t, VariableTypeAnnotation):
class_type = t.class_type
cls_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)
if isinstance(class_type, InhomogeneousTupleType):
shape = (len(class_type),)
elif isinstance(class_type, HomogeneousTupleType):
shape = (None,)*class_type.rank
elif class_type.rank:
shape = (None,)*class_type.container_rank
else:
shape = None
v = var_class(class_type, name, cls_base = cls_base,
shape = shape,
is_const = t.is_const, is_optional = False,
memory_handling = array_memory_handling if class_type.rank > 0 else 'stack',
**kwargs)
possible_args.append(v)
if isinstance(class_type, InhomogeneousTupleType):
for i, t in enumerate(class_type):
pyccel_stage.set_stage('syntactic')
syntactic_elem = AnnotatedPyccelSymbol(self.scope.get_new_name( f'{name}_{i}'),
annotation = UnionTypeAnnotation(VariableTypeAnnotation(t)))
pyccel_stage.set_stage('semantic')
elem = self._visit(syntactic_elem)
self.scope.insert_symbolic_alias(IndexedElement(v, i), elem[0])
else:
errors.report(PYCCEL_RESTRICTION_TODO + '\nUnrecognised type annotation',
severity='fatal', symbol=expr)
# An annotated variable must have a type
assert len(possible_args) != 0
# If var was declared in results
if var is not None:
new_var = possible_args[0]
if len(possible_args) != 1 or new_var.class_type != var.class_type:
errors.report(f"Variable was declared as the result of the function {self.current_function_name} but is now declared with a different type",
symbol=expr, severity='error')
# Remove variable from scope as AnnotatedPyccelSymbol is always inserted into scope
self.scope.remove_variable(var)
return [var]
return possible_args
def _visit_SyntacticTypeAnnotation(self, expr):
self._in_annotation = True
visited_dtype = self._visit(expr.dtype)
self._in_annotation = False
order = expr.order
if isinstance(visited_dtype, UnionTypeAnnotation) and len(visited_dtype.type_list) == 1:
visited_dtype = visited_dtype.type_list[0]
if isinstance(visited_dtype, PyccelFunctionDef):
dtype_cls = visited_dtype.cls_name
try:
class_type = dtype_cls.static_type()
except AttributeError:
errors.report(f"Unrecognised datatype {dtype_cls}", severity='fatal', symbol=expr)
return UnionTypeAnnotation(VariableTypeAnnotation(class_type))
elif isinstance(visited_dtype, VariableTypeAnnotation):
if order and order != visited_dtype.class_type.order:
visited_dtype = VariableTypeAnnotation(visited_dtype.class_type.swap_order())
return UnionTypeAnnotation(visited_dtype)
elif isinstance(visited_dtype, (UnionTypeAnnotation, TypingTypeVar)):
return visited_dtype
elif isinstance(visited_dtype, ClassDef):
# TODO: Improve when #1676 is merged
dtype = self.get_class_construct(visited_dtype.name)
return UnionTypeAnnotation(VariableTypeAnnotation(dtype))
elif isinstance(visited_dtype, PyccelType):
return UnionTypeAnnotation(VariableTypeAnnotation(visited_dtype))
else:
raise errors.report(PYCCEL_RESTRICTION_TODO + ' Could not deduce type information',
severity='fatal', symbol=expr)
def _visit_VariableTypeAnnotation(self, expr):
return expr
def _visit_DottedName(self, expr):
var = self.check_for_variable(_get_name(expr))
if var:
return var
lhs = expr.name[0] if len(expr.name) == 2 \
else DottedName(*expr.name[:-1])
rhs = expr.name[-1]
visited_lhs = self._visit(lhs)
first = visited_lhs
if isinstance(visited_lhs, FunctionCall):
results = visited_lhs.funcdef.results
if len(results) != 1:
errors.report("Cannot get attribute of function call with multiple returns",
symbol=expr, severity='fatal')
first = results.var
rhs_name = _get_name(rhs)
# Handle case of imported module
if isinstance(first, Module):
if rhs_name in first:
imp = self.scope.find(_get_name(lhs), 'imports')
new_name = rhs_name
if imp is not None:
new_name = imp.find_module_target(rhs_name)
if new_name is None:
new_name = self.scope.get_new_name(rhs_name)
# Save the import target that has been used
imp.define_target(AsName(first[rhs_name], PyccelSymbol(new_name)))
elif isinstance(rhs, FunctionCall):
self.scope.imports['functions'][new_name] = first[rhs_name]
elif isinstance(rhs, ConstructorCall):
self.scope.imports['classes'][new_name] = first[rhs_name]
elif isinstance(rhs, Variable):
self.scope.imports['variables'][new_name] = rhs
if isinstance(rhs, FunctionCall):
# If object is a function
args = self._handle_function_args(rhs.args)
func = first[rhs_name]
if new_name != rhs_name:
if hasattr(func, 'clone') and not isinstance(func, PyccelFunctionDef):
func = func.clone(new_name)
pyccel_stage.set_stage('syntactic')
syntactic_call = FunctionCall(func, args)
pyccel_stage.set_stage('semantic')
if first.__module__.startswith('pyccel.'):
self.insert_import(first.name, AsName(func, func.name), _get_name(lhs))
return self._handle_function(syntactic_call, func, args)
elif isinstance(rhs, Constant):
var = first[rhs_name]
if new_name != rhs_name:
var.name = new_name
return var
else:
# If object is something else (eg. dict)
var = first[rhs_name]
return var
else:
errors.report(UNDEFINED_IMPORT_OBJECT.format(rhs_name, str(lhs)),
symbol=expr, severity='fatal')
if isinstance(first, ClassDef):
errors.report("Static class methods are not yet supported", symbol=expr,
severity='fatal')
d_var = self._infer_type(first)
class_type = d_var['class_type']
cls_base = get_cls_base(class_type)
if cls_base is None:
cls_base = self.scope.find(str(class_type), 'classes')
# look for a class method
if isinstance(rhs, FunctionCall):
method = cls_base.get_method(rhs_name, expr)
args = [FunctionCallArgument(visited_lhs), *self._handle_function_args(rhs.args)]
if cls_base.name == 'numpy.ndarray':
numpy_class = method.cls_name
self.insert_import('numpy', AsName(numpy_class, numpy_class.name))
return self._handle_function(expr, method, args, is_method = True)
# look for a class attribute / property
elif isinstance(rhs, PyccelSymbol) and cls_base:
# standard class attribute
second = self.check_for_variable(expr)
if second:
return second
# class property?
else:
method = cls_base.get_method(rhs_name, expr)
assert 'property' in method.decorators
if cls_base.name == 'numpy.ndarray':
numpy_class = method.cls_name
self.insert_import('numpy', AsName(numpy_class, numpy_class.name))
return self._handle_function(expr, method, [FunctionCallArgument(visited_lhs)], is_method = True)
# did something go wrong?
return errors.report(f'Attribute {rhs_name} not found',
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
def _visit_PyccelOperator(self, expr):
args = [self._visit(a) for a in expr.args]
return self._create_PyccelOperator(expr, args)
def _visit_PyccelBooleanOperator(self, expr):
args = [self._visit(a) for a in expr.args]
args = [a if a.dtype is PythonNativeBool() else PythonBool(a) for a in args]
return self._create_PyccelOperator(expr, args)
def _visit_PyccelAdd(self, expr):
args = [self._visit(a) for a in expr.args]
arg0 = args[0]
if isinstance(arg0.class_type, (TupleType, HomogeneousListType)):
arg1 = args[1]
is_homogeneous = not isinstance(arg0.class_type, InhomogeneousTupleType) and \
arg0.class_type == arg1.class_type
if is_homogeneous:
return Concatenate(*args)
else:
if not (isinstance(arg0.shape[0], (LiteralInteger, int)) and isinstance(arg1.shape[0], (LiteralInteger, int))):
errors.report("Can't create an inhomogeneous object from objects of unknown size",
severity='fatal', symbol=expr)
tuple_args = [self.scope.collect_tuple_element(v) for v in arg0] + [self.scope.collect_tuple_element(v) for v in arg1]
expr_new = PythonTuple(*tuple_args)
else:
expr_new = self._create_PyccelOperator(expr, args)
return expr_new
def _visit_PyccelMul(self, expr):
args = [self._visit(a) for a in expr.args]
if isinstance(args[0].class_type, (TupleType, HomogeneousListType)):
expr_new = self._create_Duplicate(args[0], args[1])
elif isinstance(args[1].class_type, (TupleType, HomogeneousListType)):
expr_new = self._create_Duplicate(args[1], args[0])
else:
expr_new = self._create_PyccelOperator(expr, args)
return expr_new
def _visit_PyccelPow(self, expr):
base, exponent = [self._visit(a) for a in expr.args]
exp_val = exponent
if isinstance(exponent, LiteralInteger):
exp_val = exponent.python_value
elif isinstance(exponent, PyccelAssociativeParenthesis):
exp = exponent.args[0]
# Handle (1/2)
if isinstance(exp, PyccelDiv) and all(isinstance(a, Literal) for a in exp.args):
exp_val = exp.args[0].python_value / exp.args[1].python_value
if isinstance(base, (Literal, Variable)) and exp_val == 2:
return PyccelMul(base, base)
elif exp_val == 0.5:
pyccel_stage.set_stage('syntactic')
sqrt_name = self.scope.get_new_name('sqrt')
imp_name = AsName('sqrt', sqrt_name)
if isinstance(base.class_type.primitive_type, PrimitiveComplexType) or isinstance(exponent.class_type.primitive_type, PrimitiveComplexType):
new_import = Import('cmath',imp_name)
else:
new_import = Import('math',imp_name)
self._visit(new_import)
if isinstance(expr.args[0], PyccelAssociativeParenthesis):
new_call = FunctionCall(sqrt_name, [expr.args[0].args[0]])
else:
new_call = FunctionCall(sqrt_name, [expr.args[0]])
pyccel_stage.set_stage('semantic')
return self._visit(new_call)
else:
return PyccelPow(base, exponent)
def _visit_PyccelIn(self, expr):
element = self._visit(expr.element)
container = self._visit(expr.container)
container_type = container.class_type
if isinstance(container_type, (DictType, HomogeneousSetType, HomogeneousListType)):
element_type = container_type.key_type if isinstance(container_type, DictType) else container_type.element_type
if element.class_type == element_type:
return PyccelIn(element, container)
else:
return LiteralFalse()
container_base = self.scope.find(str(container_type), 'classes') or get_cls_base(container_type)
contains_method = container_base.get_method('__contains__',
raise_error_from = expr if isinstance(container_type, CustomDataType) else None)
if contains_method:
return self._handle_function(expr, contains_method, [FunctionCallArgument(container), FunctionCallArgument(element)])
else:
raise errors.report(f"In operator is not yet implemented for type {container_type}",
severity='fatal', symbol=expr)
def _visit_Lambda(self, expr):
errors.report("Lambda functions are not currently supported",
symbol=expr, severity='fatal')
expr_names = set(str(a) for a in expr.expr.get_attribute_nodes(PyccelSymbol))
var_names = map(str, expr.variables)
missing_vars = expr_names.difference(var_names)
if len(missing_vars) > 0:
errors.report(UNDEFINED_LAMBDA_VARIABLE, symbol = missing_vars,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
funcs = expr.expr.get_attribute_nodes(FunctionCall)
for func in funcs:
name = _get_name(func)
f = self.scope.find(name, 'symbolic_functions')
if f is None:
errors.report(UNDEFINED_LAMBDA_FUNCTION, symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
else:
f = f(*func.args)
expr_new = expr.expr.subs(func, f)
expr = Lambda(tuple(expr.variables), expr_new)
return expr
def _visit_FunctionCall(self, expr):
name = expr.funcdef
try:
name = self.scope.get_expected_name(name)
except RuntimeError:
pass
func = self.scope.find(name, 'functions')
if func is None:
name = str(expr.funcdef)
if name in builtin_functions_dict:
func = PyccelFunctionDef(name, builtin_functions_dict[name])
args = self._handle_function_args(expr.args)
if isinstance(func, PyccelFunctionDef) and func.cls_name is TypingTypeVar:
new_args = [args[0]]
for a in args[1:]:
a_val = a.value
if isinstance(a_val, LiteralString):
pyccel_stage.set_stage('syntactic')
try:
syntactic_a = types_meta.model_from_str(a_val.python_value)
except TextXSyntaxError as e:
errors.report(f"Invalid annotation. {e.message}",
symbol = self.current_ast_node, severity='fatal')
annot = syntactic_a.expr
pyccel_stage.set_stage('semantic')
new_args.append(FunctionCallArgument(self._visit(annot)))
else:
new_args.append(a)
args = new_args
# Correct keyword names if scope is available
# The scope is only available if the function body has been parsed
# (i.e. not for headers or builtin functions)
if (isinstance(func, FunctionDef) and func.scope) or isinstance(func, Interface):
scope = func.scope if isinstance(func, FunctionDef) else func.functions[0].scope
args = [a if a.keyword is None else \
FunctionCallArgument(a.value, scope.get_expected_name(a.keyword)) \
for a in args]
func_args = func.arguments if isinstance(func,FunctionDef) else func.functions[0].arguments
if not func.is_semantic:
# Correct func_args keyword names
func_args = [FunctionDefArgument(AnnotatedPyccelSymbol(scope.get_expected_name(a.var.name), a.annotation),
annotation=a.annotation, value=a.value, kwonly=a.is_kwonly, bound_argument=a.bound_argument)
for a in func_args]
args = self._sort_function_call_args(func_args, args)
if name == 'lambdify':
args = self.scope.find(str(expr.args[0]), 'symbolic_functions')
if self.scope.find(name, 'cls_constructs'):
# TODO improve the test
# we must not invoke the scope like this
cls = self.scope.find(name, 'classes')
d_methods = cls.methods_as_dict
method = d_methods.pop('__init__', None)
if not method.is_semantic:
if method.is_inline:
errors.report("An __init__ method cannot be inlined",
severity='fatal', symbol=expr)
method = self._annotate_the_called_function_def(method, args)
if method is None:
# TODO improve case of class with the no __init__
errors.report(UNDEFINED_INIT_METHOD, symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='error')
dtype = cls.class_type
cls_def = cls
d_var = {'class_type' : dtype,
'memory_handling':'stack',
'shape' : None,
'cls_base' : cls_def,
}
new_expression = []
assigns = expr.get_direct_user_nodes(lambda a: isinstance(a, Assign))
if assigns:
lhs = assigns[0].lhs
else:
lhs = self.scope.get_new_name()
if isinstance(lhs, AnnotatedPyccelSymbol):
annotation = self._visit(lhs.annotation)
if len(annotation.type_list) != 1 or annotation.type_list[0].class_type != method.arguments[0].var.class_type:
errors.report(f"Unexpected type annotation in creation of {cls_def.name}",
symbol=annotation, severity='error')
lhs = lhs.name
cls_variable = self._assign_lhs_variable(lhs, d_var,
rhs = method.results.var,
new_expressions = new_expression,
is_augassign = False)
self._additional_exprs[-1].extend(new_expression)
args = (FunctionCallArgument(cls_variable), *args)
self._check_argument_compatibility(args, method.arguments,
method, method.is_elemental)
new_expr = ConstructorCall(method, args, cls_variable)
for a, f_a in zip(new_expr.args, method.arguments):
if f_a.persistent_target:
val = a.value
if isinstance(val, Variable):
a.value.is_target = True
self._indicate_pointer_target(cls_variable, a.value, expr.get_user_nodes(Assign)[0])
else:
errors.report(f"{val} cannot be passed to class constructor call as target. Please create a temporary variable.",
severity='error', symbol=expr)
self._allocs[-1].add(cls_variable)
return new_expr
else:
if func is None and name in self._context_dict:
env_var = self._context_dict[name]
func = builtin_functions_dict.get(env_var.__name__, None)
if func is not None:
func = PyccelFunctionDef(env_var.__name__, func)
mod_name = env_var.__module__
if mod_name:
recognised_mod = recognised_source(mod_name)
elif mod_name is None and isinstance(env_var, BuiltinFunctionType):
# Handling of BuiltinFunctionType is necessary for Python 3.9 (NumPy 1.* doesn't specify __module__)
mod_name = str(env_var).split(' of ',1)[-1].split(' object ',1)[0]
while mod_name and not recognised_source(mod_name):
mod_name = mod_name.rsplit('.', 1)[0]
recognised_mod = len(mod_name) != 0
else:
recognised_mod = False
if func is None and recognised_mod:
pyccel_stage.set_stage('syntactic')
import_node = Import(mod_name, name)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(import_node))
func = self.scope.find(name)
if func is None:
return errors.report(UNDEFINED_FUNCTION, symbol=name,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
else:
return self._handle_function(expr, func, args)
def _visit_Assign(self, expr):
# TODO unset position at the end of this part
new_expressions = []
python_ast = expr.python_ast
assert python_ast
rhs = expr.rhs
lhs = expr.lhs
if isinstance(lhs, AnnotatedPyccelSymbol):
semantic_lhs = self._visit(lhs)
if len(semantic_lhs) != 1:
errors.report("Cannot declare variable with multiple types",
symbol=expr, severity='error')
semantic_lhs_var = semantic_lhs[0]
if isinstance(semantic_lhs_var, DottedVariable):
cls_def = semantic_lhs_var.lhs.cls_base
insert_scope = cls_def.scope
cls_def.add_new_attribute(semantic_lhs_var)
lhs_scope_name = lhs.name.name[-1]
else:
insert_scope = self.scope
lhs_scope_name = lhs.name
lhs = lhs.name
if semantic_lhs_var.class_type is TypeAlias():
pyccel_stage.set_stage('syntactic')
if isinstance(rhs, LiteralString):
try:
annotation = types_meta.model_from_str(rhs.python_value)
except TextXSyntaxError as e:
errors.report(f"Invalid header. {e.message}",
symbol = expr, severity = 'fatal')
rhs = annotation.expr
rhs.set_current_ast(expr.python_ast)
elif not isinstance(rhs, (SyntacticTypeAnnotation, FunctionTypeAnnotation,
VariableTypeAnnotation, UnionTypeAnnotation)):
rhs = SyntacticTypeAnnotation(rhs)
pyccel_stage.set_stage('semantic')
type_annot = self._visit(rhs)
self.scope.insert_symbolic_alias(lhs, type_annot)
return EmptyNode()
try:
insert_scope.insert_variable(semantic_lhs_var, lhs_scope_name)
except RuntimeError as e:
errors.report(e, symbol=expr, severity='error')
if isinstance(rhs, (PythonTuple, PythonList)):
assign_elems = None
if isinstance(lhs, PythonTuple):
# Create variables to handle swap expressions
unsaved_vars = set()
pyccel_stage.set_stage('syntactic')
unsaved_vars = set(rhs.get_attribute_nodes((PyccelSymbol, DottedName, IndexedElement),
excluded_nodes = (FunctionDef,)))
pyccel_stage.set_stage('semantic')
# Test if the expression describes a basic swap or if the rhs contains expressions
# (e.g. arithmetic expressions or further tuples)
# using variables from the left-hand side.
modified_vars = set(lhs.get_attribute_nodes((PyccelSymbol, DottedName, IndexedElement)))
used_vars = set(rhs.get_attribute_nodes((PyccelSymbol, DottedName, IndexedElement),
excluded_nodes = (FunctionDef,)))
trivial_assign = len(modified_vars.intersection(unsaved_vars)) == 0
all_indexed_are_simple = all(all(isinstance(idx, (PyccelSymbol, DottedName, Literal)) for idx in elem.indices)
for elem in modified_vars if isinstance(elem, IndexedElement))
if not trivial_assign and (used_vars.intersection(modified_vars).difference(unsaved_vars) or not all_indexed_are_simple):
errors.report("Assign statement is too complex. It seems that some of the variables used non-trivially on the right-hand side appear on the left-hand side.",
severity='error', symbol=expr)
assign_elems = []
for i, l in enumerate(lhs):
r = rhs[i]
# Get unsaved variables that are still needed
pyccel_stage.set_stage('syntactic')
tmp_rhs_tuple = PythonTuple(*rhs.args[i+1:])
unsaved_vars = set(tmp_rhs_tuple.get_attribute_nodes((PyccelSymbol, DottedName, IndexedElement),
excluded_nodes = (FunctionDef,)))
pyccel_stage.set_stage('semantic')
# If the lhs element has not yet been saved to a variable create a new
# variable to hold this value
if l in unsaved_vars:
temp = self.scope.get_new_name()
pyccel_stage.set_stage('syntactic')
local_assign = Assign(temp, l, python_ast = expr.python_ast)
pyccel_stage.set_stage('semantic')
assign_elems.append(self._visit(local_assign))
# Save the variable containing the value to rhs so it can be
# used when it appears in the assignment
if isinstance(l, IndexedElement):
# A list is required for IndexedElements as they are not singletons
l_list = [r for r in rhs.get_attribute_nodes(IndexedElement) if r == l]
else:
l_list = [l]
for l_elem in l_list:
rhs.substitute(l_elem, temp)
if r == l:
r = temp
# Check for a replacement right-hand side if the rhs is found among the lhs variables
pyccel_stage.set_stage('syntactic')
local_assign = Assign(l, r, python_ast = expr.python_ast)
pyccel_stage.set_stage('semantic')
assign_elems.append(self._visit(local_assign))
elif isinstance(lhs, (PyccelSymbol, DottedName)):
semantic_lhs = self.scope.find(lhs)
if semantic_lhs and isinstance(semantic_lhs.class_type, InhomogeneousTupleType):
pyccel_stage.set_stage('syntactic')
syntactic_assign_elems = [Assign(IndexedElement(lhs,i), r, python_ast=expr.python_ast) for i, r in enumerate(rhs)]
pyccel_stage.set_stage('semantic')
assign_elems = [self._visit(a) for a in syntactic_assign_elems]
elif isinstance(lhs, IndexedElement):
semantic_lhs = self._visit(lhs)
if isinstance(semantic_lhs.class_type, InhomogeneousTupleType):
pyccel_stage.set_stage('syntactic')
syntactic_assign_elems = [Assign(IndexedElement(lhs,i), r, python_ast=expr.python_ast) for i, r in enumerate(rhs)]
pyccel_stage.set_stage('semantic')
assign_elems = [self._visit(a) for a in syntactic_assign_elems]
if assign_elems is not None:
return CodeBlock([l for a in assign_elems for l in (a.body if isinstance(a, CodeBlock) else [a])])
# Steps before visiting
if isinstance(rhs, GeneratorComprehension):
rhs.substitute(rhs.lhs, lhs)
genexp = self._assign_GeneratorComprehension(_get_name(lhs), rhs)
if isinstance(expr, AugAssign):
new_expressions.append(genexp)
rhs = genexp.lhs
elif genexp.lhs.name == lhs:
return genexp
else:
new_expressions.append(genexp)
rhs = genexp.lhs
elif isinstance(rhs, IfTernaryOperator):
value_true = self._visit(rhs.value_true)
if value_true.rank > 0 or value_true.dtype is StringType():
# Temporarily deactivate type checks to construct syntactic assigns
pyccel_stage.set_stage('syntactic')
assign_true = Assign(lhs, rhs.value_true, python_ast = python_ast)
assign_false = Assign(lhs, rhs.value_false, python_ast = python_ast)
pyccel_stage.set_stage('semantic')
cond = self._visit(rhs.cond)
true_section = IfSection(cond, [self._visit(assign_true)])
false_section = IfSection(LiteralTrue(), [self._visit(assign_false)])
return If(true_section, false_section)
# Visit object
if isinstance(rhs, FunctionCall):
name = rhs.funcdef
rhs = self._visit(rhs)
if isinstance(rhs, (PythonMap, PythonZip, PythonEnumerate, PythonRange)):
errors.report(f"{type(rhs)} cannot be saved to variables", symbol=expr, severity='fatal')
else:
rhs = self._visit(rhs)
if isinstance(rhs, NumpyResultType):
errors.report("Cannot assign a datatype to a variable.",
symbol=expr, severity='error')
# Checking for the result of _build_ListExtend or _build_PythonSetFunction
if isinstance(rhs, (For, CodeBlock, ConstructorCall)):
return rhs
elif isinstance(rhs, FunctionCall):
func = rhs.funcdef
results = func.results.var
if results:
d_var = self._infer_type(results)
elif expr.lhs.is_temp:
return rhs
else:
raise errors.report("Cannot assign result of a function without a return",
severity='fatal', symbol=expr)
if isinstance(results.class_type, NumpyNDArrayType) and isinstance(lhs, IndexedElement):
temp = self.scope.get_new_name()
semantic_temp = self._assign_lhs_variable(temp, d_var, rhs, new_expressions)
new_expressions.append(Assign(semantic_temp, rhs))
rhs = semantic_temp
errors.report((f"Saving the result of the function {func.name} to a slice requires unnecessary "
"data allocation and copies. This has a performance cost. Consider modifying "
f"{func.name} so {lhs} can be passed as an argument whose contents are modified."),
severity='warning', symbol=expr)
# case of elemental function
# if the input and args of func do not have the same shape,
# then the lhs must be already declared
if func.is_elemental:
# we first compare the funcdef args with the func call
# args
# d_var = None
func_args = func.arguments
call_args = rhs.args
f_ranks = [x.var.rank for x in func_args]
c_ranks = [x.value.rank for x in call_args]
same_ranks = [x==y for (x,y) in zip(f_ranks, c_ranks)]
if not all(same_ranks):
assert len(c_ranks) == 1
arg = call_args[0].value
d_var['shape' ] = arg.shape
d_var['memory_handling'] = arg.memory_handling
d_var['class_type' ] = arg.class_type
d_var['cls_base' ] = arg.cls_base
elif isinstance(rhs, NumpyTranspose):
d_var = self._infer_type(rhs)
if d_var['memory_handling'] == 'alias' and not isinstance(lhs, IndexedElement):
rhs = rhs.internal_var
elif isinstance(rhs, PyccelFunction) and isinstance(rhs.dtype, VoidType):
if expr.lhs.is_temp:
return rhs
else:
raise NotImplementedError("Cannot assign result of a function without a return")
elif isinstance(rhs, TypingTypeVar):
self.scope.insert_symbolic_alias(lhs, rhs)
return EmptyNode()
else:
d_var = self._infer_type(rhs)
d_list = d_var if isinstance(d_var, list) else [d_var]
for d in d_list:
name = d['class_type'].__class__.__name__
if name.startswith('Pyccel'):
name = name[6:]
d['cls_base'] = self.scope.find(name, 'classes')
if d_var['memory_handling'] == 'alias':
d['memory_handling'] = 'alias'
else:
d['memory_handling'] = d_var['memory_handling'] or 'heap'
# TODO if we want to use pointers then we set target to true
# in the ConsturcterCall
if isinstance(rhs, Variable) and rhs.is_target:
# case of rhs is a target variable the lhs must be a pointer
d['memory_handling'] = 'alias'
if isinstance(lhs, (PyccelSymbol, DottedName)):
if isinstance(d_var, list):
if len(d_var) == 1:
d_var = d_var[0]
else:
errors.report(WRONG_NUMBER_OUTPUT_ARGS, symbol=expr,
severity='error')
return None
lhs = self._assign_lhs_variable(lhs, d_var, rhs, new_expressions,
arr_in_multirets = (isinstance(rhs, FunctionCall) and \
not getattr(rhs.funcdef, 'is_elemental', False)))
# If lhs is a purely symbolic object to link tuple elements to their containing tuple
# then no semantic object should be returned
# This can happen when returning an inhomogeneous tuple
if isinstance(rhs, PythonTuple) and isinstance(lhs.class_type, InhomogeneousTupleType):
for li, ri in zip(lhs, rhs):
li_var = self.scope.collect_tuple_element(li)
if li_var == ri:
new_expressions = [n for n in new_expressions if not n.is_user_of(li_var)]
# Handle assignment to multiple variables
elif isinstance(lhs, (PythonTuple, PythonList)):
if isinstance(rhs, FunctionCall):
new_lhs = []
for i,(l,r) in enumerate(zip(lhs, rhs.funcdef.results.var)):
d = self._infer_type(r)
new_lhs.append( self._assign_lhs_variable(l, d, r, new_expressions,
arr_in_multirets=r.rank>0 ) )
if not isinstance(rhs.class_type, InhomogeneousTupleType):
rhs_var = self.scope.get_temporary_variable(rhs.funcdef.results.var)
new_expressions.append(Assign(rhs_var, rhs))
rhs = rhs_var
lhs = PythonTuple(*new_lhs)
elif isinstance(rhs, PyccelFunction):
assert isinstance(rhs.class_type, InhomogeneousTupleType)
r_iter = [self.scope.collect_tuple_element(v) for v in rhs]
new_lhs = []
for i,(l,r) in enumerate(zip(lhs, r_iter)):
d = self._infer_type(r)
new_lhs.append( self._assign_lhs_variable(l, d, r, new_expressions,
arr_in_multirets=r.rank>0 ) )
lhs = PythonTuple(*new_lhs)
else:
if isinstance(rhs.class_type, InhomogeneousTupleType):
r_iter = [self.scope.collect_tuple_element(v) for v in rhs]
else:
r_iter = rhs
body = []
for i,(l,r) in enumerate(zip(lhs,r_iter)):
pyccel_stage.set_stage('syntactic')
local_assign = Assign(l, r, python_ast = expr.python_ast)
pyccel_stage.set_stage('semantic')
body.append(self._visit(local_assign))
return CodeBlock(body)
else:
lhs = self._visit(lhs)
if not isinstance(lhs, (list, tuple)):
lhs = [lhs]
if isinstance(d_var,dict):
d_var = [d_var]
if len(lhs) == 1:
lhs = lhs[0]
if isinstance(lhs, Variable):
is_pointer = lhs.is_alias
elif isinstance(lhs, (IndexedElement, DictGetItem)):
is_pointer = False
elif isinstance(lhs, (PythonTuple, PythonList)):
is_pointer = any(l.is_alias for l in lhs if isinstance(lhs, Variable))
else:
raise NotImplementedError()
# TODO: does is_pointer refer to any/all or last variable in list (currently last)
is_pointer = is_pointer and isinstance(rhs, (Variable, Duplicate))
is_pointer = is_pointer or isinstance(lhs, Variable) and lhs.is_alias
lhs = [lhs]
rhs = [rhs]
# Split into multiple Assigns to ensure AliasAssign is used where necessary
unravelling = True
while unravelling:
unravelling = False
new_lhs = []
new_rhs = []
for l,r in zip(lhs, rhs):
# Split assign (e.g. for a,b = 1,c)
if (isinstance(l.class_type, InhomogeneousTupleType) or isinstance(l, PythonTuple)) \
and not isinstance(r, (FunctionCall, PyccelFunction)):
new_lhs.extend(self.scope.collect_tuple_element(v) for v in l)
new_rhs.extend(self.scope.collect_tuple_element(r[i]) \
for i in range(l.shape[0]))
# Repeat step to handle tuples of tuples of etc.
unravelling = True
elif isinstance(l, Variable) and isinstance(l.class_type, InhomogeneousTupleType):
new_lhs.append(PythonTuple(*self.scope.collect_all_tuple_elements(l)))
new_rhs.append(r)
# Repeat step to handle tuples of tuples of etc.
unravelling = True
elif isinstance(l, Variable) and isinstance(r.class_type, InhomogeneousTupleType):
new_lhs.extend(l[i] for i in range(len(r.class_type)))
new_rhs.extend(self.scope.collect_tuple_element(ri) for ri in r)
# Repeat step to handle tuples of tuples of etc.
unravelling = True
elif l is not r:
# Manage a non-tuple assignment
# Manage memory for optionals
if isinstance(l, Variable) and l.is_optional:
if l in self._optional_params:
# Collect temporary variable which provides
# allocated memory space for this optional variable
new_lhs.append(self._optional_params[l])
else:
# Create temporary variable to provide allocated
# memory space before assigning to the pointer value
# (may be NULL)
tmp_var = self.scope.get_temporary_variable(l,
name = l.name+'_loc', is_optional = False,
is_argument = False)
self._optional_params[l] = tmp_var
l = tmp_var
if isinstance(r, ConstructorCall):
# Manage a ConstructorCall in a tuple assignment.
# In this case a temporary variable is created which must be
# replaced with the tuple element.
cls_var = r.cls_variable
if cls_var.is_temp:
r.substitute(cls_var, l)
self._allocs[-1].remove(cls_var)
self.scope.remove_variable(cls_var)
self._allocs[-1].add(l)
new_expressions.append(r)
else:
new_lhs.append(l)
new_rhs.append(r)
lhs = new_lhs
rhs = new_rhs
# Examine each assign and determine assign type (Assign, AliasAssign, etc)
for l, r in zip(lhs,rhs):
if isinstance(l, PythonTuple):
for li in l:
if li.is_const:
# If constant (can't use annotations on tuple assignment)
errors.report("Cannot modify variable marked as Final",
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
symbol=li, severity='error')
else:
if getattr(l, 'is_const', False) and (not isinstance(expr.lhs, AnnotatedPyccelSymbol) or \
any(not isinstance(u, (Allocate, PyccelArrayShapeElement)) for u in l.get_all_user_nodes())):
# If constant and not the initialising declaration of a constant variable
errors.report("Cannot modify variable marked as Final",
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
symbol=l, severity='error')
if isinstance(expr, AugAssign):
new_expr = AugAssign(l, expr.op, r)
else:
is_pointer_i = l.is_alias if isinstance(l, Variable) else is_pointer
new_expr = Assign(l, r)
if is_pointer_i:
if isinstance(r, FunctionCall):
funcdef = r.funcdef
target_r_idx = funcdef.result_pointer_map[funcdef.results.var]
for ti in target_r_idx:
self._indicate_pointer_target(l, r.args[ti].value, expr)
new_expr = AliasAssign(l, r)
else:
self._indicate_pointer_target(l, r, expr)
if not isinstance(r.class_type, NumpyNDArrayType) and not isinstance(r, Variable):
mem_var = get_managed_memory_object(l)
new_expr = UnpackManagedMemory(l, r, mem_var)
else:
new_expr = AliasAssign(l, r)
elif isinstance(l.class_type, SymbolicType):
errors.report(PYCCEL_RESTRICTION_TODO,
symbol=expr,
severity='fatal')
elif isinstance(r, (PythonList, PythonSet, PythonTuple, PythonDict)):
self._indicate_pointer_target(l, r, expr)
new_expressions.append(new_expr)
if expr.lhs == '__all__':
self.scope.remove_variable(lhs[0])
self._allocs[-1].discard(lhs[0])
if isinstance(lhs[0].class_type, HomogeneousListType):
# Remove the last element of the errors (if it is a warning)
# This will be the list of list warning
try:
error_info_map = errors.error_info_map[os.path.basename(errors.target)]
if error_info_map[-1].severity == 'warning':
error_info_map.pop()
except KeyError:
# There may be a KeyError if this is not the first time that this DataType
# of list of rank>0 is created.
pass
return AllDeclaration(new_expressions[-1].rhs)
if (len(new_expressions)==1):
new_expressions = new_expressions[0]
return new_expressions
else:
result = CodeBlock(new_expressions)
return result
def _visit_AugAssign(self, expr):
lhs = self._visit(expr.lhs)
if lhs.is_const:
errors.report("Cannot modify variable marked as Final",
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
symbol=lhs, severity='error')
rhs = self._visit(expr.rhs)
operator = expr.pyccel_operator
new_expressions = []
try:
test_node = operator(lhs, rhs)
except TypeError:
test_node = None
if test_node:
lhs.remove_user_node(test_node, invalidate = False)
if test_node in rhs.get_all_user_nodes():
rhs.remove_user_node(test_node, invalidate = False)
else:
assert isinstance(rhs.current_user_node, PyccelAssociativeParenthesis)
mid = rhs.current_user_node
rhs.remove_user_node(mid, invalidate=False)
lhs = self._assign_lhs_variable(expr.lhs, self._infer_type(test_node), test_node,
new_expressions, is_augassign = True)
lhs = self._optional_params.get(lhs, lhs)
aug_assign = AugAssign(lhs, expr.op, rhs)
else:
magic_method_name = magic_method_map[operator]
increment_magic_method_name = '__i' + magic_method_name[2:]
class_type = lhs.class_type
class_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)
increment_magic_method = class_base.get_method(increment_magic_method_name)
args = [FunctionCallArgument(lhs), FunctionCallArgument(rhs)]
if increment_magic_method:
lhs = self._optional_params.get(lhs, lhs)
return self._handle_function(expr, increment_magic_method, args)
magic_method = class_base.get_method(magic_method_name, expr)
operator_node = self._handle_function(expr, magic_method, args)
lhs = self._assign_lhs_variable(expr.lhs, self._infer_type(operator_node), test_node,
new_expressions, is_augassign = True)
lhs = self._optional_params.get(lhs, lhs)
aug_assign = Assign(lhs, operator_node)
if new_expressions:
return CodeBlock(new_expressions + [aug_assign])
else:
return aug_assign
def _visit_For(self, expr):
scope = self.create_new_loop_scope()
new_expr = []
# treatment of the index/indices
target, iterable = self._get_for_iterators(expr.iterable, expr.target, new_expr, expr)
body = self._visit(expr.body)
self.exit_loop_scope()
if isinstance(iterable, Product):
for_expr = body
scopes = self.scope.create_product_loop_scope(scope, len(target))
for t, i, r, s in zip(target[::-1], iterable.loop_counters[::-1], iterable.get_python_iterable_item()[::-1], scopes[::-1]):
# Create Variable iterable
loop_iter = VariableIterator(r.base)
loop_iter.set_loop_counter(i)
# Create a For loop for each level of the Product
for_expr = For((t,), loop_iter, for_expr, scope=s)
for_expr.end_annotation = expr.end_annotation
for_expr = [for_expr]
for_expr = for_expr[0]
else:
for_expr = For(target, iterable, body, scope=scope)
for_expr.end_annotation = expr.end_annotation
return for_expr
def _visit_FunctionalFor(self, expr):
"""
Visit and transform a FunctionalFor AST node into an equivalent code block.
This method processes a `FunctionalFor` expression and transforms the loop structure
into a corresponding code block.
Parameters
----------
expr : pyccel.ast.functionalexpr.FunctionalFor
The FunctionalFor AST node.
Returns
-------
pyccel.ast.basic.CodeBlock
A code block containing the equivalent loops and necessary variable allocations for the given `FunctionalFor` expression.
"""
target = expr.expr
indices = []
dims = []
idx_subs = {}
tmp_used_names = self.scope.all_used_symbols.copy()
i = 0
loops = list(expr.loops)
# Inner function to handle PythonNativeInt variables
def handle_int_loop_variable(var_name, var_scope):
indices.append(var_name)
var = self._create_variable(var_name, PythonNativeInt(), None, {}, insertion_scope=var_scope)
return var
# Inner function to handle iterable variables
def handle_iterable_variable(var_name, element, var_scope):
indices.append(var_name)
dvar = self._infer_type(element)
class_type = dvar.pop('class_type')
if class_type.rank > 0:
class_type = class_type.switch_rank(class_type.rank - 1)
dvar['shape'] = dvar['shape'][1:]
if class_type.rank == 0:
dvar['shape'] = None
dvar['memory_handling'] = 'stack'
var = self._create_variable(var_name, class_type, None, dvar, insertion_scope=var_scope)
return var
for loop, condition in zip(loops, expr.conditions):
if condition:
loop.insert2body(condition)
while len(loops) > 1:
outer_loop = loops.pop()
inserted_into = loops[-1]
if inserted_into.body.body:
inserted_into.body.body[0].blocks[0].body.insert2body(outer_loop)
else:
inserted_into.insert2body(outer_loop)
body = loops[0]
while isinstance(body, (For, If)):
if isinstance(body, If):
body = None if not body.blocks[0].body.body else body.blocks[0].body.body[0]
continue
stop = None
start = LiteralInteger(0)
step = LiteralInteger(1)
variables = []
a = self._get_iterable(self._visit(body.iterable))
if isinstance(a, PythonRange):
var_name = self.scope.get_expected_name(expr.indices[i])
variables.append(handle_int_loop_variable(var_name, body.scope))
start = a.start
stop = a.stop
step = a.step
elif isinstance(a, PythonEnumerate):
var_name1 = self.scope.get_expected_name(expr.indices[i][0])
var_name2 = self.scope.get_expected_name(expr.indices[i][1])
variables.append(handle_int_loop_variable(var_name1, body.scope))
variables.append(handle_iterable_variable(var_name2, a.element, body.scope))
stop = a.element.shape[0]
elif isinstance(a, PythonZip):
for idx, arg in enumerate(a.args):
var = self.scope.get_expected_name(expr.indices[i][idx])
variables.append(handle_iterable_variable(var, arg, body.scope))
stop = a.get_range().stop
elif isinstance(a, VariableIterator):
var = self.scope.get_expected_name(expr.indices[i])
variables.append(handle_iterable_variable(var, a.variable, body.scope))
stop = a.variable.shape[0]
else:
errors.report(PYCCEL_RESTRICTION_TODO,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
for var in variables:
existing_var = self.scope.find(var.name, 'variables')
if var.name == expr.lhs:
errors.report(f"Variable {var} has the same name as the left hand side",
symbol = expr, severity='fatal')
if existing_var or var.name == expr.lhs:
if self._infer_type(existing_var)['class_type'] != var.class_type:
errors.report(f"Variable {var} already exists with different type",
symbol = expr, severity='fatal')
else:
self.scope.insert_variable(var)
step = pyccel_to_sympy(step , idx_subs, tmp_used_names)
start = pyccel_to_sympy(start, idx_subs, tmp_used_names)
stop = pyccel_to_sympy(stop , idx_subs, tmp_used_names)
size = (stop - start) / step
if (step != 1):
size = ceiling(size)
body = None if not body.body.body else body.body.body[0]
dims.append((size, step, start, stop))
i += 1
for idx in indices:
var = self.get_variable(idx)
idx_subs[idx] = var
sp_indices = [sp_Symbol(i) for i in indices]
dim = sp_Integer(1)
for i in reversed(range(len(dims))):
size = dims[i][0]
step = dims[i][1]
start = dims[i][2]
stop = dims[i][3]
# For complicated cases we must ensure that the upper bound is never smaller than the
# lower bound as this leads to too little memory being allocated
min_size = size
# Collect all uses of other indices
start_idx = [-1] + [sp_indices.index(a) for a in start.atoms(sp_Symbol) if a in sp_indices]
stop_idx = [-1] + [sp_indices.index(a) for a in stop.atoms(sp_Symbol) if a in sp_indices]
start_idx.sort()
stop_idx.sort()
# Find the minimum size
while max(len(start_idx),len(stop_idx))>1:
# Use the maximum value of the start
if start_idx[-1] > stop_idx[-1]:
s = start_idx.pop()
min_size = min_size.subs(sp_indices[s], dims[s][3])
# and the minimum value of the stop
else:
s = stop_idx.pop()
min_size = min_size.subs(sp_indices[s], dims[s][2])
# While the min_size is not a known integer, assume that the bounds are positive
j = 0
while not isinstance(min_size, sp_Integer) and j<=i:
min_size = min_size.subs(dims[j][3]-dims[j][2], 1).simplify()
j+=1
# If the min_size is negative then the size will be wrong and an error is raised
if isinstance(min_size, sp_Integer) and min_size < 0:
errors.report(PYCCEL_RESTRICTION_LIST_COMPREHENSION_LIMITS.format(indices[i]),
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='error')
# sympy is necessary to carry out the summation
dim = dim.subs(sp_indices[i], start+step*sp_indices[i])
dim = Summation(dim, (sp_indices[i], 0, size-1))
dim = dim.doit()
try:
dim = sympy_to_pyccel(dim, idx_subs)
except TypeError:
errors.report(PYCCEL_RESTRICTION_LIST_COMPREHENSION_SIZE + f'\n Deduced size : {dim}',
symbol=expr,
severity='fatal')
target = self._visit(target)
d_var = self._infer_type(target)
class_type = d_var['class_type']
if class_type is GenericType():
errors.report(LIST_OF_TUPLES,
bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
severity='fatal')
d_var['memory_handling'] = 'heap'
target_type_name = 'list' if not expr.target_type else expr.target_type
if isinstance(target_type_name, DottedName):
lhs = target_type_name.name[0] if len(target_type_name.name) == 2 \
else DottedName(*target_type_name.name[:-1])
first = self._visit(lhs)
if isinstance(first, Module):
conversion_func = first[target_type_name.name[-1]]
else:
conversion_func = None
else:
conversion_func = self.scope.find(target_type_name, 'functions')
if conversion_func is None:
if target_type_name in builtin_functions_dict:
conversion_func = PyccelFunctionDef(target_type_name,
builtin_functions_dict[target_type_name])
if conversion_func is None:
errors.report("Unrecognised output type from functional for.\n"+PYCCEL_RESTRICTION_TODO,
symbol=expr,
severity='fatal')
if target_type_name != 'list' and any(cond is not None for cond in expr.conditions):
errors.report("Cannot handle if statements in list comprehensions if lhs is a numpy array.\
List length cannot be calculated.\n" + PYCCEL_RESTRICTION_TODO,
symbol=expr, severity='error')
try:
class_type = type_container[conversion_func.cls_name](class_type)
except TypeError:
if class_type.rank > 0:
errors.report("ND comprehension expressions cannot be saved directly to an array yet.\n"+PYCCEL_RESTRICTION_TODO,
symbol=expr,
severity='fatal')
class_type = type_container[conversion_func.cls_name](numpy_process_dtype(class_type), rank=1, order=None)
d_var['class_type'] = class_type
d_var['shape'] = (dim,)
d_var['cls_base'] = get_cls_base(class_type)
# ...
# TODO [YG, 30.10.2020]:
# - Check if we should allow the possibility that is_stack_array=True
# ...
lhs_symbol = expr.lhs
ne = []
lhs = self._assign_lhs_variable(lhs_symbol, d_var, rhs=expr, new_expressions=ne)
lhs_alloc = ne[0]
if isinstance(target, PythonTuple) and not target.is_homogeneous:
errors.report(LIST_OF_TUPLES, symbol=expr, severity='error')
target.invalidate_node()
operations = []
assign = None
target_conversion_func = self._visit(target_type_name)
if (isinstance(target_conversion_func, PyccelFunctionDef)
and target_conversion_func.cls_name is NumpyArray):
old_index = expr.index
new_index = self.scope.get_new_name()
expr.substitute(old_index, new_index, is_equivalent = lambda x,y: x is y)
array_ops = expr.operations['numpy_array']
for operation in array_ops:
operation.substitute(old_index, new_index, is_equivalent = lambda x,y: x is y)
assign = array_ops[0]
assign = self._visit(array_ops[0])
operations.extend(array_ops[1:])
index = new_index
index = self._visit(index)
elif target_conversion_func == "'list'":
index = None
operations.extend(expr.operations['list'])
else:
errors.report("Unrecognised target for functional for.\n"+PYCCEL_RESTRICTION_TODO,
symbol=expr,
severity='fatal')
if expr.loops[-1].body.body:
for operation in operations:
expr.loops[-1].body.body[0].blocks[0].body.insert2body(self._visit(operation))
else:
for operation in operations:
expr.loops[-1].insert2body(self._visit(operation))
loops = [self._visit(i) for i in loops]
if assign:
loops = [assign, *loops]
l = loops[-1]
cnt = 0
for idx in indices:
assert isinstance(l, For)
if idx.is_temp:
self.scope.remove_variable(l.target[cnt])
l.substitute(l.target[cnt], idx_subs[idx])
cnt += 1
if cnt == len(l.target):
if l.body.body:
if isinstance(l.body.body[0], If):
l = l.body.body[0].blocks[0].body.body[0]
else:
l = l.body.body[0]
cnt = 0
return CodeBlock([lhs_alloc, FunctionalFor(loops, lhs=lhs, index=index, indices=expr.indices, target_type=target_type_name, conditions=expr.conditions)])
def _visit_GeneratorComprehension(self, expr):
lhs = self.check_for_variable(expr.lhs)
if lhs is None:
pyccel_stage.set_stage('syntactic')
if expr.lhs.is_temp:
lhs = PyccelSymbol(self.scope.get_new_name(), is_temp=True)
else:
lhs = expr.lhs
syntactic_assign = Assign(lhs, expr, python_ast=expr.python_ast)
pyccel_stage.set_stage('semantic')
creation = self._visit(syntactic_assign)
self._additional_exprs[-1].append(creation)
return self.get_variable(lhs)
else:
return lhs
def _visit_While(self, expr):
scope = self.create_new_loop_scope()
test = self._visit(expr.test)
body = self._visit(expr.body)
self.exit_loop_scope()
return While(test, body, scope=scope)
def _visit_IfSection(self, expr):
condition = expr.condition
cond = self._visit(expr.condition)
symbol_map = {}
used_names = self.scope.all_used_symbols.copy()
try:
sympy_cond = pyccel_to_sympy(cond, symbol_map, used_names)
except TypeError:
sympy_cond = 'unknown'
if sympy_cond == sp_False():
return IfSection(LiteralFalse(), CodeBlock([]))
elif sympy_cond == sp_True():
cond = LiteralTrue()
if cond.dtype is not PythonNativeBool():
cond = PythonBool(cond)
cond.set_current_ast(cond.python_ast or expr.python_ast)
body = self._visit(expr.body)
if_block = expr.get_direct_user_nodes(lambda u: isinstance(u, If))[0]
is_last_block = expr is if_block.blocks[-1]
def treat_condition(cond, body):
"""
Run through the condition of the If to try to extract `if a is not None`
conditions. These must be done in their own line in low-level languages.
"""
is_not_conds = cond.get_attribute_nodes(PyccelIsNot)
non_conditional_list = [c for c in is_not_conds if c.args[1] is Nil()]
for non_conditional in non_conditional_list:
v = non_conditional.args[0]
var_use = v.get_direct_user_nodes(cond.is_user_of)
# If variable is only used in `a is not None` the condition is ok
if len(var_use) > 1:
# If `a is not None` is in an `and` we can split this into valid conditions
if isinstance(cond, PyccelAnd) and non_conditional in cond.args:
remaining_cond = PyccelAnd(*[a for a in cond.args if a is not non_conditional]) \
if len(cond.args) > 2 else next(a for a in cond.args if a is not non_conditional)
remaining_cond.set_current_ast(cond.python_ast)
cond_var = self.scope.get_temporary_variable(PythonNativeBool(),
self.scope.get_new_name('condition'))
treated_remaining_cond, body = treat_condition(remaining_cond, body)
if is_last_block:
# if in the last block create an if in the current if
body = [If(IfSection(treated_remaining_cond, body))]
cond = non_conditional
else:
# Otherwise evaluate the condition before the if block
self._additional_exprs[-1].append(If(IfSection(non_conditional, [Assign(cond_var, treated_remaining_cond)]),
IfSection(LiteralTrue(), [Assign(cond_var, LiteralFalse())])))
cond = cond_var
return cond, body
else:
errors.report("Cannot evaluate condition. Checking if a variable is present must be done before using the variable",
severity='error', symbol=cond)
return cond, body
return IfSection(*treat_condition(cond, body))
def _visit_If(self, expr):
args = []
for b in expr.blocks:
new_b = self._visit(b)
cond = new_b.condition
if not isinstance(cond, LiteralFalse):
args.append(new_b)
if isinstance(cond, LiteralTrue):
if len(args) == 1:
return new_b.body
break
allocations = [arg.get_attribute_nodes(Allocate) for arg in args]
var_shapes = [{a.variable : a.shape for a in allocs} for allocs in allocations]
variables = [v for branch in var_shapes for v in branch]
for v in variables:
all_shapes_set = all(v in branch_shapes.keys() for branch_shapes in var_shapes)
if all_shapes_set:
shape_branch1 = var_shapes[0][v]
same_shapes = all(shape_branch1==branch_shapes[v] \
for branch_shapes in var_shapes[1:])
else:
same_shapes = False
if not same_shapes:
v.set_changeable_shape()
return If(*args)
def _visit_IfTernaryOperator(self, expr):
value_true = self._visit(expr.value_true)
if value_true.rank > 0 or value_true.dtype is StringType():
lhs = PyccelSymbol(self.scope.get_new_name(), is_temp=True)
# Temporarily deactivate type checks to construct syntactic assigns
pyccel_stage.set_stage('syntactic')
assign_true = Assign(lhs, expr.value_true, python_ast = expr.python_ast)
assign_false = Assign(lhs, expr.value_false, python_ast = expr.python_ast)
pyccel_stage.set_stage('semantic')
cond = self._visit(expr.cond)
true_section = IfSection(cond, [self._visit(assign_true)])
false_section = IfSection(LiteralTrue(), [self._visit(assign_false)])
self._additional_exprs[-1].append(If(true_section, false_section))
return self._visit(lhs)
else:
cond = self._visit(expr.cond)
value_false = self._visit(expr.value_false)
if isinstance(cond, LiteralTrue):
return value_true
elif isinstance(cond, LiteralFalse):
return value_false
else:
return IfTernaryOperator(cond, value_true, value_false)
def _visit_Return(self, expr):
results = expr.expr
f_name = self.current_function_name
if isinstance(f_name, DottedName):
f_name = f_name.name[-1]
func = self._current_function[-1]
original_name = self.scope.get_python_name(f_name)
if original_name.startswith('__i') and ('__'+original_name[3:]) in magic_method_map.values():
valid_return = isinstance(expr.expr, PyccelSymbol) and expr.stmt is None and len(func.arguments) > 0
if valid_return:
out = self._visit(expr.expr)
expected = func.arguments[0].var
valid_return &= (out == expected)
if valid_return:
return EmptyNode()
else:
errors.report("Increment functions must return the class instance",
severity='fatal', symbol=expr)
return_objs = func.results
return_var = return_objs.var
if isinstance(return_var, (AnnotatedPyccelSymbol, Variable)):
return_var = return_var.name
assigns = []
if return_var != results:
# Create a syntactic object to visit
pyccel_stage.set_stage('syntactic')
syntactic_assign = Assign(return_var, results, python_ast=expr.python_ast)
pyccel_stage.set_stage('semantic')
a = self._visit(syntactic_assign)
if not isinstance(a, Assign) or a.lhs != a.rhs:
assigns.append(a)
if isinstance(a, ConstructorCall):
a.cls_variable.is_temp = False
else:
a.invalidate_node()
results = self._visit(return_var)
# add the Deallocate node before the Return node and eliminating the Deallocate nodes
# the arrays that will be returned.
results_vars = self.scope.collect_all_tuple_elements(results)
self._check_pointer_targets(results_vars)
code = assigns + [Deallocate(i) for i in self._allocs[-1] if i not in results_vars]
if results is Nil():
results = None
if code:
expr = Return(results, CodeBlock(code))
else:
expr = Return(results)
return expr
def _visit_FunctionDef(self, expr):
"""
Semantically analyse the FunctionDef.
Analyse the FunctionDef adding all necessary semantic information.
Parameter
---------
expr : FunctionDef|Interface
The node that needs to be annotated.
If we provide an Interface, this means that the function has been annotated partially,
and we need to continue annotating the needed ones.
"""
if expr.get_direct_user_nodes(lambda u: isinstance(u, CodeBlock)):
errors.report("Functions can only be declared in modules, classes or inside other functions.",
symbol=expr, severity='error')
current_class = expr.get_direct_user_nodes(lambda u: isinstance(u, ClassDef))
cls_name = current_class[0].name if current_class else None
insertion_scope = self.scope
if cls_name:
bound_class = self.scope.find(cls_name, 'classes', raise_if_missing = True)
insertion_scope = bound_class.scope
existing_semantic_funcs = []
if not expr.is_semantic:
name = expr.scope.get_expected_name(expr.name)
func = insertion_scope.functions.get(name, None)
if func:
if func.is_semantic:
if self.is_header_file:
# Only Interfaces should be revisited in a header file
assert isinstance(func, Interface)
existing_semantic_funcs = [*func.functions]
else:
return EmptyNode()
else:
insertion_scope.functions.pop(name)
elif isinstance(expr, Interface):
existing_semantic_funcs = [*expr.functions]
expr.invalidate_node()
expr = expr.syntactic_node
name = expr.scope.get_expected_name(expr.name)
decorators = expr.decorators.copy()
new_semantic_funcs = []
sub_funcs = []
func_interfaces = []
docstring = self._visit(expr.docstring) if expr.docstring else expr.docstring
is_pure = expr.is_pure
is_elemental = expr.is_elemental
is_private = expr.is_private
is_inline = expr.is_inline
not_used = [d for d in decorators if d not in (*def_decorators.__all__, 'property', 'overload')]
if len(not_used) >= 1:
errors.report(UNUSED_DECORATORS, symbol=', '.join(not_used), severity='warning')
available_type_vars = {n:v for n,v in self._context_dict.items() if isinstance(v, typing.TypeVar)}
available_type_vars.update(self.scope.collect_all_type_vars())
used_type_vars = {}
for a in expr.arguments:
used_objs = a.annotation.get_attribute_nodes(PyccelSymbol)
for o in used_objs:
if o in available_type_vars:
used_type_vars[o] = available_type_vars[o]
for o, t in used_type_vars.items():
if isinstance(t, typing.TypeVar):
pyccel_type_var = self.env_var_to_pyccel(t)
used_type_vars[o] = pyccel_type_var
global_scope = self.scope
while global_scope.parent_scope:
global_scope = global_scope.parent_scope
global_scope.insert_symbol(o)
global_scope.insert_symbolic_alias(o, pyccel_type_var)
possible_combinations = list(product(*[t.type_list for t in used_type_vars.values()]))
argument_combinations = []
type_var_indices = []
for i,p in enumerate(possible_combinations):
scope = self.create_new_function_scope(expr.name, '_', decorators = decorators,
used_symbols = expr.scope.local_used_symbols.copy(),
original_symbols = expr.scope.python_names.copy(),
symbolic_aliases = expr.scope.symbolic_aliases)
for n, dtype in zip(used_type_vars, p):
self.scope.insert_symbolic_alias(n, dtype)
args = list(product(*[self._visit(a) for a in expr.arguments]))
argument_combinations.extend(args)
type_var_indices.extend([i]*len(args))
self.exit_function_scope()
# this for the case of a function without arguments => no headers
interface_name = name
interface_counter = 0
is_interface = len(argument_combinations) > 1 or 'overload' in decorators
for interface_idx, (arguments, type_var_idx) in enumerate(zip(argument_combinations, type_var_indices)):
if is_interface:
name, _ = self.scope.get_new_incremented_symbol(interface_name, interface_idx)
insertion_scope.python_names[name] = expr.name
scope = self.create_new_function_scope(expr.name, name, decorators = decorators,
used_symbols = expr.scope.local_used_symbols.copy(),
original_symbols = expr.scope.python_names.copy(),
symbolic_aliases = expr.scope.symbolic_aliases)
self.scope.decorators.update(decorators)
for n, dtype in zip(used_type_vars, possible_combinations[type_var_idx]):
self.scope.insert_symbolic_alias(n, dtype)
arg_dict = {a.name:a.var for a in arguments}
for a in arguments:
a_var = a.var
if isinstance(a_var, FunctionAddress):
self.insert_function(a_var)
else:
self.scope.insert_variable(a_var, expr.scope.get_python_name(a.name))
if arguments and arguments[0].bound_argument:
if arguments[0].var.cls_base.name != cls_name:
errors.report('Class method self argument does not have the expected type',
severity='error', symbol=arguments[0])
for s in expr.scope.dotted_symbols:
base = s.name[0]
if base in arg_dict:
cls_base = arg_dict[base].cls_base
cls_base.scope.insert_symbol(DottedName(*s.name[1:]))
results = expr.results
if results.annotation:
results = self._visit(expr.results)
# insert the FunctionDef into the scope
# to handle the case of a recursive function
# TODO improve in the case of an interface
recursive_func_obj = FunctionDef(name, arguments, [], results, scope = scope)
self.insert_function(recursive_func_obj, insertion_scope)
# Create a new list that store local variables for each FunctionDef to handle nested functions
self._allocs.append(set())
self._pointer_targets.append({})
import_init_calls = [self._visit(i) for i in expr.imports]
for f in expr.functions:
self.insert_function(f)
# we annotate the body
body = self._visit(expr.body)
body.insert2body(*import_init_calls, back=False)
# Annotate the remaining functions
sub_funcs = [i for i in self.scope.functions.values() if not i.is_header and\
not isinstance(i, (InlineFunctionDef, FunctionAddress)) and \
not i.is_semantic]
for i in sub_funcs:
self._visit(i)
results = self._visit(results)
if isinstance(results, EmptyNode):
results = FunctionDefResult(Nil())
if results.var is Nil():
results_vars = []
else:
results_vars = self.scope.collect_all_tuple_elements(results.var)
self._check_pointer_targets(results_vars)
# Calling the Garbage collecting,
# it will add the necessary Deallocate nodes
# to the body of the function
body.insert2body(*self._garbage_collector(body))
# Determine local and global variables
global_vars = list(self.get_variables(self.scope.parent_scope))
global_vars = [g for g in global_vars if body.is_user_of(g)]
# get the imports
imports = self.scope.imports['imports'].values()
# Prefer dict to set to preserve order
imports = list({imp:None for imp in imports}.keys())
# remove the FunctionDef from the function scope
func_ = insertion_scope.functions.pop(name)
is_recursive = False
# check if the function is recursive if it was called on the same scope
if func_.is_recursive and not is_inline:
is_recursive = True
elif func_.is_recursive and is_inline:
errors.report("Pyccel does not support an inlined recursive function", symbol=expr,
severity='fatal')
sub_funcs = [i for i in self.scope.functions.values() if not i.is_header and not isinstance(i, FunctionAddress)]
func_args = [i for i in self.scope.functions.values() if isinstance(i, FunctionAddress)]
if func_args:
func_interfaces.append(Interface('', func_args, is_argument = True))
namespace_imports = self.scope.imports
self.exit_function_scope()
# Raise an error if one of the return arguments is an alias.
pointer_targets = self._pointer_targets.pop()
result_pointer_map = {}
for r in results_vars:
t = pointer_targets.get(r, ())
if r.is_alias:
arg_vars = [a.var for a in arguments]
temp_targets = [target for target, _ in t if target not in arg_vars]
if temp_targets:
errors.report(UNSUPPORTED_POINTER_RETURN_VALUE,
symbol=r, severity='error')
else:
result_pointer_map[r] = [next(i for i,a in enumerate(arguments) if a.var == target) for target, _ in t]
optional_inits = []
for a in arguments:
var = self._optional_params.pop(a.var, None)
if var:
optional_inits.append(If(IfSection(PyccelIsNot(a.var, Nil()),
[Assign(var, a.var)])))
body.insert2body(*optional_inits, back=False)
func_kwargs = {
'global_vars':global_vars,
'is_pure':is_pure,
'is_elemental':is_elemental,
'is_private':is_private,
'imports':imports,
'decorators':decorators,
'is_recursive':is_recursive,
'functions': sub_funcs,
'interfaces': func_interfaces,
'result_pointer_map': result_pointer_map,
'docstring': docstring,
'scope': scope,
}
if is_inline:
func_kwargs['namespace_imports'] = namespace_imports
global_funcs = [f for f in body.get_attribute_nodes(FunctionDef) if self.scope.find(f.name, 'functions')]
func_kwargs['global_funcs'] = global_funcs
cls = InlineFunctionDef
else:
cls = FunctionDef
func = cls(name,
arguments,
body,
results,
**func_kwargs)
if not is_recursive:
recursive_func_obj.invalidate_node()
if cls_name:
# update the class methods
if not is_interface:
bound_class.update_method(expr, func)
new_semantic_funcs += [func]
if expr.python_ast:
func.set_current_ast(expr.python_ast)
if existing_semantic_funcs:
new_semantic_funcs = existing_semantic_funcs + new_semantic_funcs
if len(new_semantic_funcs) == 1 and not is_interface:
new_semantic_funcs = new_semantic_funcs[0]
self.insert_function(new_semantic_funcs, insertion_scope)
else:
for f in new_semantic_funcs:
self.insert_function(f, insertion_scope)
new_semantic_funcs = Interface(interface_name, new_semantic_funcs, syntactic_node=expr)
if expr.python_ast:
new_semantic_funcs.set_current_ast(expr.python_ast)
if cls_name:
bound_class.update_interface(expr, new_semantic_funcs)
self.insert_function(new_semantic_funcs, insertion_scope)
return EmptyNode()
def _visit_InlineFunctionDef(self, expr, function_call_args, function_call):
"""
Visit an inline function definition to add the code to the calling scope.
Visit an inline function definition to add the code to the calling scope.
The code is inlined at this stage.
Parameters
----------
expr : InlineFunctionDef
The inline function definition being called.
function_call_args : list[FunctionDefArgument]
The semantic arguments passed to the function.
function_call : FunctionCall
The syntactic function call being expanded to a function definition.
"""
assign = function_call.get_direct_user_nodes(lambda a: isinstance(a, Assign) and not isinstance(a, AugAssign))
self._current_function.append(expr)
if assign:
lhs = assign[-1].lhs
else:
lhs = self.scope.get_new_name()
# Build the syntactic body
replace_map = {}
pyccel_stage.set_stage('syntactic')
global_scope_import_targets = {}
if expr.is_imported:
mod_name = expr.get_direct_user_nodes(lambda m: isinstance(m, Module))[0].name
mod = self.d_parsers[mod_name].semantic_parser.ast
global_symbols = set(expr.body.get_attribute_nodes(PyccelSymbol))
global_symbols.difference_update(expr.scope.local_used_symbols)
for v in global_symbols:
import_mod_name = mod_name
if mod.scope.find(v):
imported_obj = None
for import_type in mod.scope.imports.values():
if v in import_type:
imported_obj = import_type[v]
break
if imported_obj:
import_mod_name = imported_obj.get_direct_user_nodes(lambda m: isinstance(m, Module))[0].name
if self.scope.symbol_in_use(v):
new_v = self.scope.get_new_name(self.scope.get_expected_name(v))
replace_map[v] = new_v
global_scope_import_targets.setdefault(import_mod_name, []).append(AsName(v, new_v))
else:
global_scope_import_targets.setdefault(import_mod_name, []).append(v)
# Swap in the function call arguments to replace the variables representing
# the arguments of the inlined function
res_vars = ()
if expr.results:
# Swap in the result of the function to replace the variable representing
# the result of the inlined function
res_var = expr.results.var
if isinstance(res_var, AnnotatedPyccelSymbol):
res_var = res_var.name
if isinstance(lhs, PyccelSymbol):
replace_map[res_var] = lhs
res_vars = (res_var,)
func_args = [a.var for a in expr.arguments]
func_args = [a.name if isinstance(a, AnnotatedPyccelSymbol) else a for a in func_args]
# Ensure local variables will be recognised and use a name that is not already in use
for v in expr.scope.local_used_symbols:
if v != expr.name and v not in res_vars:
if self.scope.symbol_in_use(v):
new_v = self.scope.get_new_name(self.scope.get_expected_name(v))
replace_map[v] = new_v
else:
self.scope.insert_symbol(v)
# Map local call arguments to function arguments
positional_call_args = [a.value for a in function_call_args if not a.has_keyword]
for func_a, call_a in zip(func_args, positional_call_args):
if isinstance(call_a, Variable) and func_a == self.scope.get_expected_name(call_a.name):
# If call argument is a variable with the same name as the target function
# argument then there is no need to rename
new_func_a = replace_map.pop(func_a)
self.scope.remove_symbol(new_func_a)
else:
# Otherwise the symbol used for the function arguments should be mapped
# to the call argument
func_a_name = replace_map.get(func_a, func_a)
self.scope.variables[func_a_name] = call_a
self.scope.local_used_symbols[func_a_name] = func_a_name
# Map local keyword call arguments to function arguments
nargs = len(positional_call_args)
kw_call_args = {a.keyword: a.value for a in function_call_args[nargs:]}
for func_a, func_a_name in zip(expr.arguments[nargs:], func_args[nargs:]):
call_a = kw_call_args.get(func_a_name, getattr(func_a.default_call_arg, 'value', func_a.default_call_arg))
if isinstance(call_a, Variable) and func_a_name == self.scope.get_expected_name(call_a.name):
# If call argument is a variable with the same name as the target function
# argument then there is no need to rename
new_func_a = replace_map.pop(func_a_name)
self.scope.remove_symbol(new_func_a)
else:
# Otherwise the symbol used for the function arguments should be mapped
# to the call argument
used_func_a_name = replace_map.get(func_a_name, func_a_name)
self.scope.variables[used_func_a_name] = call_a
self.scope.local_used_symbols[func_a_name] = func_a_name
to_replace = list(replace_map.keys())
local_var = list(replace_map.values())
# Replace local syntactic variables from the inline functions with the syntactic
# variables defined above which are sure to not cause name collisions
expr.substitute(to_replace, local_var, invalidate = False)
# Replace return expressions with an assign to the results
returns = expr.body.get_attribute_nodes(Return)
replace_return = [Assign(lhs, r.expr, python_ast = r.python_ast) \
if not isinstance(r.expr, PyccelSymbol) or not isinstance(lhs, PyccelSymbol) \
else EmptyNode() for r in returns]
expr.body.substitute(returns, replace_return, invalidate = False)
imports = list(expr.imports)
imports.extend(Import(m_name, targets) for m_name, targets in global_scope_import_targets.items())
pyccel_stage.set_stage('semantic')
import_init_calls = [self._visit(i) for i in imports]
if expr.functions:
errors.report("Functions in inline functions are not supported",
severity='error', symbol=expr)
# Visit the body as though it appeared directly in the code
body = self._visit(expr.body)
body.insert2body(*import_init_calls, back=False)
self._current_function.pop()
pyccel_stage.set_stage('syntactic')
# Put back the returns to create custom Assign nodes on the next visit
expr.body.substitute(replace_return, returns)
# Remove the symbol maps added to handle the function arguments
# These are found in self.scope.variables but do not represent variables
# that need to be declared.
for func_a, call_a in zip(func_args, positional_call_args):
func_a_name = replace_map.get(func_a, func_a)
if not isinstance(call_a, Variable) or func_a_name != call_a.name:
self.scope.remove_variable(call_a, func_a_name)
for func_a, func_a_name in zip(expr.arguments[nargs:], func_args[nargs:]):
if func_a_name in kw_call_args:
used_func_a_name = replace_map.get(func_a_name, func_a_name)
call_a = kw_call_args[func_a_name]
if not isinstance(call_a, Variable) or used_func_a_name != call_a.name:
self.scope.remove_variable(call_a, used_func_a_name)
# Swap the arguments back to the original version to preserve the syntactic
# inline function definition.
expr.substitute(local_var, to_replace)
pyccel_stage.set_stage('semantic')
if assign:
return body
else:
self._additional_exprs[-1].append(body)
return self._visit(lhs)
def _visit_PythonPrint(self, expr):
args = [self._visit(i) for i in expr.expr]
if len(args) == 0:
return PythonPrint(args)
def is_symbolic(var):
return isinstance(var, Variable) \
and isinstance(var.dtype, SymbolicType)
if any(isinstance(a.value.class_type, InhomogeneousTupleType) for a in args):
new_args = []
for a in args:
val = a.value
if isinstance(val.class_type, InhomogeneousTupleType):
assert not a.has_keyword
if isinstance(val, FunctionCall):
pyccel_stage.set_stage('syntactic')
tmp_var = PyccelSymbol(self.scope.get_new_name())
assign = Assign(tmp_var, val)
assign.set_current_ast(expr.python_ast)
pyccel_stage.set_stage('semantic')
self._additional_exprs[-1].append(self._visit(assign))
val.remove_user_node(assign)
val = self._visit(tmp_var)
new_args.append(FunctionCallArgument(self.create_tuple_of_inhomogeneous_elements(val)))
else:
new_args.append(a)
args = new_args
# TODO fix: not yet working because of mpi examples
# if not test:
# # TODO: Add description to parser/messages.py
# errors.report('Either all arguments must be symbolic or none of them can be',
# bounding_box=(self.current_ast_node.lineno, self.current_ast_node.col_offset),
# severity='fatal')
return PythonPrint(args)
def _visit_ClassDef(self, expr):
# TODO - improve the use and def of interfaces
# - wouldn't be better if it is done inside ClassDef?
if expr.get_direct_user_nodes(lambda u: isinstance(u, CodeBlock)):
errors.report("Classes can only be declared in modules.",
symbol=expr, severity='error')
name = self.scope.get_expected_name(expr.name)
# create a new Datatype for the current class
dtype = DataTypeFactory(name)()
typenames_to_dtypes[name] = dtype
self.scope.cls_constructs[name] = dtype
parent = self._find_superclasses(expr)
cls_scope = self.create_new_class_scope(name, used_symbols=expr.scope.local_used_symbols,
original_symbols = expr.scope.python_names.copy())
attribute_annotations = [self._visit(a) for a in expr.attributes]
attributes = []
for a in attribute_annotations:
if len(a) != 1:
errors.report(f"Couldn't determine type of {a}",
severity='error', symbol=a)
else:
v = a[0]
cls_scope.insert_variable(v)
attributes.append(v)
self.exit_class_scope()
docstring = self._visit(expr.docstring) if expr.docstring else expr.docstring
cls = ClassDef(name, attributes, [], superclasses=parent, scope=cls_scope,
docstring = docstring, class_type = dtype)
self.scope.insert_class(cls)
methods = expr.methods
for method in methods:
cls.add_new_method(method)
return EmptyNode()
def _visit_Del(self, expr):
ls = [Deallocate(self._visit(i)) for i in expr.variables]
return Del(ls)
def _visit_PyccelIs(self, expr):
# Handles PyccelIs and PyccelIsNot
IsClass = type(expr)
# TODO ERROR wrong position ??
var1 = self._visit(expr.lhs)
var2 = self._visit(expr.rhs)
if (var1 is var2) or (isinstance(var2, Nil) and isinstance(var1, Nil)):
if IsClass == PyccelIsNot:
return LiteralFalse()
elif IsClass == PyccelIs:
return LiteralTrue()
if isinstance(var1, Nil):
var1, var2 = var2, var1
if isinstance(var2, Nil):
if not isinstance(var1, Variable) or not var1.is_optional:
if IsClass == PyccelIsNot:
return LiteralTrue()
elif IsClass == PyccelIs:
return LiteralFalse()
return IsClass(var1, expr.rhs)
if (var1.dtype != var2.dtype):
if IsClass == PyccelIs:
return LiteralFalse()
elif IsClass == PyccelIsNot:
return LiteralTrue()
if (isinstance(var1.dtype, PythonNativeBool) and
isinstance(var2.dtype, PythonNativeBool)):
return IsClass(var1, var2)
if isinstance(var1.class_type, (StringType, FixedSizeNumericType)):
errors.report(PYCCEL_RESTRICTION_PRIMITIVE_IMMUTABLE, symbol=expr,
severity='error')
return IsClass(var1, var2)
errors.report(PYCCEL_RESTRICTION_IS_ISNOT,
symbol=expr, severity='error')
return IsClass(var1, var2)
def _visit_Import(self, expr):
# TODO - must have a dict where to store things that have been
# imported
# - should not use scope
if expr.get_direct_user_nodes(lambda u: isinstance(u, CodeBlock)):
errors.report("Imports can only be used in modules or inside functions.",
symbol=expr, severity='error')
container = self.scope.imports
result = EmptyNode()
if isinstance(expr.source, AsName):
source = expr.source.name
source_target = expr.source.local_alias
else:
source = str(expr.source)
source_target = source
if source in pyccel_builtin_import_registry:
imports = pyccel_builtin_import(expr)
def _insert_obj(location, target, obj):
F = self.scope.find(target)
if obj is F:
errors.report(FOUND_DUPLICATED_IMPORT,
symbol=target, severity='warning')
elif F is None or isinstance(F, dict):
container[location][target] = obj
else:
errors.report(IMPORTING_EXISTING_IDENTIFIED, symbol=expr,
severity='fatal')
if expr.target:
for t in expr.target:
t_name = t.name if isinstance(t, AsName) else t
if t_name not in pyccel_builtin_import_registry[source]:
errors.report(f"Function '{t}' from module '{source}' is not currently supported by pyccel",
symbol=expr,
severity='error')
for (name, atom) in imports:
if not name is None:
if isinstance(atom, Decorator):
continue
elif isinstance(atom, Constant):
_insert_obj('variables', name, atom)
else:
_insert_obj('functions', name, atom)
else:
assert len(imports) == 1
mod = imports[0][1]
assert isinstance(mod, Module)
_insert_obj('variables', source_target, mod)
self.insert_import(source, [AsName(v,n) for n,v in imports], source_target)
elif recognised_source(source):
errors.report(f"Module {source} is not currently supported by pyccel",
symbol=expr,
severity='error')
else:
# we need to use str here since source has been defined
# using repr.
# TODO shall we improve it?
p = self.d_parsers[source_target]
import_init = p.semantic_parser.ast.init_func if source_target not in container['imports'] else None
import_free = p.semantic_parser.ast.free_func if source_target not in container['imports'] else None
if expr.target:
targets = {i.local_alias if isinstance(i,AsName) else i:None for i in expr.target}
names = [i.name if isinstance(i,AsName) else i for i in expr.target]
p_scope = p.scope
p_imports = p_scope.imports
entries = ['variables', 'classes', 'functions']
direct_sons = ((e,getattr(p.scope, e)) for e in entries)
import_sons = ((e,p_imports[e]) for e in entries)
for entry, d_son in chain(direct_sons, import_sons):
for t,n in zip(targets.keys(),names):
if n in d_son:
e = d_son[n]
if entry == 'functions':
container[entry][t] = e.clone(t, is_imported=True)
m = e.get_direct_user_nodes(lambda x: isinstance(x, Module))[0]
container[entry][t].set_current_user_node(m)
elif entry == 'variables':
container[entry][t] = e.clone(t)
else:
container[entry][t] = e
targets[t] = e
if None in targets.values():
errors.report("Import target {} could not be found",
severity="warning", symbol=expr)
targets = [AsName(v,k) for k,v in targets.items() if v is not None]
else:
mod = p.semantic_parser.ast
container['variables'][source_target] = mod
targets = [AsName(mod, source_target)]
self.scope.cls_constructs.update(p.scope.cls_constructs)
# ... meta variables
# in some cases (blas, lapack and openacc level-0)
# the import should not appear in the final file
# all metavars here, will have a prefix and suffix = __
__ignore_at_import__ = p.metavars.get('ignore_at_import', False)
# Indicates that the module must be imported with the syntax 'from mod import *'
__import_all__ = p.metavars.get('import_all', False)
# Indicates the name of the fortran module containing the functions
__module_name__ = p.metavars.get('module_name', None)
if source_target in container['imports']:
targets.extend(container['imports'][source_target].target)
if import_init:
old_name = import_init.name
new_name = self.scope.get_new_name(old_name)
targets.append(AsName(import_init, new_name))
if new_name != old_name:
import_init = import_init.clone(new_name)
container['functions'][old_name] = import_init
result = import_init()
if import_free:
old_name = import_free.name
new_name = self.scope.get_new_name(old_name)
targets.append(AsName(import_free, new_name))
if new_name != old_name:
import_free = import_free.clone(new_name)
mod = p.semantic_parser.ast
if __import_all__:
expr = Import(source_target, AsName(mod, __module_name__), mod=mod)
container['imports'][source_target] = expr
elif __module_name__:
expr = Import(__module_name__, targets, mod=mod)
container['imports'][source_target] = expr
elif not __ignore_at_import__:
expr = Import(source, targets, mod=mod)
container['imports'][source_target] = expr
return result
def _visit_With(self, expr):
scope = self.create_new_loop_scope()
domaine = self._visit(expr.test)
parent = domaine.cls_base
if not parent.is_with_construct:
errors.report(UNDEFINED_WITH_ACCESS, symbol=expr,
severity='fatal')
body = self._visit(expr.body)
self.exit_loop_scope()
return With(domaine, body, scope).block
def _visit_StarredArguments(self, expr):
var = self._visit(expr.args_var)
assert var.rank==1
size = var.shape[0]
return StarredArguments([var[i] for i in range(size)])
def _visit_NumpyMatmul(self, expr):
self.insert_import('numpy', AsName(NumpyMatmul, 'matmul'))
a = self._visit(expr.a)
b = self._visit(expr.b)
return NumpyMatmul(a, b)
def _visit_Assert(self, expr):
test = self._visit(expr.test)
return Assert(test)
def _visit_FunctionDefResult(self, expr):
f_name = self.current_function_name
if isinstance(f_name, DottedName):
f_name = f_name.name[-1]
# There may be no name if we are in a FunctionTypeAnnotation
if f_name:
original_name = self.scope.get_python_name(f_name)
if original_name.startswith('__i') and ('__'+original_name[3:]) in magic_method_map.values():
return EmptyNode()
var = self._visit(expr.var)
if isinstance(var, list):
n_types = len(var)
if n_types == 0:
errors.report("Can't deduce type for function definition result.",
severity = 'fatal', symbol = expr)
elif n_types != 1:
errors.report("The type of the result of a function definition cannot be a union of multiple types.",
severity = 'error', symbol = expr)
var = var[0]
self.scope.insert_variable(var)
return FunctionDefResult(var, annotation = expr.annotation)
#====================================================
# _build functions
#====================================================
def _build_NumpyWhere(self, func_call, func_call_args):
"""
Method for building the node created by a call to `numpy.where`.
Method for building the node created by a call to `numpy.where`. If only one argument is passed to `numpy.where`
then it is equivalent to a call to `numpy.nonzero`. The result of a call to `numpy.nonzero`
is a complex object so there is a `_build_NumpyNonZero` function which must be called.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `numpy.nonzero.
func_call_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `numpy.nonzero` function.
"""
# expr is a FunctionCall
args = [a.value for a in func_call_args if not a.has_keyword]
kwargs = {a.keyword: a.value for a in func_call.args if a.has_keyword}
nargs = len(args)+len(kwargs)
if nargs == 1:
return self._build_NumpyNonZero(func_call, func_call_args)
return NumpyWhere(*args, **kwargs)
def _build_NumpyNonZero(self, func_call, func_call_args):
"""
Method for building the node created by a call to `numpy.nonzero`.
Method for building the node created by a call to `numpy.nonzero`. The result of a call to `numpy.nonzero`
is a complex object (tuple of arrays) in order to ensure that the results are correctly saved into the
correct objects it is therefore important to call `_visit` on any intermediate expressions that are required.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `numpy.nonzero.
func_call_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `numpy.nonzero` function.
"""
# expr is a FunctionCall
arg = func_call_args[0].value
if not isinstance(arg, Variable):
pyccel_stage.set_stage('syntactic')
new_symbol = PyccelSymbol(self.scope.get_new_name())
syntactic_assign = Assign(new_symbol, arg, python_ast=func_call.python_ast)
pyccel_stage.set_stage('semantic')
creation = self._visit(syntactic_assign)
self._additional_exprs[-1].append(creation)
arg = self._visit(new_symbol)
return NumpyWhere(arg)
def _build_ListExtend(self, expr, args):
"""
Method to navigate the syntactic DottedName node of an `extend()` call.
The purpose of this `_build` method is to construct new nodes from a syntactic
DottedName node. It checks the type of the iterable passed to `extend()`.
If the iterable is an instance of `PythonList` or `PythonTuple`, it constructs
a CodeBlock node where its body consists of `ListAppend` objects with the
elements of the iterable. If not, it attempts to construct a syntactic `For`
loop to iterate over the iterable object and append its elements to the list
object. Finally, it passes to a `_visit()` call for semantic parsing.
Parameters
----------
expr : DottedName
The syntactic DottedName node that represents the call to `.extend()`.
args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
PyccelAstNode
CodeBlock or For containing ListAppend objects.
"""
iterable = expr.name[1].args[0].value
if isinstance(iterable, (PythonList, PythonTuple)):
list_variable = self._visit(expr.name[0])
added_list = self._visit(iterable)
try:
store = [ListAppend(list_variable, a) for a in added_list]
except TypeError as e:
msg = str(e)
errors.report(msg, symbol=expr, severity='fatal')
if not isinstance(list_variable.class_type.element_type, (StringType, FixedSizeNumericType)):
for a in added_list:
if not isinstance(a, (PythonList, PythonSet, PythonTuple, NumpyNewArray)):
self._indicate_pointer_target(list_variable, a, expr)
return CodeBlock(store)
else:
pyccel_stage.set_stage('syntactic')
for_target = self.scope.get_new_name('index')
arg = FunctionCallArgument(for_target)
func_call = FunctionCall('append', [arg])
dotted = DottedName(expr.name[0], func_call)
dotted.set_current_ast(expr.python_ast)
lhs = PyccelSymbol('_', is_temp=True)
assign = Assign(lhs, dotted)
assign.set_current_ast(expr.python_ast)
body = CodeBlock([assign])
for_obj = For(for_target, iterable, body)
pyccel_stage.set_stage('semantic')
return self._visit(for_obj)
def _build_MathSqrt(self, func_call, func_call_args):
"""
Method for building the node created by a call to `math.sqrt`.
Method for building the node created by a call to `math.sqrt`. A separate method is needed for
this because some expressions are simplified. This is notably the case for expressions such as
`math.sqrt(a**2)`. When `a` is a complex number this expression is equivalent to a call to `math.fabs`.
The expression is translated to this node. The associated imports therefore need to be inserted into the parser.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `cmath.sqrt`.
func_call_args : iterable[FunctionCallArgument]
The semantic argument passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `cmath.sqrt` function.
"""
func = self.scope.find(func_call.funcdef, 'functions')
arg = func_call_args[0]
if isinstance(arg.value, PyccelMul):
mul1, mul2 = arg.value.args
if mul1 is mul2:
pyccel_stage.set_stage('syntactic')
fabs_name = self.scope.get_new_name('fabs')
imp_name = AsName('fabs', fabs_name)
new_import = Import('math',imp_name)
new_call = FunctionCall(fabs_name, [mul1])
pyccel_stage.set_stage('semantic')
self._visit(new_import)
return self._visit(new_call)
elif isinstance(arg.value, PyccelPow):
base, exponent = arg.value.args
if exponent == 2:
pyccel_stage.set_stage('syntactic')
fabs_name = self.scope.get_new_name('fabs')
imp_name = AsName('fabs', fabs_name)
new_import = Import('math',imp_name)
new_call = FunctionCall(fabs_name, [base])
pyccel_stage.set_stage('semantic')
self._visit(new_import)
return self._visit(new_call)
return self._handle_function(func_call, func, (arg,), use_build_functions = False)
def _build_CmathSqrt(self, func_call, func_call_args):
"""
Method for building the node created by a call to `cmath.sqrt`.
Method for building the node created by a call to `cmath.sqrt`. A separate method is needed for
this because some expressions are simplified. This is notably the case for expressions such as
`cmath.sqrt(a**2)`. When `a` is a complex number this expression is equivalent to a call to `cmath.fabs`.
The expression is translated to this node. The associated imports therefore need to be inserted into the parser.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `cmath.sqrt`.
func_call_args : iterable[FunctionCallArgument]
The semantic argument passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `cmath.sqrt` function.
"""
func = self.scope.find(func_call.funcdef, 'functions')
arg = func_call_args[0]
if isinstance(arg.value, PyccelMul):
mul1, mul2 = arg.value.args
is_abs = False
if isinstance(mul1, (NumpyConjugate, PythonConjugate)) and mul1.internal_var is mul2:
is_abs = True
abs_arg = mul2
elif isinstance(mul2, (NumpyConjugate, PythonConjugate)) and mul1 is mul2.internal_var:
is_abs = True
abs_arg = mul1
if is_abs:
pyccel_stage.set_stage('syntactic')
abs_name = self.scope.get_new_name('abs')
imp_name = AsName('abs', abs_name)
new_import = Import('numpy',imp_name)
new_call = FunctionCall(abs_name, [abs_arg])
pyccel_stage.set_stage('semantic')
self._visit(new_import)
# Cast to preserve final dtype
return PythonComplex(self._visit(new_call))
return self._handle_function(func_call, func, (arg,), use_build_functions = False)
def _build_CmathPolar(self, func_call, func_call_args):
"""
Method for building the node created by a call to `cmath.polar`.
Method for building the node created by a call to `cmath.polar`. A separate method is needed for
this because the function is translated to an expression including calls to `math.sqrt` and
`math.atan2`. The associated imports therefore need to be inserted into the parser.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `cmath.polar`.
func_call_args : iterable[FunctionCallArgument]
The semantic argument passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `cmath.polar` function.
"""
arg = func_call_args[0]
z = arg.value
x = PythonReal(z)
y = PythonImag(z)
x_var = self.scope.get_temporary_variable(z, class_type=PythonNativeFloat(),
is_argument=False)
y_var = self.scope.get_temporary_variable(z, class_type=PythonNativeFloat(),
is_argument=False)
self._additional_exprs[-1].append(Assign(x_var, x))
self._additional_exprs[-1].append(Assign(y_var, y))
r = MathSqrt(PyccelAdd(PyccelMul(x_var,x_var), PyccelMul(y_var,y_var)))
t = MathAtan2(y_var, x_var)
self.insert_import('math', AsName(MathSqrt, 'sqrt'))
self.insert_import('math', AsName(MathAtan2, 'atan2'))
return PythonTuple(r,t)
def _build_CmathRect(self, func_call, func_call_args):
"""
Method for building the node created by a call to `cmath.rect`.
Method for building the node created by a call to `cmath.rect`. A separate method is needed for
this because the function is translated to an expression including calls to `math.cos` and
`math.sin`. The associated imports therefore need to be inserted into the parser.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `cmath.rect`.
func_call_args : iterable[FunctionCallArgument]
The 2 semantic arguments passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `cmath.rect` function.
"""
arg_r, arg_phi = func_call_args
r = arg_r.value
phi = arg_phi.value
x = PyccelMul(r, MathCos(phi))
y = PyccelMul(r, MathSin(phi))
self.insert_import('math', AsName(MathCos, 'cos'))
self.insert_import('math', AsName(MathSin, 'sin'))
return PyccelAdd(x, PyccelMul(y, LiteralImaginaryUnit()))
def _build_CmathPhase(self, func_call, func_call_args):
"""
Method for building the node created by a call to `cmath.phase`.
Method for building the node created by a call to `cmath.phase`. A separate method is needed for
this because the function is translated to a call to `math.atan2`. The associated import therefore
needs to be inserted into the parser.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `cmath.phase`.
func_call_args : iterable[FunctionCallArgument]
The semantic argument passed to the function.
Returns
-------
TypedAstNode
A node describing the result of a call to the `cmath.phase` function.
"""
arg = func_call_args[0]
var = arg.value
if not isinstance(var.dtype.primitive_type, PrimitiveComplexType):
return LiteralFloat(0.0)
else:
self.insert_import('math', AsName(MathAtan2, 'atan2'))
return MathAtan2(PythonImag(var), PythonReal(var))
def _build_PythonTupleFunction(self, func_call, func_args):
"""
Method for building the node created by a call to `tuple()`.
Method for building the node created by a call to `tuple()`. A separate method is needed for
this because inhomogeneous variables can be passed to this function. In order to access the
underlying variables for the indexed elements access to the scope is required.
Parameters
----------
func_call : FunctionCall
The syntactic FunctionCall describing the call to `tuple()`.
func_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
PythonTuple
A node describing the result of a call to the `tuple()` function.
"""
arg = func_args[0].value
if isinstance(arg, PythonTuple):
return arg
elif isinstance(arg.shape[0], LiteralInteger):
return PythonTuple(*[self.scope.collect_tuple_element(a) for a in arg])
else:
raise TypeError(f"Can't unpack {arg} into a tuple")
def _build_NumpyArray(self, expr, func_call_args):
"""
Method for building the node created by a call to `numpy.array`.
Method for building the node created by a call to `numpy.array`. A separate method is needed for
this because inhomogeneous variables can be passed to this function. In order to access the
underlying variables for the indexed elements access to the scope is required.
Parameters
----------
expr : FunctionCall | DottedName
The syntactic FunctionCall describing the call to `numpy.array`.
If `numpy.array` is called via a call to `numpy.copy` then this is a DottedName describing the call.
func_call_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
NumpyArray
A node describing the result of a call to the `numpy.array` function.
"""
if isinstance(expr, DottedName):
arg = expr.name[0]
dtype = None
ndmin = None
func_call = expr.name[1]
func = func_call.funcdef
func_call_args = func_call.args
order = func_call_args[0].value if func_call_args else func.argument_description['order']
else:
args, kwargs = split_positional_keyword_arguments(*func_call_args)
def unpack_args(arg, dtype = None, order = 'K', ndmin = None):
""" Small function to reorder and get access to the named variables from args and kwargs.
"""
return arg, dtype, order, ndmin
arg, dtype, order, ndmin = unpack_args(*args, **kwargs)
if not isinstance(arg, (PythonTuple, PythonList, Variable, IndexedElement)):
errors.report('Unexpected object passed to numpy.array',
severity='fatal', symbol=expr)
is_homogeneous_tuple = isinstance(arg.class_type, HomogeneousTupleType)
# Inhomogeneous tuples can contain homogeneous data if it is inhomogeneous due to pointers
if isinstance(arg.class_type, InhomogeneousTupleType):
is_homogeneous_tuple = isinstance(arg.dtype, FixedSizeNumericType) and len(set(a.rank for a in arg))
if not isinstance(arg, PythonTuple):
arg = PythonTuple(*(self.scope.collect_tuple_element(a) for a in arg))
if not (is_homogeneous_tuple or isinstance(arg.class_type, HomogeneousContainerType)):
errors.report('Inhomogeneous type passed to numpy.array',
severity='fatal', symbol=expr)
if not isinstance(order, (LiteralString, str)):
errors.report('Order must be specified with a literal string',
severity='fatal', symbol=expr)
elif isinstance(order, LiteralString):
order = order.python_value
if ndmin is not None:
if not isinstance(ndmin, (LiteralInteger, int)):
errors.report("The minimum number of dimensions must be specified explicitly with an integer.",
severity='fatal', symbol=expr)
elif isinstance(ndmin, LiteralInteger):
ndmin = ndmin.python_value
return NumpyArray(arg, dtype, order, ndmin)
def _build_SetUpdate(self, expr, args):
"""
Method to navigate the syntactic DottedName node of an `update()` call.
The purpose of this `_build` method is to construct new nodes from a syntactic
DottedName node. It checks the type of the iterable passed to `update()`.
If the iterable is an instance of `PythonList`, `PythonSet` or `PythonTuple`, it constructs
a CodeBlock node where its body consists of `SetAdd` objects with the
elements of the iterable. If not, it attempts to construct a syntactic `For`
loop to iterate over the iterable object and added its elements to the set
object. Finally, it passes to a `_visit()` call for semantic parsing.
Parameters
----------
expr : DottedName | AugAssign
The syntactic DottedName node that represents the call to `.update()`.
args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
PyccelAstNode
CodeBlock or For containing SetAdd objects.
"""
if isinstance(expr, DottedName):
iterable_args = [a.value for a in expr.name[1].args]
set_obj = expr.name[0]
elif isinstance(expr, AugAssign):
iterable_args = [expr.rhs]
set_obj = expr.lhs
else:
raise NotImplementedError(f"Function doesn't handle {type(expr)}")
code = []
for iterable in iterable_args:
if isinstance(iterable, (PythonList, PythonSet, PythonTuple)):
list_variable = self._visit(set_obj)
added_list = self._visit(iterable)
try:
code.extend(SetAdd(list_variable, a) for a in added_list)
except TypeError as e:
msg = str(e)
errors.report(msg, symbol=expr, severity='fatal')
else:
pyccel_stage.set_stage('syntactic')
for_target = self.scope.get_new_name()
arg = FunctionCallArgument(for_target)
func_call = FunctionCall('add', [arg])
dotted = DottedName(set_obj, func_call)
lhs = PyccelSymbol('_', is_temp=True)
assign = Assign(lhs, dotted)
assign.set_current_ast(expr.python_ast)
body = CodeBlock([assign])
for_obj = For(for_target, iterable, body)
pyccel_stage.set_stage('semantic')
code.append(self._visit(for_obj))
if len(code) == 1:
return code[0]
else:
return CodeBlock(code)
def _build_SetUnion(self, expr, function_call_args):
"""
Method to navigate the syntactic DottedName node of a `set.union()` call.
The purpose of this `_build` method is to construct new nodes from a syntactic
DottedName node. It creates a SetUnion node if the type of the arguments matches
the type of the original set. Otherwise it uses `set.copy` and `set.update` to
handle iterators.
Parameters
----------
expr : DottedName
The syntactic DottedName node that represents the call to `.union()`.
function_call_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
SetUnion | CodeBlock
The nodes describing the union operator.
"""
if isinstance(expr, DottedName):
syntactic_set_obj = expr.name[0]
syntactic_args = [a.value for a in expr.name[1].args]
elif isinstance(expr, PyccelBitOr):
syntactic_set_obj = expr.args[0]
syntactic_args = expr.args[1:]
else:
raise NotImplementedError(f"Function doesn't handle {type(expr)}")
args = [a.value for a in function_call_args]
set_obj = self._visit(syntactic_set_obj)
class_type = set_obj.class_type
if all(a.class_type == class_type for a in args):
return SetUnion(set_obj, *args[1:])
else:
element_type = class_type.element_type
if any(a.class_type.element_type != element_type for a in args):
errors.report(("Containers containing objects of a different type cannot be used as "
f"arguments to {class_type}.union"),
severity='fatal', symbol=expr)
lhs = expr.get_user_nodes(Assign)[0].lhs
pyccel_stage.set_stage('syntactic')
body = [Assign(lhs, DottedName(syntactic_set_obj, FunctionCall('copy', ())),
python_ast = expr.python_ast)]
update_calls = [DottedName(lhs, FunctionCall('update', (s_a,))) for s_a in syntactic_args]
for c in update_calls:
c.set_current_ast(expr.python_ast)
body += [Assign(PyccelSymbol('_', is_temp=True), c, python_ast = expr.python_ast)
for c in update_calls]
pyccel_stage.set_stage('semantic')
return CodeBlock([self._visit(b) for b in body])
def _build_SetIntersection(self, expr, function_call_args):
"""
Method to visit a SetIntersection node.
The purpose of this `_build` method is to construct multiple nodes to represent
the single DottedName node representing the call to SetIntersection. It
replaces the call with a call to copy followed by multiple calls to
SetIntersectionUpdate.
Parameters
----------
expr : DottedName
The syntactic DottedName node that represents the call to `.intersection()`.
function_call_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
CodeBlock
CodeBlock containing SetCopy and SetIntersectionUpdate objects.
"""
start_set = function_call_args[0].value
set_args = [self._visit(a.value) for a in function_call_args[1:]]
assign = expr.get_direct_user_nodes(lambda a: isinstance(a, Assign))
if assign:
syntactic_lhs = assign[-1].lhs
else:
syntactic_lhs = self.scope.get_new_name()
d_var = self._infer_type(start_set)
if isinstance(start_set, PythonSet):
rhs = start_set
else:
rhs = SetCopy(start_set)
body = []
lhs = self._assign_lhs_variable(syntactic_lhs, d_var, rhs, body)
body.append(Assign(lhs, rhs, python_ast = expr.python_ast))
try:
body += [SetIntersectionUpdate(lhs, s) for s in set_args]
except TypeError as e:
errors.report(e, symbol=expr, severity='error')
if assign:
return CodeBlock(body)
else:
self._additional_exprs[-1].extend(body)
return lhs
def _build_PythonLen(self, expr, function_call_args):
"""
Method to visit a PythonLen node.
The purpose of this `_build` method is to construct a node representing
a call to the PythonLen function. This function returns the first element
of the shape of a variable, or a call to a method which calculates the
length (e.g. the `__len__` function).
Parameters
----------
expr : FunctionCall
The syntactic node that represents the call to `len()`.
function_call_args : iterable[FunctionCallArgument]
The semantic argument passed to the function.
Returns
-------
TypedAstNode
The node representing an object which allows the result of the
PythonLen function to be obtained.
"""
arg = function_call_args[0].value
class_type = arg.class_type
if isinstance(arg, LiteralString):
return LiteralInteger(len(arg.python_value))
elif isinstance(arg.class_type, CustomDataType):
class_base = self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)
magic_method = class_base.get_method('__len__')
if magic_method:
return self._handle_function(expr, magic_method, function_call_args)
else:
raise errors.report(f"__len__ not implemented for type {class_type}",
severity='fatal', symbol=expr)
elif arg.rank > 0:
return arg.shape[0]
else:
raise errors.report(f"__len__ not implemented for type {class_type}",
severity='fatal', symbol=expr)
def _build_PythonSetFunction(self, expr, function_call_args):
"""
Method to visit a PythonSetFunction node.
The purpose of this `_build` method is to construct a node representing
a set which is built from another object. A build function is required
as sets of unknown length must be built by calling the add function
repeatedly. This means that the entire assignment statement must be used.
Parameters
----------
expr : FunctionCall
The syntactic node that represents the call to `PythonSetFunction`.
function_call_args : iterable[FunctionCallArgument]
The semantic arguments passed to the function.
Returns
-------
TypedAstNode | CodeBlock
The node representing an object which allows the set to be created.
"""
if len(function_call_args) == 0:
return PythonSet()
arg = function_call_args[0].value
class_type = arg.class_type
if isinstance(arg, (PythonList, PythonSet, PythonTuple)):
return PythonSet(*arg)
elif isinstance(class_type, HomogeneousSetType):
return SetCopy(arg)
else:
assigns = expr.get_direct_user_nodes(lambda a: isinstance(a, Assign))
if not assigns:
lhs = self.scope.get_new_name()
else:
assert len(assigns) == 1
lhs = assigns[0].lhs
d_var = {
'class_type' : HomogeneousSetType(class_type.element_type),
'shape' : arg.shape,
'cls_base' : SetClass,
'memory_handling' : 'heap'
}
body = []
lhs_semantic_var = self._assign_lhs_variable(lhs, d_var, PythonSetFunction(arg), body)
scope = self.create_new_loop_scope()
targets, iterable = self._get_for_iterators(arg, self.scope.get_new_name(), body, expr)
self.exit_loop_scope()
body.append(For(targets, iterable, [SetAdd(lhs_semantic_var, targets[0])], scope=scope))
if assigns:
return CodeBlock(body)
else:
self._additional_exprs[-1].extend(body)
return lhs_semantic_var
def _build_PythonIsInstance(self, expr, function_call_args):
"""
Method to visit a PythonIsInstance node.
The purpose of this `_build` method is to construct a literal boolean indicating
whether or not the expression has the expected type.
The syntactic node that represents the call to `isinstance()`.
Parameters
----------
expr : FunctionCall
The syntactic node that represents the call to `PythonSetFunction`.
function_call_args : iterable[FunctionCallArgument]
The 2 semantic arguments passed to the function.
Returns
-------
Literal
A LiteralTrue or LiteralFalse node describing the result of the `isinstance`
call.
"""
obj = function_call_args[0].value
class_or_tuple = function_call_args[1].value
if isinstance(class_or_tuple, PythonTuple):
obj_arg = function_call_args[0]
return PyccelOr(*[self._build_PythonIsInstance(expr, [obj_arg, FunctionCallArgument(class_type)]) \
for class_type in class_or_tuple], simplify=True)
elif isinstance(class_or_tuple, UnionTypeAnnotation):
obj_arg = function_call_args[0]
return PyccelOr(*[self._build_PythonIsInstance(expr, [obj_arg, FunctionCallArgument(var_annot)]) \
for var_annot in class_or_tuple.type_list], simplify=True)
else:
if isinstance(class_or_tuple, VariableTypeAnnotation):
expected_type = class_or_tuple.class_type
else:
class_type = class_or_tuple.cls_name
try:
expected_type = class_type.static_type()
except AttributeError:
expected_type = None
if isinstance(expected_type, type):
return convert_to_literal(isinstance(obj.class_type, expected_type))
elif expected_type:
class_type = obj.class_type
cls_base_to_insert = [self.scope.find(str(class_type), 'classes') or get_cls_base(class_type)]
possible_types = {class_type}
while cls_base_to_insert:
cls_base = cls_base_to_insert.pop()
class_type = cls_base.class_type
possible_types.add(class_type)
cls_base_to_insert.extend(cls_base.superclasses)
possible_types.discard(None)
return convert_to_literal(expected_type in possible_types)
else:
errors.report(f"Type {class_or_tuple} is not handled in isinstance call.",
severity='error', symbol=expr)
return LiteralTrue()
def _build_ListAppend(self, expr, args):
"""
Method to create the semantic ListAppend node.
Method to create the semantic ListAppend node ensuring that pointers are
correctly handled.
Parameters
----------
expr : DottedName
The syntactic DottedName node that represents the call to `.append()`.
args : iterable[FunctionCallArgument]
An iterable containing the 1 semantic argument passed to the function.
Returns
-------
ListAppend
The semantic ListAppend object.
"""
list_obj, append_arg = [a.value for a in args]
semantic_node = ListAppend(list_obj, append_arg)
if not isinstance(append_arg.class_type, (StringType, FixedSizeNumericType)) \
and not isinstance(append_arg, (PythonList, PythonSet, PythonTuple, NumpyNewArray)):
self._indicate_pointer_target(list_obj, append_arg, expr)
return semantic_node
def _build_ListInsert(self, expr, args):
"""
Method to create the semantic ListInsert node.
Method to create the semantic ListInsert node ensuring that pointers are
correctly handled.
Parameters
----------
expr : DottedName
The syntactic DottedName node that represents the call to `.insert()`.
args : iterable[FunctionCallArgument]
The 2 semantic arguments passed to the function.
Returns
-------
ListInsert
The semantic ListInsert object.
"""
list_obj, index, new_elem = [a.value for a in args]
semantic_node = ListInsert(list_obj, index, new_elem)
if not isinstance(new_elem.class_type, (StringType, FixedSizeNumericType)) \
and not isinstance(new_elem, (PythonList, PythonSet, PythonTuple, NumpyNewArray)):
self._indicate_pointer_target(list_obj, new_elem, expr)
return semantic_node