Skip to content

Commit

Permalink
Merge pull request #1646 from helmholtz-analytics/1645-bug-batched-ma…
Browse files Browse the repository at this point in the history
…trix-vector-multiplication-does-not-work-correctly

Raise Error for batched vector inputs on matmul
  • Loading branch information
mrfh92 authored Oct 4, 2024
2 parents 576878e + 7b110f0 commit e1a36a7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
10 changes: 7 additions & 3 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,17 +430,17 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
Parameters
-----------
a : DNDarray
matrix :math:`L \\times P` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k [\\times L] \\times P`
matrix :math:`L \\times P` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times L \\times P`
b : DNDarray
matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices/vectors: :math:`B_1 \\times ... \\times B_k \\times P [\\times Q]`
matrix :math:`P \\times Q` or vector :math:`P` or batch of matrices: :math:`B_1 \\times ... \\times B_k \\times P \\times Q`
allow_resplit : bool, optional
Whether to distribute ``a`` in the case that both ``a.split is None`` and ``b.split is None``.
Default is ``False``. If ``True``, if both are not split then ``a`` will be distributed in-place along axis 0.
Notes
-----------
- For batched inputs, batch dimensions must coincide and if one matrix is split along a batch axis the other must be split along the same axis.
- If ``a`` or ``b`` is a (possibly batched) vector the result will also be a (possibly batched) vector.
- If ``a`` or ``b`` is a vector the result will also be a vector.
- We recommend to avoid the particular split combinations ``1``-``0``, ``None``-``0``, and ``1``-``None`` (for ``a.split``-``b.split``) due to their comparably high memory consumption, if possible. Applying ``DNDarray.resplit_`` or ``heat.resplit`` on one of the two factors before calling ``matmul`` in these situations might improve performance of your code / might avoid memory bottlenecks.
References
Expand Down Expand Up @@ -529,6 +529,10 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray:
raise NotImplementedError(
"Both input matrices have to be split along the same batch axis!"
)
if vector_flag: # batched matrix vector multiplication not supported
raise NotImplementedError(
"Batched matrix-vector multiplication is not supported, try using expand_dims to make it a batched matrix-matrix multiplication."
)

comm = a.comm
ndim = max(a.ndim, b.ndim)
Expand Down
14 changes: 8 additions & 6 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,14 +827,16 @@ def test_matmul(self):
a = ht.zeros((3, 3, 3), split=0)
b = ht.zeros((4, 3, 3), split=0)
ht.matmul(a, b)
# not implemented split
"""
todo
# split along different batch dimension
with self.assertRaises(NotImplementedError):
a = ht.zeros((3, 3, 3))
b = ht.zeros((3, 3, 3))
a = ht.zeros((4, 3, 3, 3), split=0)
b = ht.zeros((4, 3, 3, 3), split=1)
ht.matmul(a, b)
# batched matrix-vector multiplication
with self.assertRaises(NotImplementedError):
a = ht.zeros((3, 3, 3), split=0)
b = ht.zeros((3, 3), split=0)
ht.matmul(a, b)
"""

# batched, split batch
n = 11 # number of batches
Expand Down

0 comments on commit e1a36a7

Please sign in to comment.