diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index 6a4837d67a..91c0819257 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -748,6 +748,8 @@ def _Num(self, t): # For complex values, use ``dtype_to_typeclass`` if isinstance(t_n, complex): dtype = dtypes.dtype_to_typeclass(complex) + repr_n = f'{dtype}({t_n.real}, {t_n.imag})' + # Handle large integer values if isinstance(t_n, int): @@ -764,10 +766,8 @@ def _Num(self, t): elif bits >= 64: warnings.warn(f'Value wider than 64 bits encountered in expression ({t_n}), emitting as-is') - if repr_n.endswith("j"): - self.write("%s(0, %s)" % (dtype, repr_n.replace("inf", INFSTR)[:-1])) - else: - self.write(repr_n.replace("inf", INFSTR)) + repr_n = repr_n.replace("inf", INFSTR) + self.write(repr_n) def _List(self, t): raise NotImplementedError('Invalid C++') diff --git a/dace/distr_types.py b/dace/distr_types.py index 1b595a1b84..b60eb4925e 100644 --- a/dace/distr_types.py +++ b/dace/distr_types.py @@ -96,6 +96,10 @@ def _validate(self): raise ValueError('Color must have only logical true (1) or false (0) values.') return True + @property + def dtype(self): + return type(self) + def to_json(self): attrs = serialize.all_properties_to_json(self) retdict = {"type": type(self).__name__, "attributes": attrs} diff --git a/dace/frontend/common/distr.py b/dace/frontend/common/distr.py index 88a6b0c54a..c517028d53 100644 --- a/dace/frontend/common/distr.py +++ b/dace/frontend/common/distr.py @@ -50,14 +50,14 @@ def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: Shape @oprepo.replaces_method('Intracomm', 'Create_cart') -def _intracomm_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', dims: ShapeType): +def _intracomm_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, dims: ShapeType): """ Equivalent to `dace.comm.Cart_create(dims). :param dims: Shape of the process-grid (see `dims` parameter of `MPI_Cart_create`), e.g., [2, 3, 3]. :return: Name of the new process-grid descriptor. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') return _cart_create(pv, sdfg, state, dims) @@ -186,13 +186,13 @@ def _bcast(pv: ProgramVisitor, def _intracomm_bcast(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, - comm: Tuple[str, 'Comm'], + comm: str, buffer: str, root: Union[str, sp.Expr, Number] = 0): """ Equivalent to `dace.comm.Bcast(buffer, root)`. """ from mpi4py import MPI - comm_name, comm_obj = comm + comm_name, comm_obj = comm, pv.globals[comm] if comm_obj == MPI.COMM_WORLD: return _bcast(pv, sdfg, state, buffer, root) # NOTE: Highly experimental @@ -267,12 +267,12 @@ def _alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, inbuffer: str, @oprepo.replaces_method('Intracomm', 'Alltoall') -def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', inp_buffer: str, +def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, inp_buffer: str, out_buffer: str): """ Equivalent to `dace.comm.Alltoall(inp_buffer, out_buffer)`. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') return _alltoall(pv, sdfg, state, inp_buffer, out_buffer) @@ -303,12 +303,12 @@ def _allreduce(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, buffer: str, op @oprepo.replaces_method('Intracomm', 'Allreduce') -def _intracomm_allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', inp_buffer: 'InPlace', +def _intracomm_allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, inp_buffer: 'InPlace', out_buffer: str, op: str): """ Equivalent to `dace.comm.Allreduce(out_buffer, op)`. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') if inp_buffer != MPI.IN_PLACE: @@ -470,12 +470,12 @@ def _send(pv: ProgramVisitor, @oprepo.replaces_method('Intracomm', 'Send') -def _intracomm_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str, +def _intracomm_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str, dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]): """ Equivalent to `dace.comm.end(buffer, dst, tag)`. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') return _send(pv, sdfg, state, buffer, dst, tag) @@ -592,12 +592,12 @@ def _isend(pv: ProgramVisitor, @oprepo.replaces_method('Intracomm', 'Isend') -def _intracomm_isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str, +def _intracomm_isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str, dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]): """ Equivalent to `dace.comm.Isend(buffer, dst, tag, req)`. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') req, _ = sdfg.add_array("isend_req", [1], dace.dtypes.opaque("MPI_Request"), transient=True, find_new_name=True) @@ -690,12 +690,12 @@ def _recv(pv: ProgramVisitor, @oprepo.replaces_method('Intracomm', 'Recv') -def _intracomm_Recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str, +def _intracomm_Recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str, src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]): """ Equivalent to `dace.comm.Recv(buffer, src, tagq)`. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') return _recv(pv, sdfg, state, buffer, src, tag) @@ -810,12 +810,12 @@ def _irecv(pv: ProgramVisitor, @oprepo.replaces_method('Intracomm', 'Irecv') -def _intracomm_irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str, +def _intracomm_irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str, src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]): """ Equivalent to `dace.comm.Irecv(buffer, src, tag, req)`. """ from mpi4py import MPI - icomm_name, icomm_obj = icomm + icomm_name, icomm_obj = icomm, pv.globals[icomm] if icomm_obj != MPI.COMM_WORLD: raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') req, _ = sdfg.add_array("irecv_req", [1], dace.dtypes.opaque("MPI_Request"), transient=True, find_new_name=True) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index f10ff7c8cf..30c457281d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1346,7 +1346,7 @@ def defined(self): # MPI-related stuff result.update({ - k: self.sdfg.process_grids[v] + v: self.sdfg.process_grids[v] for k, v in self.variables.items() if v in self.sdfg.process_grids }) try: @@ -4477,7 +4477,14 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): func = node.func.value if func is None: - funcname = rname(node) + func_result = self.visit(node.func) + if isinstance(func_result, str): + if isinstance(node.func, ast.Attribute): + funcname = f'{func_result}.{node.func.attr}' + else: + funcname = func_result + else: + funcname = rname(node) # Check if the function exists as an SDFG in a different module modname = until(funcname, '.') if ('.' in funcname and len(modname) > 0 and modname in self.globals @@ -4592,7 +4599,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): arg = self.scope_vars[modname] else: # Fallback to (name, object) - arg = (modname, self.defined[modname]) + arg = modname args.append(arg) # Otherwise, try to find a default implementation for the SDFG elif not found_ufunc: @@ -4811,12 +4818,18 @@ def _visitname(self, name: str, node: ast.AST): self.sdfg.add_symbol(result.name, result.dtype) return result + if name in self.closure.callbacks: + return name + if name in self.sdfg.arrays: return name if name in self.sdfg.symbols: return name + if name in __builtins__: + return name + if name not in self.scope_vars: raise DaceSyntaxError(self, node, 'Use of undefined variable "%s"' % name) rname = self.scope_vars[name] @@ -4861,33 +4874,43 @@ def visit_NameConstant(self, node: NameConstant): return self.visit_Constant(node) def visit_Attribute(self, node: ast.Attribute): - # If visiting an attribute, return attribute value if it's of an array or global - name = until(astutils.unparse(node), '.') - result = self._visitname(name, node) + result = self.visit(node.value) + if isinstance(result, (tuple, list, dict)): + if len(result) > 1: + raise DaceSyntaxError( + self, node.value, f'{type(result)} object cannot use attributes. Try storing the ' + 'object to a different variable first (e.g., ``a = result; a.attribute``') + else: + result = result[0] + tmpname = f"{result}.{astutils.unparse(node.attr)}" if tmpname in self.sdfg.arrays: return tmpname + if isinstance(result, str) and result in self.sdfg.arrays: arr = self.sdfg.arrays[result] elif isinstance(result, str) and result in self.scope_arrays: arr = self.scope_arrays[result] else: - return result + arr = None # Try to find sub-SDFG attribute - func = oprepo.Replacements.get_attribute(type(arr), node.attr) - if func is not None: - # A new state is likely needed here, e.g., for transposition (ndarray.T) - self._add_state('%s_%d' % (type(node).__name__, node.lineno)) - self.last_block.set_default_lineinfo(self.current_lineinfo) - result = func(self, self.sdfg, self.last_block, result) - self.last_block.set_default_lineinfo(None) - return result + if arr is not None: + func = oprepo.Replacements.get_attribute(type(arr), node.attr) + if func is not None: + # A new state is likely needed here, e.g., for transposition (ndarray.T) + self._add_state('%s_%d' % (type(node).__name__, node.lineno)) + self.last_block.set_default_lineinfo(self.current_lineinfo) + result = func(self, self.sdfg, self.last_block, result) + self.last_block.set_default_lineinfo(None) + return result # Otherwise, try to find compile-time attribute (such as shape) try: - return getattr(arr, node.attr) - except KeyError: + if arr is not None: + return getattr(arr, node.attr) + return getattr(result, node.attr) + except (AttributeError, KeyError): return result def visit_List(self, node: ast.List): diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 17c5e8b03d..716b483fb5 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -531,6 +531,8 @@ def global_value_to_node(self, elif isinstance(value, symbolic.symbol): # Symbols resolve to the symbol name newnode = ast.Name(id=value.name, ctx=ast.Load()) + elif isinstance(value, sympy.Basic): # Symbolic or constant expression + newnode = ast.parse(symbolic.symstr(value)).body[0].value elif isinstance(value, ast.Name): newnode = ast.Name(id=value.id, ctx=ast.Load()) elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')): diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 39d9e703fa..e64bca7790 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -318,6 +318,9 @@ def _numpy_full(pv: ProgramVisitor, """ Creates and array of the specified shape and initializes it with the fill value. """ + if isinstance(shape, Number) or symbolic.issymbolic(shape): + shape = [shape] + is_data = False if isinstance(fill_value, (Number, np.bool_)): vtype = dtypes.dtype_to_typeclass(type(fill_value)) @@ -553,8 +556,13 @@ def _numpy_rot90(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, k=1 @oprepo.replaces('numpy.arange') @oprepo.replaces('dace.arange') -def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): - """ Implementes numpy.arange """ +def _arange(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + *args, + dtype: dtypes.typeclass = None, + like: Optional[str] = None): + """ Implements numpy.arange """ start = 0 stop = None @@ -568,35 +576,42 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): else: start, stop, step = args + if isinstance(start, str): + raise TypeError(f'Cannot compile numpy.arange with a scalar start value "{start}" (only constants and symbolic ' + 'expressions are supported). Please use numpy.linspace instead.') + if isinstance(stop, str): + raise TypeError(f'Cannot compile numpy.arange with a scalar stop value "{stop}" (only constants and symbolic ' + 'expressions are supported). Please use numpy.linspace instead.') + if isinstance(step, str): + raise TypeError(f'Cannot compile numpy.arange with a scalar step value "{step}" (only constants and symbolic ' + 'expressions are supported). Please use numpy.linspace instead.') + actual_step = step if isinstance(start, Number) and isinstance(stop, Number): actual_step = type(start + step)(start + step) - start if any(not isinstance(s, Number) for s in [start, stop, step]): - shape = (symbolic.int_ceil(stop - start, step), ) + if step == 1: # Common case where ceiling is not necessary + shape = (stop - start,) + else: + shape = (symbolic.int_ceil(stop - start, step), ) else: shape = (np.int64(np.ceil((stop - start) / step)), ) - if not isinstance(shape[0], Number) and ('dtype' not in kwargs or kwargs['dtype'] == None): - raise NotImplementedError("The current implementation of numpy.arange requires that the output dtype is given " - "when at least one of (start, stop, step) is symbolic.") + # Infer dtype from input arguments + if dtype is None: + dtype, _ = _result_type(args) + # TODO: Unclear what 'like' does - # if 'like' in kwargs and kwargs['like'] != None: - # outname, outarr = sdfg.add_temp_transient_like(sdfg.arrays[kwargs['like']]) + # if like is not None: + # outname, outarr = sdfg.add_temp_transient_like(sdfg.arrays[like]) # outarr.shape = shape - if 'dtype' in kwargs and kwargs['dtype'] != None: - dtype = kwargs['dtype'] - if not isinstance(dtype, dtypes.typeclass): - dtype = dtypes.dtype_to_typeclass(dtype) - outname, outarr = sdfg.add_temp_transient(shape, dtype) - else: - # infer dtype based on args's dtype - # (since the `dtype` keyword argument isn't given, none of the arguments can be symbolic) - if any(isinstance(arg, (float, np.float32, np.float64)) for arg in args): - dtype = dtypes.float64 - else: - dtype = dtypes.int64 - outname, outarr = sdfg.add_temp_transient(shape, dtype) + if not isinstance(dtype, dtypes.typeclass): + dtype = dtypes.dtype_to_typeclass(dtype) + outname, outarr = sdfg.add_temp_transient(shape, dtype) + + start = f'decltype(__out)({start})' + actual_step = f'decltype(__out)({actual_step})' state.add_mapped_tasklet(name="_numpy_arange_", map_ranges={'__i': f"0:{shape[0]}"}, @@ -608,6 +623,131 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): return outname +def _add_axis_to_shape(shape: Sequence[symbolic.SymbolicType], axis: int, + axis_value: Any) -> List[symbolic.SymbolicType]: + if axis > len(shape): + raise ValueError(f'axis {axis} is out of bounds for array of dimension {len(shape)}') + if axis < 0: + naxis = len(shape) + 1 + axis + if naxis < 0 or naxis > len(shape): + raise ValueError(f'axis {axis} is out of bounds for array of dimension {len(shape)}') + axis = naxis + + # Make a new shape list with the inserted dimension + new_shape = [None] * (len(shape) + 1) + for i in range(len(shape) + 1): + if i == axis: + new_shape[i] = axis_value + elif i < axis: + new_shape[i] = shape[i] + else: + new_shape[i] = shape[i - 1] + + return new_shape + + +@oprepo.replaces('numpy.linspace') +def _linspace(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + start: Union[Number, symbolic.SymbolicType, str], + stop: Union[Number, symbolic.SymbolicType, str], + num: Union[Integral, symbolic.SymbolicType] = 50, + endpoint: bool = True, + retstep: bool = False, + dtype: dtypes.typeclass = None, + axis: int = 0): + """ Implements numpy.linspace """ + # Argument checks + if not isinstance(num, (Integral, sp.Basic)): + raise TypeError('numpy.linspace can only be compiled when the ``num`` argument is symbolic or constant.') + if not isinstance(axis, Integral): + raise TypeError('numpy.linspace can only be compiled when the ``axis`` argument is constant.') + + # Start and stop are broadcast together, then, a new dimension is added to axis (taken from ``ndim + 1``), + # along which the numbers are filled. + start_shape = sdfg.arrays[start].shape if (isinstance(start, str) and start in sdfg.arrays) else [] + stop_shape = sdfg.arrays[stop].shape if (isinstance(stop, str) and stop in sdfg.arrays) else [] + + shape, ranges, outind, ind1, ind2 = _broadcast_together(start_shape, stop_shape) + shape_with_axis = _add_axis_to_shape(shape, axis, num) + ranges_with_axis = _add_axis_to_shape(ranges, axis, ('__sind', f'0:{symbolic.symstr(num)}')) + if outind: + outind_with_axis = _add_axis_to_shape(outind.split(', '), axis, '__sind') + else: + outind_with_axis = ['__sind'] + + if dtype is None: + # Infer output type from start and stop + start_type = sdfg.arrays[start] if (isinstance(start, str) and start in sdfg.arrays) else start + stop_type = sdfg.arrays[stop] if (isinstance(stop, str) and stop in sdfg.arrays) else stop + + dtype, _ = _result_type((start_type, stop_type), 'Add') + + # From the NumPy documentation: The inferred dtype will never be an integer; float is chosen even if the + # arguments would produce an array of integers. + if dtype in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, + dtypes.uint64): + dtype = dtypes.dtype_to_typeclass(float) + + outname, _ = sdfg.add_temp_transient(shape_with_axis, dtype) + + if endpoint == True: + num -= 1 + + # Fill in input memlets as necessary + inputs = {} + if isinstance(start, str) and start in sdfg.arrays: + index = f'[{ind1}]' if ind1 else '' + inputs['__start'] = Memlet(f'{start}{index}') + startcode = '__start' + else: + startcode = symbolic.symstr(start) + + if isinstance(stop, str) and stop in sdfg.arrays: + index = f'[{ind2}]' if ind2 else '' + inputs['__stop'] = Memlet(f'{stop}{index}') + stopcode = '__stop' + else: + stopcode = symbolic.symstr(stop) + + # Create tasklet code based on inputs + code = f'__out = {startcode} + __sind * decltype(__out)({stopcode} - {startcode}) / decltype(__out)({symbolic.symstr(num)})' + + state.add_mapped_tasklet(name="linspace", + map_ranges=ranges_with_axis, + inputs=inputs, + code=code, + outputs={'__out': Memlet(f"{outname}[{','.join(outind_with_axis)}]")}, + external_edges=True) + + if retstep == False: + return outname + + # Return step if requested + + # Handle scalar outputs + if not ranges: + ranges = [('__unused', '0:1')] + out_index = f'[{outind}]' + + if len(shape) > 0: + stepname, _ = sdfg.add_temp_transient(shape, dtype) + else: + stepname, _ = sdfg.add_scalar(sdfg.temp_data_name(), dtype, transient=True) + out_index = '[0]' + + state.add_mapped_tasklet( + 'retstep', + ranges, + copy.deepcopy(inputs), + f'__out = decltype(__out)({stopcode} - {startcode}) / decltype(__out)({symbolic.symstr(num)})', + {'__out': Memlet(f"{stepname}{out_index}")}, + external_edges=True) + + return outname, stepname + + @oprepo.replaces('elementwise') @oprepo.replaces('dace.elementwise') def _elementwise(pv: 'ProgramVisitor', @@ -713,9 +853,9 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: def _complex_to_scalar(complex_type: dace.typeclass): - if complex_type is dace.complex64: + if complex_type == dace.complex64: return dace.float32 - elif complex_type is dace.complex128: + elif complex_type == dace.complex128: return dace.float64 else: return complex_type @@ -819,7 +959,8 @@ def _len_array(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a: str): return sdfg.arrays[a].shape[0] if a in sdfg.constants_prop: return len(sdfg.constants[a]) - raise TypeError(f'`len` is not supported for input "{a}" (type {type(a)})') + else: + return len(a) @oprepo.replaces('transpose') @@ -1637,8 +1778,17 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi else: # Operators with 3 or more arguments result_type = _np_result_type(dtypes_for_result) + coarse_result_type = None + if result_type in complex_types: + coarse_result_type = 3 # complex + elif result_type in float_types: + coarse_result_type = 2 # float + elif result_type in signed_types: + coarse_result_type = 1 # signed integer, bool + else: + coarse_result_type = 0 # unsigned integer for i, t in enumerate(coarse_types): - if t != result_type: + if t != coarse_result_type: casting[i] = _cast_str(result_type) return result_type, casting @@ -2520,6 +2670,13 @@ def ufuncs(): code="__out = log1p(__in1)", reduce=None, initial=np.log1p.identity), + clip=dict(name="_numpy_clip_", + operator=None, + inputs=["__in_a", "__in_amin", "__in_amax"], + outputs=["__out"], + code="__out = min(max(__in_a, __in_amin), __in_amax)", + reduce=None, + initial=np.inf), sqrt=dict(name="_numpy_sqrt_", operator="Sqrt", inputs=["__in1"], @@ -4095,14 +4252,13 @@ def implement_ufunc_outer(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDF @oprepo.replaces('numpy.reshape') -def reshape( - pv: ProgramVisitor, - sdfg: SDFG, - state: SDFGState, - arr: str, - newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], - order: StringLiteral = StringLiteral('C') -) -> str: +def reshape(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + arr: str, + newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], + order: StringLiteral = StringLiteral('C'), + strides: Optional[Any] = None) -> str: if isinstance(arr, (list, tuple)) and len(arr) == 1: arr = arr[0] desc = sdfg.arrays[arr] @@ -4116,10 +4272,11 @@ def reshape( # New shape and strides as symbolic expressions newshape = [symbolic.pystr_to_symbolic(s) for s in newshape] - if fortran_strides: - strides = [data._prod(newshape[:i]) for i in range(len(newshape))] - else: - strides = [data._prod(newshape[i + 1:]) for i in range(len(newshape))] + if strides is None: + if fortran_strides: + strides = [data._prod(newshape[:i]) for i in range(len(newshape))] + else: + strides = [data._prod(newshape[i + 1:]) for i in range(len(newshape))] newarr, newdesc = sdfg.add_view(arr, newshape, @@ -4334,9 +4491,13 @@ def _ndarray_reshape( sdfg: SDFG, state: SDFGState, arr: str, - newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], + *newshape: Union[str, symbolic.SymbolicType, Tuple[Union[str, symbolic.SymbolicType]]], order: StringLiteral = StringLiteral('C') ) -> str: + if len(newshape) == 0: + raise TypeError('reshape() takes at least 1 argument (0 given)') + if len(newshape) == 1 and isinstance(newshape, (list, tuple)): + newshape = newshape[0] return reshape(pv, sdfg, state, arr, newshape, order) @@ -4843,3 +5004,407 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLite for op, method in _boolop_to_method.items(): _makeboolop(op, method) + + +@oprepo.replaces('numpy.concatenate') +def _concat(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + arrays: Tuple[Any], + axis: Optional[int] = 0, + out: Optional[Any] = None, + *, + dtype=None, + casting: str = 'same_kind'): + if dtype is not None and out is not None: + raise ValueError('Arguments dtype and out cannot be given together') + if casting != 'same_kind': + raise NotImplementedError('The casting argument is currently unsupported') + if not isinstance(arrays, (tuple, list)): + raise ValueError('List of arrays is not iterable, cannot compile concatenation') + if axis is not None and not isinstance(axis, Integral): + raise ValueError('Axis is not a compile-time evaluatable integer, cannot compile concatenation') + if len(arrays) == 1: + return arrays[0] + for i in range(len(arrays)): + if arrays[i] not in sdfg.arrays: + raise TypeError(f'Index {i} is not an array') + if out is not None: + if out not in sdfg.arrays: + raise TypeError('Output is not an array') + dtype = sdfg.arrays[out].dtype + + descs = [sdfg.arrays[arr] for arr in arrays] + shape = list(descs[0].shape) + + if axis is None: # Flatten arrays, then concatenate + arrays = [flat(visitor, sdfg, state, arr) for arr in arrays] + descs = [sdfg.arrays[arr] for arr in arrays] + shape = list(descs[0].shape) + axis = 0 + else: + # Check shapes for validity + first_shape = copy.copy(shape) + first_shape[axis] = 0 + for i, d in enumerate(descs[1:]): + other_shape = list(d.shape) + other_shape[axis] = 0 + if other_shape != first_shape: + raise ValueError(f'Array shapes do not match at index {i}') + + shape[axis] = sum(desc.shape[axis] for desc in descs) + if out is None: + if dtype is None: + dtype = descs[0].dtype + name, odesc = sdfg.add_temp_transient(shape, dtype, storage=descs[0].storage, lifetime=descs[0].lifetime) + else: + name = out + odesc = sdfg.arrays[out] + + # Make copies + w = state.add_write(name) + offset = 0 + subset = subsets.Range.from_array(odesc) + for arr, desc in zip(arrays, descs): + r = state.add_read(arr) + subset = copy.deepcopy(subset) + subset[axis] = (offset, offset + desc.shape[axis] - 1, 1) + state.add_edge(r, None, w, None, Memlet(data=name, subset=subset)) + offset += desc.shape[axis] + + return name + + +@oprepo.replaces('numpy.stack') +def _stack(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + arrays: Tuple[Any], + axis: int = 0, + out: Any = None, + *, + dtype=None, + casting: str = 'same_kind'): + if dtype is not None and out is not None: + raise ValueError('Arguments dtype and out cannot be given together') + if casting != 'same_kind': + raise NotImplementedError('The casting argument is currently unsupported') + if not isinstance(arrays, (tuple, list)): + raise ValueError('List of arrays is not iterable, cannot compile stack call') + if not isinstance(axis, Integral): + raise ValueError('Axis is not a compile-time evaluatable integer, cannot compile stack call') + + for i in range(len(arrays)): + if arrays[i] not in sdfg.arrays: + raise TypeError(f'Index {i} is not an array') + + descs = [sdfg.arrays[a] for a in arrays] + shape = descs[0].shape + for i, d in enumerate(descs[1:]): + if d.shape != shape: + raise ValueError(f'Array shapes are not equal ({shape} != {d.shape} at index {i})') + + if axis > len(shape): + raise ValueError(f'axis {axis} is out of bounds for array of dimension {len(shape)}') + if axis < 0: + naxis = len(shape) + 1 + axis + if naxis < 0 or naxis > len(shape): + raise ValueError(f'axis {axis} is out of bounds for array of dimension {len(shape)}') + axis = naxis + + # Stacking is implemented as a reshape followed by concatenation + reshaped = [] + for arr, desc in zip(arrays, descs): + # Make a reshaped view with the inserted dimension + new_shape = [0] * (len(shape) + 1) + new_strides = [0] * (len(shape) + 1) + for i in range(len(shape) + 1): + if i == axis: + new_shape[i] = 1 + new_strides[i] = desc.strides[i - 1] if i != 0 else desc.strides[i] + elif i < axis: + new_shape[i] = shape[i] + new_strides[i] = desc.strides[i] + else: + new_shape[i] = shape[i - 1] + new_strides[i] = desc.strides[i - 1] + + rname = reshape(visitor, sdfg, state, arr, new_shape, strides=new_strides) + reshaped.append(rname) + + return _concat(visitor, sdfg, state, reshaped, axis, out, dtype=dtype, casting=casting) + + +@oprepo.replaces('numpy.vstack') +@oprepo.replaces('numpy.row_stack') +def _vstack(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + tup: Tuple[Any], + *, + dtype=None, + casting: str = 'same_kind'): + if not isinstance(tup, (tuple, list)): + raise ValueError('List of arrays is not iterable, cannot compile stack call') + if tup[0] not in sdfg.arrays: + raise TypeError(f'Index 0 is not an array') + + # In the 1-D case, stacking is performed along the first axis + if len(sdfg.arrays[tup[0]].shape) == 1: + return _stack(visitor, sdfg, state, tup, axis=0, out=None, dtype=dtype, casting=casting) + # Otherwise, concatenation is performed + return _concat(visitor, sdfg, state, tup, axis=0, out=None, dtype=dtype, casting=casting) + + +@oprepo.replaces('numpy.hstack') +@oprepo.replaces('numpy.column_stack') +def _hstack(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + tup: Tuple[Any], + *, + dtype=None, + casting: str = 'same_kind'): + if not isinstance(tup, (tuple, list)): + raise ValueError('List of arrays is not iterable, cannot compile stack call') + if tup[0] not in sdfg.arrays: + raise TypeError(f'Index 0 is not an array') + + # In the 1-D case, concatenation is performed along the first axis + if len(sdfg.arrays[tup[0]].shape) == 1: + return _concat(visitor, sdfg, state, tup, axis=0, out=None, dtype=dtype, casting=casting) + + return _concat(visitor, sdfg, state, tup, axis=1, out=None, dtype=dtype, casting=casting) + + +@oprepo.replaces('numpy.dstack') +def _dstack(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + tup: Tuple[Any], + *, + dtype=None, + casting: str = 'same_kind'): + if not isinstance(tup, (tuple, list)): + raise ValueError('List of arrays is not iterable, cannot compile a stack call') + if tup[0] not in sdfg.arrays: + raise TypeError(f'Index 0 is not an array') + if len(sdfg.arrays[tup[0]].shape) < 3: + raise NotImplementedError('dstack is not implemented for arrays that are smaller than 3D') + + return _concat(visitor, sdfg, state, tup, axis=2, out=None, dtype=dtype, casting=casting) + + +def _split_core(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, + indices_or_sections: Union[int, Sequence[symbolic.SymbolicType], str], axis: int, allow_uneven: bool): + # Argument checks + if not isinstance(ary, str) or ary not in sdfg.arrays: + raise TypeError('Split object must be an array') + if not isinstance(axis, Integral): + raise ValueError('Cannot determine split dimension, axis is not a compile-time evaluatable integer') + + desc = sdfg.arrays[ary] + + # Test validity of axis + orig_axis = axis + if axis < 0: + axis = len(desc.shape) + axis + if axis < 0 or axis >= len(desc.shape): + raise ValueError(f'axis {orig_axis} is out of bounds for array of dimension {len(desc.shape)}') + + # indices_or_sections may only be an integer (not symbolic), list of integers, list of symbols, or an array + if isinstance(indices_or_sections, str): + raise ValueError('Array-indexed split cannot be compiled due to data-dependent sizes. ' + 'Consider using numpy.reshape instead.') + elif isinstance(indices_or_sections, (list, tuple)): + if any(isinstance(i, str) for i in indices_or_sections): + raise ValueError('Array-indexed split cannot be compiled due to data-dependent sizes. ' + 'Use symbolic values as an argument instead.') + # Sequence is given + sections = indices_or_sections + elif isinstance(indices_or_sections, Integral): # Constant integer given + if indices_or_sections <= 0: + raise ValueError('Number of sections must be larger than zero.') + + # If uneven sizes are not allowed and ary shape is numeric, check evenness + if not allow_uneven and not symbolic.issymbolic(desc.shape[axis]): + if desc.shape[axis] % indices_or_sections != 0: + raise ValueError('Array split does not result in an equal division. Consider using numpy.array_split ' + 'instead.') + if indices_or_sections > desc.shape[axis]: + raise ValueError('Cannot compile array split as it will result in empty arrays.') + + # Sequence is not given, compute sections + # Mimic behavior of array_split in numpy: Sections are [s+1 x N%s], s, ..., s + size = desc.shape[axis] // indices_or_sections + remainder = desc.shape[axis] % indices_or_sections + sections = [] + offset = 0 + for _ in range(min(remainder, indices_or_sections)): + offset += size + 1 + sections.append(offset) + for _ in range(remainder, indices_or_sections - 1): + offset += size + sections.append(offset) + + elif symbolic.issymbolic(indices_or_sections): + raise ValueError('Symbolic split cannot be compiled due to output tuple size being unknown. ' + 'Consider using numpy.reshape instead.') + else: + raise TypeError(f'Unsupported type {type(indices_or_sections)} for indices_or_sections in numpy.split') + + # Split according to sections + r = state.add_read(ary) + result = [] + offset = 0 + for section in sections: + shape = list(desc.shape) + shape[axis] = section - offset + name, _ = sdfg.add_temp_transient(shape, desc.dtype, storage=desc.storage, lifetime=desc.lifetime) + # Add copy + w = state.add_write(name) + subset = subsets.Range.from_array(desc) + subset[axis] = (offset, offset + shape[axis] - 1, 1) + offset += shape[axis] + state.add_nedge(r, w, Memlet(data=ary, subset=subset)) + result.append(name) + + # Add final section + shape = list(desc.shape) + shape[axis] -= offset + name, _ = sdfg.add_temp_transient(shape, desc.dtype, storage=desc.storage, lifetime=desc.lifetime) + w = state.add_write(name) + subset = subsets.Range.from_array(desc) + subset[axis] = (offset, offset + shape[axis] - 1, 1) + state.add_nedge(r, w, Memlet(data=ary, subset=subset)) + result.append(name) + + # Always return a list of results, even if the size is 1 + return result + + +@oprepo.replaces('numpy.split') +def _split(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + ary: str, + indices_or_sections: Union[symbolic.SymbolicType, List[symbolic.SymbolicType], str], + axis: int = 0): + return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis, allow_uneven=False) + + +@oprepo.replaces('numpy.array_split') +def _array_split(visitor: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + ary: str, + indices_or_sections: Union[symbolic.SymbolicType, List[symbolic.SymbolicType], str], + axis: int = 0): + return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis, allow_uneven=True) + + +@oprepo.replaces('numpy.dsplit') +def _dsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, + indices_or_sections: Union[symbolic.SymbolicType, List[symbolic.SymbolicType], str]): + if isinstance(ary, str) and ary in sdfg.arrays: + if len(sdfg.arrays[ary].shape) < 3: + raise ValueError('Array dimensionality must be 3 or above for dsplit') + return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis=2, allow_uneven=False) + + +@oprepo.replaces('numpy.hsplit') +def _hsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, + indices_or_sections: Union[symbolic.SymbolicType, List[symbolic.SymbolicType], str]): + if isinstance(ary, str) and ary in sdfg.arrays: + # In case of a 1D array, split with axis=0 + if len(sdfg.arrays[ary].shape) <= 1: + return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis=0, allow_uneven=False) + return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis=1, allow_uneven=False) + + +@oprepo.replaces('numpy.vsplit') +def _vsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, + indices_or_sections: Union[symbolic.SymbolicType, List[symbolic.SymbolicType], str]): + return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis=0, allow_uneven=False) + + +############################################################################################################ +# Fast Fourier Transform numpy package (numpy.fft) + +def _real_to_complex(real_type: dace.typeclass): + if real_type == dace.float32: + return dace.complex64 + elif real_type == dace.float64: + return dace.complex128 + else: + return real_type + + +def _fft_core(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + a: str, + n: Optional[dace.symbolic.SymbolicType] = None, + axis=-1, + norm: StringLiteral = StringLiteral('backward'), + is_inverse: bool = False): + from dace.libraries.fft.nodes import FFT, IFFT # Avoid import loops + if axis != 0 and axis != -1: + raise NotImplementedError('Only one dimensional arrays are supported at the moment') + if not isinstance(a, str) or a not in sdfg.arrays: + raise ValueError('Input must be a valid array') + + libnode = FFT('fft') if not is_inverse else IFFT('ifft') + + desc = sdfg.arrays[a] + N = desc.shape[axis] + + # If n is not None, either pad input or slice and add a view + if n is not None: + raise NotImplementedError + + # Compute factor + if norm == 'forward': + factor = (1 / N) if not is_inverse else 1 + elif norm == 'backward': + factor = 1 if not is_inverse else (1 / N) + elif norm == 'ortho': + factor = sp.sqrt(1 / N) + else: + raise ValueError('norm argument can only be one of "forward", "backward", or "ortho".') + libnode.factor = factor + + # Compute output type from input type + if is_inverse and desc.dtype not in (dace.complex64, dace.complex128): + raise TypeError(f'Inverse FFT only accepts complex inputs, got {desc.dtype}') + dtype = _real_to_complex(desc.dtype) + + name, odesc = sdfg.add_temp_transient_like(desc, dtype) + r = state.add_read(a) + w = state.add_write(name) + state.add_edge(r, None, libnode, '_inp', Memlet.from_array(a, desc)) + state.add_edge(libnode, '_out', w, None, Memlet.from_array(name, odesc)) + + return name + + +@oprepo.replaces('numpy.fft.fft') +def _fft(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + a: str, + n: Optional[dace.symbolic.SymbolicType] = None, + axis=-1, + norm: StringLiteral = StringLiteral('backward')): + return _fft_core(pv, sdfg, state, a, n, axis, norm, False) + + +@oprepo.replaces('numpy.fft.ifft') +def _ifft(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + a, + n=None, + axis=-1, + norm: StringLiteral = StringLiteral('backward')): + return _fft_core(pv, sdfg, state, a, n, axis, norm, True) diff --git a/dace/libraries/blas/nodes/gemv.py b/dace/libraries/blas/nodes/gemv.py index d55e2e3b04..6464cc26d4 100644 --- a/dace/libraries/blas/nodes/gemv.py +++ b/dace/libraries/blas/nodes/gemv.py @@ -729,6 +729,9 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): dtype_a = outer_array_a.dtype.type dtype = outer_array_x.dtype.base_type veclen = outer_array_x.dtype.veclen + alpha = f'{dtype.ctype}({node.alpha})' + beta = f'{dtype.ctype}({node.beta})' + m = m or node.m n = n or node.n if m is None: @@ -764,8 +767,17 @@ def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): func = func.lower() + 'gemv' - code = f"""cblas_{func}({layout}, {trans}, {m}, {n}, {node.alpha}, _A, {lda}, - _x, {strides_x[0]}, {node.beta}, _y, {strides_y[0]});""" + code = '' + if dtype in (dace.complex64, dace.complex128): + code = f''' + {dtype.ctype} __alpha = {alpha}; + {dtype.ctype} __beta = {beta}; + ''' + alpha = '&__alpha' + beta = '&__beta' + + code += f"""cblas_{func}({layout}, {trans}, {m}, {n}, {alpha}, _A, {lda}, + _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});""" tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, diff --git a/dace/libraries/fft/__init__.py b/dace/libraries/fft/__init__.py new file mode 100644 index 0000000000..71fb014f32 --- /dev/null +++ b/dace/libraries/fft/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from dace.library import register_library +from .nodes import * +from .environments import * + +register_library(__name__, "fft") diff --git a/dace/libraries/fft/algorithms/__init__.py b/dace/libraries/fft/algorithms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/libraries/fft/algorithms/dft.py b/dace/libraries/fft/algorithms/dft.py new file mode 100644 index 0000000000..340dfed22d --- /dev/null +++ b/dace/libraries/fft/algorithms/dft.py @@ -0,0 +1,45 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +One-dimensional Discrete Fourier Transform (DFT) native implementations. +""" +import dace +import numpy as np +import math + + +# Native, naive version of the Discrete Fourier Transform +@dace.program +def dft(_inp, _out, N: dace.compiletime, factor: dace.compiletime): + i = np.arange(N) + e = np.exp(-2j * np.pi * i * i[:, None] / N) + _out[:] = factor * (e @ _inp.astype(dace.complex128)) + + +@dace.program +def idft(_inp, _out, N: dace.compiletime, factor: dace.compiletime): + i = np.arange(N) + e = np.exp(2j * np.pi * i * i[:, None] / N) + _out[:] = factor * (e @ _inp.astype(dace.complex128)) + + +# Single-map version of DFT, useful for integrating small Fourier transforms into other operations +@dace.program +def dft_explicit(_inp, _out, N: dace.compiletime, factor: dace.compiletime): + _out[:] = 0 + for i, n in dace.map[0:N, 0:N]: + with dace.tasklet: + inp << _inp[n] + exponent = 2 * math.pi * i * n / N + b = decltype(b)(math.cos(exponent), -math.sin(exponent)) * inp * factor + b >> _out(1, lambda a, b: a + b)[i] + + +@dace.program +def idft_explicit(_inp, _out, N: dace.compiletime, factor: dace.compiletime): + _out[:] = 0 + for i, n in dace.map[0:N, 0:N]: + with dace.tasklet: + inp << _inp[n] + exponent = 2 * math.pi * i * n / N + b = decltype(b)(math.cos(exponent), math.sin(exponent)) * inp * factor + b >> _out(1, lambda a, b: a + b)[i] diff --git a/dace/libraries/fft/environments/__init__.py b/dace/libraries/fft/environments/__init__.py new file mode 100644 index 0000000000..0900214e68 --- /dev/null +++ b/dace/libraries/fft/environments/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from .cufft import * diff --git a/dace/libraries/fft/environments/cufft.py b/dace/libraries/fft/environments/cufft.py new file mode 100644 index 0000000000..dd243d376a --- /dev/null +++ b/dace/libraries/fft/environments/cufft.py @@ -0,0 +1,21 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library + + +@dace.library.environment +class cuFFT: + + cmake_minimum_version = None + cmake_packages = ["CUDA"] + cmake_variables = {} + cmake_includes = [] + cmake_libraries = ["cufft"] + cmake_compile_flags = [] + cmake_link_flags = [] + cmake_files = [] + + headers = {'frame': ["cufft.h", "cufftXt.h"], 'cuda': ["cufft.h", "cufftXt.h"]} + state_fields = [] + init_code = "" + finalize_code = "" + dependencies = [] diff --git a/dace/libraries/fft/nodes/__init__.py b/dace/libraries/fft/nodes/__init__.py new file mode 100644 index 0000000000..dd8f132aa4 --- /dev/null +++ b/dace/libraries/fft/nodes/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from .fft import FFT, IFFT diff --git a/dace/libraries/fft/nodes/fft.py b/dace/libraries/fft/nodes/fft.py new file mode 100644 index 0000000000..bc85f8785b --- /dev/null +++ b/dace/libraries/fft/nodes/fft.py @@ -0,0 +1,204 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Implements Forward and Inverse Fast Fourier Transform (FFT) library nodes +""" +import warnings + +from dace import data, dtypes, SDFG, SDFGState, symbolic, library, nodes, properties +from dace import transformation as xf +from dace.libraries.fft import environments as env + + +# Define the library nodes +@library.node +class FFT(nodes.LibraryNode): + implementations = {} + default_implementation = 'pure' + + factor = properties.SymbolicProperty(desc='Coefficient to multiply outputs. Used for normalization', default=1.0) + + def __init__(self, name, *args, schedule=None, **kwargs): + super().__init__(name, *args, schedule=schedule, inputs={'_inp'}, outputs={'_out'}, **kwargs) + + +@library.node +class IFFT(nodes.LibraryNode): + implementations = {} + default_implementation = 'pure' + + factor = properties.SymbolicProperty(desc='Coefficient to multiply outputs. Used for normalization', default=1.0) + + def __init__(self, name, *args, schedule=None, **kwargs): + super().__init__(name, *args, schedule=schedule, inputs={'_inp'}, outputs={'_out'}, **kwargs) + + +################################################################################################## +# Native SDFG expansions +################################################################################################## + + +@library.register_expansion(FFT, 'pure') +class DFTExpansion(xf.ExpandTransformation): + environments = [] + + @staticmethod + def expansion(node: FFT, parent_state: SDFGState, parent_sdfg: SDFG) -> SDFG: + from dace.libraries.fft.algorithms import dft # Lazy import functions + input, output = _get_input_and_output(parent_state, node) + indesc = parent_sdfg.arrays[input] + outdesc = parent_sdfg.arrays[output] + if len(indesc.shape) != 1: + raise NotImplementedError('Native SDFG expansion for FFT does not yet support N-dimensional inputs') + + warnings.warn('Performance Warning: No assumptions on FFT input size, falling back to DFT') + return dft.dft_explicit.to_sdfg(indesc, outdesc, N=indesc.shape[0], factor=node.factor) + + +@library.register_expansion(IFFT, 'pure') +class IDFTExpansion(xf.ExpandTransformation): + environments = [] + + @staticmethod + def expansion(node: IFFT, parent_state: SDFGState, parent_sdfg: SDFG) -> SDFG: + from dace.libraries.fft.algorithms import dft # Lazy import functions + input, output = _get_input_and_output(parent_state, node) + indesc = parent_sdfg.arrays[input] + outdesc = parent_sdfg.arrays[output] + if len(indesc.shape) != 1: + raise NotImplementedError('Native SDFG expansion for IFFT does not yet support N-dimensional inputs') + + warnings.warn('Performance Warning: No assumptions on IFFT input size, falling back to DFT') + return dft.idft_explicit.to_sdfg(indesc, outdesc, N=indesc.shape[0], factor=node.factor) + + +################################################################################################## +# cuFFT expansions +################################################################################################## + + +@library.register_expansion(FFT, 'cuFFT') +class cuFFTFFTExpansion(xf.ExpandTransformation): + environments = [env.cuFFT] + plan_uid = 0 + + @staticmethod + def expansion(node: FFT, parent_state: SDFGState, parent_sdfg: SDFG) -> SDFG: + input, output = _get_input_and_output(parent_state, node) + indesc = parent_sdfg.arrays[input] + outdesc = parent_sdfg.arrays[output] + if str(node.factor) != '1': + raise NotImplementedError('Multiplicative post-FFT factors are not yet implemented') + return _generate_cufft_code(indesc, outdesc, parent_sdfg, False) + + +@library.register_expansion(IFFT, 'cuFFT') +class cuFFTIFFTExpansion(xf.ExpandTransformation): + environments = [env.cuFFT] + plan_uid = 0 + + @staticmethod + def expansion(node: IFFT, parent_state: SDFGState, parent_sdfg: SDFG) -> SDFG: + input, output = _get_input_and_output(parent_state, node) + indesc = parent_sdfg.arrays[input] + outdesc = parent_sdfg.arrays[output] + if str(node.factor) != '1': + raise NotImplementedError('Multiplicative post-FFT factors are not yet implemented') + return _generate_cufft_code(indesc, outdesc, parent_sdfg, True) + + +def _generate_cufft_code(indesc: data.Data, outdesc: data.Data, sdfg: SDFG, is_inverse: bool): + from dace.codegen.targets import cpp # Avoid import loops + if len(indesc.shape) not in (1, 2, 3): + raise ValueError('cuFFT only supports 1/2/3-dimensional FFT') + if indesc.storage != dtypes.StorageType.GPU_Global: + raise ValueError('cuFFT implementation requires input array to be on GPU') + if outdesc.storage != dtypes.StorageType.GPU_Global: + raise ValueError('cuFFT implementation requires output array to be on GPU') + + cufft_type = _types_to_cufft(indesc.dtype, outdesc.dtype) + init_code = '' + exit_code = '' + callsite_code = '' + + # Make a unique name for this plan + if not is_inverse: + plan_name = f'fwdplan{cuFFTFFTExpansion.plan_uid}' + cuFFTFFTExpansion.plan_uid += 1 + direction = 'CUFFT_FORWARD' + tasklet_prefix = '' + else: + plan_name = f'invplan{cuFFTIFFTExpansion.plan_uid}' + cuFFTIFFTExpansion.plan_uid += 1 + direction = 'CUFFT_INVERSE' + tasklet_prefix = 'i' + + fields = [ + f'cufftHandle {plan_name};', + ] + plan_name = f'__state->{plan_name}' + + init_code += f''' + cufftCreate(&{plan_name}); + ''' + exit_code += f''' + cufftDestroy({plan_name}); + ''' + + cdims = ', '.join([cpp.sym2cpp(s) for s in indesc.shape]) + make_plan = f''' + {{ + size_t __work_size = 0; + cufftMakePlan{len(indesc.shape)}d({plan_name}, {cdims}, {cufft_type}, /*batch=*/1, &__work_size); + }} + ''' + + # Make plan in init if not symbolic or not data-dependent, otherwise make at callsite. + symbols_that_change = set(s for ise in sdfg.edges() for s in ise.data.assignments.keys()) + symbols_that_change &= set(map(str, sdfg.symbols.keys())) + + def _fsyms(x): + if symbolic.issymbolic(x): + return set(map(str, x.free_symbols)) + return set() + + if symbols_that_change and any(_fsyms(s) & symbols_that_change for s in indesc.shape): + callsite_code += make_plan + else: + init_code += make_plan + + # Execute plan + callsite_code += f''' + cufftSetStream({plan_name}, __dace_current_stream); + cufftXtExec({plan_name}, _inp, _out, {direction}); + ''' + + return nodes.Tasklet(f'cufft_{tasklet_prefix}fft', {'_inp'}, {'_out'}, + callsite_code, + language=dtypes.Language.CPP, + state_fields=fields, + code_init=init_code, + code_exit=exit_code) + + +################################################################################################## +# Helper functions +################################################################################################## + + +def _get_input_and_output(state: SDFGState, node: nodes.LibraryNode): + """ + Helper function that returns the input and output arrays of the library node + """ + in_edge = next(e for e in state.in_edges(node) if e.dst_conn) + out_edge = next(e for e in state.out_edges(node) if e.src_conn) + return in_edge.data.data, out_edge.data.data + + +def _types_to_cufft(indtype: dtypes.typeclass, outdtype: dtypes.typeclass): + typedict = { + dtypes.float32: 'R', + dtypes.float64: 'D', + dtypes.complex64: 'C', + dtypes.complex128: 'Z', + } + return f'CUFFT_{typedict[indtype]}2{typedict[outdtype]}' diff --git a/dace/libraries/standard/nodes/transpose.py b/dace/libraries/standard/nodes/transpose.py index 58c6cfc33e..e2795ef951 100644 --- a/dace/libraries/standard/nodes/transpose.py +++ b/dace/libraries/standard/nodes/transpose.py @@ -100,6 +100,12 @@ class ExpandTransposeMKL(ExpandTransformation): @staticmethod def expansion(node, state, sdfg): node.validate(sdfg, state) + + # Fall back to native implementation if input and output types are not the same + if (sdfg.arrays[list(state.in_edges_by_connector(node, '_inp'))[0].data.data].dtype != sdfg.arrays[list( + state.out_edges_by_connector(node, '_out'))[0].data.data].dtype): + return ExpandTransposePure.make_sdfg(node, state, sdfg) + dtype = node.dtype if dtype == dace.float32: func = "somatcopy" @@ -141,22 +147,30 @@ class ExpandTransposeOpenBLAS(ExpandTransformation): @staticmethod def expansion(node, state, sdfg): node.validate(sdfg, state) + + # Fall back to native implementation if input and output types are not the same + if (sdfg.arrays[list(state.in_edges_by_connector(node, '_inp'))[0].data.data].dtype != sdfg.arrays[list( + state.out_edges_by_connector(node, '_out'))[0].data.data].dtype): + return ExpandTransposePure.make_sdfg(node, state, sdfg) + dtype = node.dtype cast = "" if dtype == dace.float32: func = "somatcopy" alpha = "1.0f" + cast = '' elif dtype == dace.float64: func = "domatcopy" alpha = "1.0" + cast = '' elif dtype == dace.complex64: func = "comatcopy" - cast = "(float*)" - alpha = f"{cast}dace::blas::BlasConstants::Get().Complex64Pone()" + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + cast = '(float*)' elif dtype == dace.complex128: func = "zomatcopy" - cast = "(double*)" - alpha = f"{cast}dace::blas::BlasConstants::Get().Complex128Pone()" + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + cast = '(double*)' else: raise ValueError("Unsupported type for OpenBLAS omatcopy extension: " + str(dtype)) # TODO: Add stride support @@ -164,8 +178,8 @@ def expansion(node, state, sdfg): # Adaptations for BLAS API order = 'CblasRowMajor' trans = 'CblasTrans' - code = ("cblas_{f}({o}, {t}, {m}, {n}, {a}, {c}_inp, " - "{n}, {c}_out, {m});").format(f=func, o=order, t=trans, m=m, n=n, a=alpha, c=cast) + code = ("cblas_{f}({o}, {t}, {m}, {n}, {cast}{a}, {cast}_inp, " + "{n}, {cast}_out, {m});").format(f=func, o=order, t=trans, m=m, n=n, a=alpha, cast=cast) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, @@ -184,6 +198,11 @@ def expansion(node, state, sdfg, **kwargs): node.validate(sdfg, state) dtype = node.dtype + # Fall back to native implementation if input and output types are not the same + if (sdfg.arrays[list(state.in_edges_by_connector(node, '_inp'))[0].data.data].dtype != sdfg.arrays[list( + state.out_edges_by_connector(node, '_out'))[0].data.data].dtype): + return ExpandTransposePure.make_sdfg(node, state, sdfg) + try: func, cdtype, factort = blas_helpers.cublas_type_metadata(dtype) except TypeError as ex: diff --git a/tests/library/fft_test.py b/tests/library/fft_test.py new file mode 100644 index 0000000000..440d0a46cf --- /dev/null +++ b/tests/library/fft_test.py @@ -0,0 +1,101 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import pytest +import numpy as np + +import dace + + +@pytest.mark.parametrize('symbolic', (False, True)) +def test_fft(symbolic): + if symbolic: + N = dace.symbol('N') + else: + N = 21 + + @dace.program + def tester(x: dace.complex128[N]): + return np.fft.fft(x) + + a = np.random.rand(21) + 1j * np.random.rand(21) + b = tester(a) + assert np.allclose(b, np.fft.fft(a)) + + +def test_fft_r2c(): + """ + Tests implicit conversion to complex types + """ + + @dace.program + def tester(x: dace.float32[20]): + return np.fft.fft(x) + + a = np.random.rand(20).astype(np.float32) + b = tester(a) + assert b.dtype == np.complex64 + assert np.allclose(b, np.fft.fft(a)) + + +@pytest.mark.parametrize('norm', ('backward', 'forward', 'ortho')) +def test_ifft(norm): + + @dace.program + def tester(x: dace.complex128[21]): + return np.fft.ifft(x, norm=norm) + + a = np.random.rand(21) + 1j * np.random.rand(21) + b = tester(a) + assert np.allclose(b, np.fft.ifft(a, norm=norm)) + + +@pytest.mark.gpu +def test_cufft(): + import dace.libraries.fft as fftlib + + @dace.program + def tester(x: dace.complex128[210]): + return np.fft.fft(x) + + sdfg = tester.to_sdfg() + sdfg.apply_gpu_transformations() + fftlib.FFT.default_implementation = 'cuFFT' + sdfg.expand_library_nodes() + fftlib.FFT.default_implementation = 'pure' + + a = np.random.rand(210) + 1j * np.random.rand(210) + b = sdfg(a) + assert np.allclose(b, np.fft.fft(a)) + + +@pytest.mark.gpu +def test_cufft_twoplans(): + import dace.libraries.fft as fftlib + + @dace.program + def tester(x: dace.complex128[210], y: dace.complex64[19]): + return np.fft.fft(x), np.fft.ifft(y, norm='forward') + + sdfg = tester.to_sdfg() + sdfg.apply_gpu_transformations() + fftlib.FFT.default_implementation = 'cuFFT' + fftlib.IFFT.default_implementation = 'cuFFT' + sdfg.expand_library_nodes() + fftlib.FFT.default_implementation = 'pure' + fftlib.IFFT.default_implementation = 'pure' + + a = np.random.rand(210) + 1j * np.random.rand(210) + b = (np.random.rand(19) + 1j * np.random.rand(19)).astype(np.complex64) + c, d = sdfg(a, b) + assert np.allclose(c, np.fft.fft(a)) + assert np.allclose(d, np.fft.ifft(b, norm='forward')) + + +if __name__ == '__main__': + test_fft(False) + test_fft(True) + test_fft_r2c() + test_ifft('backward') + test_ifft('forward') + test_ifft('ortho') + test_cufft() + test_cufft_twoplans() diff --git a/tests/numpy/array_creation_test.py b/tests/numpy/array_creation_test.py index 7329b48b3f..a1f6d0329f 100644 --- a/tests/numpy/array_creation_test.py +++ b/tests/numpy/array_creation_test.py @@ -152,6 +152,42 @@ def test_arange_6(): return np.arange(2.5, 10, 3) +@compare_numpy_output() +def test_linspace_1(): + return np.linspace(2.5, 10, num=3) + + +@compare_numpy_output() +def test_linspace_2(): + space, step = np.linspace(2.5, 10, num=3, retstep=True) + return space, step + + +@compare_numpy_output() +def test_linspace_3(): + a = np.array([1, 2, 3]) + return np.linspace(a, 5, num=10) + + +@compare_numpy_output() +def test_linspace_4(): + a = np.array([[1, 2, 3], [4, 5, 6]]) + space, step = np.linspace(a, 10, endpoint=False, retstep=True) + return space, step + + +@compare_numpy_output() +def test_linspace_5(): + a = np.array([[1, 2, 3], [4, 5, 6]]) + b = np.array([[5], [10]]) + return np.linspace(a, b, endpoint=False, axis=1) + + +@compare_numpy_output() +def test_linspace_6(): + return np.linspace(-5, 5.5, dtype=np.float32) + + @dace.program def program_strides_0(): A = dace.ndarray((2, 2), dtype=dace.int32, strides=(2, 1)) @@ -267,6 +303,12 @@ def ones_scalar_size(k: dace.int32): test_arange_4() test_arange_5() test_arange_6() + test_linspace_1() + test_linspace_2() + test_linspace_3() + test_linspace_4() + test_linspace_5() + test_linspace_6() test_strides_0() test_strides_1() test_strides_2() diff --git a/tests/numpy/attention_simple_test.py b/tests/numpy/attention_simple_test.py index 49558a154b..2ce0205e3f 100644 --- a/tests/numpy/attention_simple_test.py +++ b/tests/numpy/attention_simple_test.py @@ -11,7 +11,7 @@ def dace_softmax(X_in: dace.float32[N], X_out: dace.float32[N]): tmp_max = dace.reduce(lambda a, b: max(a, b), X_in) - X_out[:] = exp(X_in - tmp_max) + X_out[:] = np.exp(X_in - tmp_max) tmp_sum = dace.reduce(lambda a, b: a + b, X_out, identity=0) X_out[:] /= tmp_sum diff --git a/tests/numpy/attribute_test.py b/tests/numpy/attribute_test.py index 2181883015..e011eafc89 100644 --- a/tests/numpy/attribute_test.py +++ b/tests/numpy/attribute_test.py @@ -54,7 +54,50 @@ def fn(a: dace.float64[N, F_in], b: dace.float64[N, heads, F_out], c: dace.float assert np.allclose(c, c_expected) +def test_nested_attribute(): + + @dace.program + def tester(a: dace.complex128[20, 10]): + return a.T.real + + r = np.random.rand(20, 10) + im = np.random.rand(20, 10) + a = r + 1j * im + res = tester(a) + assert np.allclose(res, r.T) + + +def test_attribute_of_expr(): + """ + Regression reported in Issue #1295. + """ + + @dace.program + def tester(a: dace.float64[20, 20], b: dace.float64[20, 20], c: dace.float64[20, 20]): + c[:, :] = (a @ b).T + + a = np.random.rand(20, 20) + b = np.random.rand(20, 20) + c = np.random.rand(20, 20) + ref = (a @ b).T + tester(a, b, c) + assert np.allclose(c, ref) + + +def test_attribute_function(): + + @dace.program + def tester(): + return np.arange(10).reshape(10, 1) + + a = tester() + assert np.allclose(a, np.arange(10).reshape(10, 1)) + + if __name__ == '__main__': test_attribute_in_ranged_loop() test_attribute_in_ranged_loop_symbolic() test_attribute_new_state() + test_nested_attribute() + test_attribute_of_expr() + test_attribute_function() diff --git a/tests/numpy/concat_test.py b/tests/numpy/concat_test.py new file mode 100644 index 0000000000..614258e34f --- /dev/null +++ b/tests/numpy/concat_test.py @@ -0,0 +1,133 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from common import compare_numpy_output +import pytest + +M = 10 +N = 20 +K = 30 + + +@compare_numpy_output() +def test_concatenate(): + a = np.zeros([N, N], dtype=np.float32) + b = np.ones([N, 1], dtype=np.float32) + return np.concatenate((a, b), axis=-1) + + +@compare_numpy_output() +def test_concatenate_four(): + a = np.zeros([N, N], dtype=np.float32) + b = np.ones([N, 1], dtype=np.float32) + c = np.full([N, M], 2.0, dtype=np.float32) + return np.concatenate((a, b, c, a), axis=-1) + + +@compare_numpy_output() +def test_concatenate_out(): + a = np.zeros([N, N], dtype=np.float32) + b = np.ones([M, N], dtype=np.float32) + c = np.full([N + M, N], -1, dtype=np.float32) + np.concatenate([a, b], out=c) + return c + 1 + + +def test_concatenate_symbolic(): + n = dace.symbol('n') + m = dace.symbol('m') + k = dace.symbol('k') + + @dace.program + def tester(a: dace.float64[k, m], b: dace.float64[k, n]): + return np.concatenate((a, b), axis=1) + + aa = np.random.rand(10, 4) + bb = np.random.rand(10, 5) + cc = tester(aa, bb) + assert tuple(cc.shape) == (10, 9) + assert np.allclose(np.concatenate((aa, bb), axis=1), cc) + + +def test_concatenate_fail(): + with pytest.raises(ValueError): + + @dace.program + def tester(a: dace.float64[K, M], b: dace.float64[N, K]): + return np.concatenate((a, b), axis=1) + + aa = np.random.rand(K, M) + bb = np.random.rand(N, K) + tester(aa, bb) + + +@compare_numpy_output() +def test_concatenate_flatten(): + a = np.zeros([1, 2, 3], dtype=np.float32) + b = np.ones([4, 5, 6], dtype=np.float32) + return np.concatenate([a, b], axis=None) + + +@compare_numpy_output() +def test_stack(): + a = np.zeros([N, M, K], dtype=np.float32) + b = np.ones([N, M, K], dtype=np.float32) + return np.stack((a, b), axis=-1) + + +@compare_numpy_output() +def test_vstack(): + a = np.zeros([N, M], dtype=np.float32) + b = np.ones([N, M], dtype=np.float32) + return np.vstack((a, b)) + + +@compare_numpy_output() +def test_vstack_1d(): + a = np.zeros([N], dtype=np.float32) + b = np.ones([N], dtype=np.float32) + return np.vstack((a, b)) + + +@compare_numpy_output() +def test_hstack(): + a = np.zeros([N, M], dtype=np.float32) + b = np.ones([N, M], dtype=np.float32) + return np.hstack((a, b)) + + +@compare_numpy_output() +def test_hstack_1d(): + a = np.zeros([N], dtype=np.float32) + b = np.ones([N], dtype=np.float32) + return np.hstack((a, b)) + + +@compare_numpy_output() +def test_dstack(): + a = np.zeros([N, M, K], dtype=np.float32) + b = np.ones([N, M, K], dtype=np.float32) + return np.dstack((a, b)) + + +@compare_numpy_output() +def test_dstack_4d(): + a = np.zeros([N, M, K, K], dtype=np.float32) + b = np.ones([N, M, K, K], dtype=np.float32) + return np.dstack((a, b)) + + +if __name__ == "__main__": + test_concatenate() + test_concatenate_four() + test_concatenate_out() + test_concatenate_symbolic() + test_concatenate_fail() + test_concatenate_flatten() + test_stack() + test_vstack() + test_vstack_1d() + test_hstack() + test_hstack_1d() + test_dstack() + test_dstack_4d() diff --git a/tests/numpy/nested_call_subarray_test.py b/tests/numpy/nested_call_subarray_test.py index 6a92b004fa..7501652328 100644 --- a/tests/numpy/nested_call_subarray_test.py +++ b/tests/numpy/nested_call_subarray_test.py @@ -8,7 +8,7 @@ @dace.program def dace_softmax_ncs(X_in: dace.float32[N], X_out: dace.float32[N]): tmp_max = dace.reduce(lambda a, b: a + b, X_in, identity=0) - X_out[:] = exp(X_in - tmp_max) + X_out[:] = np.exp(X_in - tmp_max) tmp_sum = dace.reduce(lambda a, b: max(a, b), X_in) X_out[:] /= tmp_sum @@ -22,7 +22,7 @@ def test_ncs_local_program(): @dace.program def dace_softmax_localprog(X_in: dace.float32[N], X_out: dace.float32[N]): tmp_max = dace.reduce(lambda a, b: a + b, X_in, identity=0) - X_out[:] = exp(X_in - tmp_max) + X_out[:] = np.exp(X_in - tmp_max) tmp_sum = dace.reduce(lambda a, b: max(a, b), X_in) X_out[:] /= tmp_sum diff --git a/tests/numpy/split_test.py b/tests/numpy/split_test.py new file mode 100644 index 0000000000..e4088754e8 --- /dev/null +++ b/tests/numpy/split_test.py @@ -0,0 +1,142 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests variants of the numpy split array manipulation. +""" +import dace +import numpy as np +from common import compare_numpy_output +import pytest + +M = 9 +N = 20 +K = 30 + + +@compare_numpy_output() +def test_split(): + arr = np.arange(M) + a, b, c = np.split(arr, 3) + return a + b + c + + +def test_uneven_split_fail(): + with pytest.raises(ValueError): + + @dace.program + def tester(): + arr = np.arange(N) + a, b, c = np.split(arr, 3) + return a + b + c + + tester() + + +def test_symbolic_split_fail(): + with pytest.raises(ValueError): + n = dace.symbol('n') + + @dace.program + def tester(): + arr = np.arange(N) + a, b, c = np.split(arr, n) + return a + b + c + + tester() + + +def test_array_split_fail(): + with pytest.raises(ValueError): + + @dace.program + def tester(): + arr = np.arange(N) + split = np.arange(N) + a, b, c = np.split(arr, split) + return a + b + c + + tester() + + +@compare_numpy_output() +def test_array_split(): + arr = np.arange(N) + a, b, c = np.array_split(arr, 3) + return a, b, c + + +@compare_numpy_output() +def test_array_split_multidim(): + arr = np.ones((N, N)) + a, b, c = np.array_split(arr, 3, axis=1) + return a, b, c + + +@compare_numpy_output() +def test_split_sequence(): + arr = np.arange(N) + a, b = np.split(arr, [3]) + return a, b + + +@compare_numpy_output() +def test_split_sequence_2(): + arr = np.arange(M) + a, b, c = np.split(arr, [3, 6]) + return a + b + c + + +def test_split_sequence_symbolic(): + n = dace.symbol('n') + + @dace.program + def tester(arr: dace.float64[3 * n]): + a, b, c = np.split(arr, [n, n + 2]) + return a, b, c + + nval = K // 3 + a = np.random.rand(K) + ra, rb, rc = tester(a) + assert ra.shape[0] == nval + assert rb.shape[0] == 2 + assert rc.shape[0] == K - nval - 2 + ref = np.split(a, [nval, nval + 2]) + assert len(ref) == 3 + assert np.allclose(ra, ref[0]) + assert np.allclose(rb, ref[1]) + assert np.allclose(rc, ref[2]) + + +@compare_numpy_output() +def test_vsplit(): + arr = np.ones((N, M)) + a, b = np.vsplit(arr, 2) + return a, b + + +@compare_numpy_output() +def test_hsplit(): + arr = np.ones((M, N)) + a, b = np.hsplit(arr, 2) + return a, b + + +@compare_numpy_output() +def test_dsplit_4d(): + arr = np.ones([N, M, K, K], dtype=np.float32) + a, b, c = np.dsplit(arr, 3) + return a, b, c + + +if __name__ == "__main__": + test_split() + test_uneven_split_fail() + test_symbolic_split_fail() + test_array_split_fail() + test_array_split() + test_array_split_multidim() + test_split_sequence() + test_split_sequence_2() + test_split_sequence_symbolic() + test_vsplit() + test_hsplit() + test_dsplit_4d() diff --git a/tests/numpy/ufunc_test.py b/tests/numpy/ufunc_test.py index 06bd4c3189..b769ab1082 100644 --- a/tests/numpy/ufunc_test.py +++ b/tests/numpy/ufunc_test.py @@ -1304,6 +1304,11 @@ def test_ufunc_trunc_u(A: dace.uint32[10]): return np.trunc(A) +@compare_numpy_output() +def test_ufunc_clip(A: dace.float32[10]): + return np.clip(A, 0.2, 0.5) + + if __name__ == "__main__": test_ufunc_add_ff() test_ufunc_subtract_ff() @@ -1542,3 +1547,4 @@ def test_ufunc_trunc_u(A: dace.uint32[10]): test_ufunc_trunc_c() test_ufunc_trunc_f() test_ufunc_trunc_u() + test_ufunc_clip()