From f40173555a33aac4505adaceaa2bc1050cdfff20 Mon Sep 17 00:00:00 2001 From: Osterfeld Date: Thu, 12 Sep 2024 03:25:09 +0200 Subject: [PATCH 1/2] raise error for batched vector inputs --- heat/core/linalg/basics.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index 53e5e94e8..2f57fe774 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -430,9 +430,9 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: Parameters ----------- a : DNDarray - matrix :math:`L \\times P` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k [\\times L] \\times P` + matrix :math:`L \\times P` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times L \\times P` b : DNDarray - matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k \\times P [\\times Q]` + matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times P \\times Q` allow_resplit : bool, optional Whether to distribute ``a`` in the case that both ``a.split is None`` and ``b.split is None``. Default is ``False``. If ``True``, if both are not split then ``a`` will be distributed in-place along axis 0. @@ -440,7 +440,7 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: Notes ----------- - For batched inputs, batch dimensions must coincide and if one matrix is split along a batch axis the other must be split along the same axis. - - If ``a`` or ``b`` is a (possibly batched) vector the result will also be a (possibly batched) vector. + - If ``a`` or ``b`` is a vector the result will also be a vector. - We recommend to avoid the particular split combinations ``1``-``0``, ``None``-``0``, and ``1``-``None`` (for ``a.split``-``b.split``) due to their comparably high memory consumption, if possible. Applying ``DNDarray.resplit_`` or ``heat.resplit`` on one of the two factors before calling ``matmul`` in these situations might improve performance of your code / might avoid memory bottlenecks. References @@ -529,6 +529,10 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: raise NotImplementedError( "Both input matrices have to be split along the same batch axis!" ) + if vector_flag: # batched matrix vector multiplication not supported + raise NotImplementedError( + "Batched matrix-vector multiplication is not supported, try using expand_dims to make it a batched matrix-matrix multiplication." + ) comm = a.comm ndim = max(a.ndim, b.ndim) From 7b110f09eef0fc14cfa4f1817a82aed472a1f6d0 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 1 Oct 2024 17:28:28 +0200 Subject: [PATCH 2/2] expand tests --- heat/core/linalg/tests/test_basics.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index 39fc2583b..870d671f6 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -827,14 +827,16 @@ def test_matmul(self): a = ht.zeros((3, 3, 3), split=0) b = ht.zeros((4, 3, 3), split=0) ht.matmul(a, b) - # not implemented split - """ - todo + # split along different batch dimension with self.assertRaises(NotImplementedError): - a = ht.zeros((3, 3, 3)) - b = ht.zeros((3, 3, 3)) + a = ht.zeros((4, 3, 3, 3), split=0) + b = ht.zeros((4, 3, 3, 3), split=1) + ht.matmul(a, b) + # batched matrix-vector multiplication + with self.assertRaises(NotImplementedError): + a = ht.zeros((3, 3, 3), split=0) + b = ht.zeros((3, 3), split=0) ht.matmul(a, b) - """ # batched, split batch n = 11 # number of batches