Skip to content

Commit

Permalink
More NumPy operation implementations (#1498)
Browse files Browse the repository at this point in the history
* Concatenation and stacking (numpy.concatenate, numpy.stack, and
their variants)
* numpy.linspace
* Fix nested attribute parsing (Fixes #1295)
* numpy.clip
* numpy.split and its variants
* numpy.full variants (zeros, ones, etc.) with a single value for
shape (`np.zeros(N)`)
* NumPy-compatible numpy.arange dtype inference
* `numpy.fft.{fft, ifft}`
  • Loading branch information
tbennun authored Oct 30, 2024
1 parent 1343a6e commit 2811e40
Show file tree
Hide file tree
Showing 23 changed files with 1,458 additions and 86 deletions.
8 changes: 4 additions & 4 deletions dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,8 @@ def _Num(self, t):
# For complex values, use ``dtype_to_typeclass``
if isinstance(t_n, complex):
dtype = dtypes.dtype_to_typeclass(complex)
repr_n = f'{dtype}({t_n.real}, {t_n.imag})'


# Handle large integer values
if isinstance(t_n, int):
Expand All @@ -765,10 +767,8 @@ def _Num(self, t):
elif bits >= 64:
warnings.warn(f'Value wider than 64 bits encountered in expression ({t_n}), emitting as-is')

if repr_n.endswith("j"):
self.write("%s(0, %s)" % (dtype, repr_n.replace("inf", INFSTR)[:-1]))
else:
self.write(repr_n.replace("inf", INFSTR))
repr_n = repr_n.replace("inf", INFSTR)
self.write(repr_n)

def _List(self, t):
raise NotImplementedError('Invalid C++')
Expand Down
4 changes: 4 additions & 0 deletions dace/distr_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def _validate(self):
raise ValueError('Color must have only logical true (1) or false (0) values.')
return True

@property
def dtype(self):
return type(self)

def to_json(self):
attrs = serialize.all_properties_to_json(self)
retdict = {"type": type(self).__name__, "attributes": attrs}
Expand Down
32 changes: 16 additions & 16 deletions dace/frontend/common/distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: Shape


@oprepo.replaces_method('Intracomm', 'Create_cart')
def _intracomm_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', dims: ShapeType):
def _intracomm_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, dims: ShapeType):
""" Equivalent to `dace.comm.Cart_create(dims).
:param dims: Shape of the process-grid (see `dims` parameter of `MPI_Cart_create`), e.g., [2, 3, 3].
:return: Name of the new process-grid descriptor.
"""

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _cart_create(pv, sdfg, state, dims)
Expand Down Expand Up @@ -186,13 +186,13 @@ def _bcast(pv: ProgramVisitor,
def _intracomm_bcast(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
comm: Tuple[str, 'Comm'],
comm: str,
buffer: str,
root: Union[str, sp.Expr, Number] = 0):
""" Equivalent to `dace.comm.Bcast(buffer, root)`. """

from mpi4py import MPI
comm_name, comm_obj = comm
comm_name, comm_obj = comm, pv.globals[comm]
if comm_obj == MPI.COMM_WORLD:
return _bcast(pv, sdfg, state, buffer, root)
# NOTE: Highly experimental
Expand Down Expand Up @@ -267,12 +267,12 @@ def _alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, inbuffer: str,


@oprepo.replaces_method('Intracomm', 'Alltoall')
def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', inp_buffer: str,
def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, inp_buffer: str,
out_buffer: str):
""" Equivalent to `dace.comm.Alltoall(inp_buffer, out_buffer)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _alltoall(pv, sdfg, state, inp_buffer, out_buffer)
Expand Down Expand Up @@ -303,12 +303,12 @@ def _allreduce(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, buffer: str, op


@oprepo.replaces_method('Intracomm', 'Allreduce')
def _intracomm_allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', inp_buffer: 'InPlace',
def _intracomm_allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, inp_buffer: 'InPlace',
out_buffer: str, op: str):
""" Equivalent to `dace.comm.Allreduce(out_buffer, op)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
if inp_buffer != MPI.IN_PLACE:
Expand Down Expand Up @@ -470,12 +470,12 @@ def _send(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Send')
def _intracomm_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.end(buffer, dst, tag)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _send(pv, sdfg, state, buffer, dst, tag)
Expand Down Expand Up @@ -592,12 +592,12 @@ def _isend(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Isend')
def _intracomm_isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.Isend(buffer, dst, tag, req)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
req, _ = sdfg.add_array("isend_req", [1], dace.dtypes.opaque("MPI_Request"), transient=True, find_new_name=True)
Expand Down Expand Up @@ -690,12 +690,12 @@ def _recv(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Recv')
def _intracomm_Recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_Recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.Recv(buffer, src, tagq)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _recv(pv, sdfg, state, buffer, src, tag)
Expand Down Expand Up @@ -810,12 +810,12 @@ def _irecv(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Irecv')
def _intracomm_irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.Irecv(buffer, src, tag, req)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
req, _ = sdfg.add_array("irecv_req", [1], dace.dtypes.opaque("MPI_Request"), transient=True, find_new_name=True)
Expand Down
57 changes: 40 additions & 17 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ def defined(self):

# MPI-related stuff
result.update({
k: self.sdfg.process_grids[v]
v: self.sdfg.process_grids[v]
for k, v in self.variables.items() if v in self.sdfg.process_grids
})
try:
Expand Down Expand Up @@ -4461,7 +4461,14 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
func = node.func.value

if func is None:
funcname = rname(node)
func_result = self.visit(node.func)
if isinstance(func_result, str):
if isinstance(node.func, ast.Attribute):
funcname = f'{func_result}.{node.func.attr}'
else:
funcname = func_result
else:
funcname = rname(node)
# Check if the function exists as an SDFG in a different module
modname = until(funcname, '.')
if ('.' in funcname and len(modname) > 0 and modname in self.globals
Expand Down Expand Up @@ -4576,7 +4583,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
arg = self.scope_vars[modname]
else:
# Fallback to (name, object)
arg = (modname, self.defined[modname])
arg = modname
args.append(arg)
# Otherwise, try to find a default implementation for the SDFG
elif not found_ufunc:
Expand Down Expand Up @@ -4795,12 +4802,18 @@ def _visitname(self, name: str, node: ast.AST):
self.sdfg.add_symbol(result.name, result.dtype)
return result

if name in self.closure.callbacks:
return name

if name in self.sdfg.arrays:
return name

if name in self.sdfg.symbols:
return name

if name in __builtins__:
return name

if name not in self.scope_vars:
raise DaceSyntaxError(self, node, 'Use of undefined variable "%s"' % name)
rname = self.scope_vars[name]
Expand Down Expand Up @@ -4845,33 +4858,43 @@ def visit_NameConstant(self, node: NameConstant):
return self.visit_Constant(node)

def visit_Attribute(self, node: ast.Attribute):
# If visiting an attribute, return attribute value if it's of an array or global
name = until(astutils.unparse(node), '.')
result = self._visitname(name, node)
result = self.visit(node.value)
if isinstance(result, (tuple, list, dict)):
if len(result) > 1:
raise DaceSyntaxError(
self, node.value, f'{type(result)} object cannot use attributes. Try storing the '
'object to a different variable first (e.g., ``a = result; a.attribute``')
else:
result = result[0]

tmpname = f"{result}.{astutils.unparse(node.attr)}"
if tmpname in self.sdfg.arrays:
return tmpname

if isinstance(result, str) and result in self.sdfg.arrays:
arr = self.sdfg.arrays[result]
elif isinstance(result, str) and result in self.scope_arrays:
arr = self.scope_arrays[result]
else:
return result
arr = None

# Try to find sub-SDFG attribute
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_block.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_block, result)
self.last_block.set_default_lineinfo(None)
return result
if arr is not None:
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_block.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_block, result)
self.last_block.set_default_lineinfo(None)
return result

# Otherwise, try to find compile-time attribute (such as shape)
try:
return getattr(arr, node.attr)
except KeyError:
if arr is not None:
return getattr(arr, node.attr)
return getattr(result, node.attr)
except (AttributeError, KeyError):
return result

def visit_List(self, node: ast.List):
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def global_value_to_node(self,
elif isinstance(value, symbolic.symbol):
# Symbols resolve to the symbol name
newnode = ast.Name(id=value.name, ctx=ast.Load())
elif isinstance(value, sympy.Basic): # Symbolic or constant expression
newnode = ast.parse(symbolic.symstr(value)).body[0].value
elif isinstance(value, ast.Name):
newnode = ast.Name(id=value.id, ctx=ast.Load())
elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')):
Expand Down
Loading

0 comments on commit 2811e40

Please sign in to comment.