Skip to content

Commit

Permalink
Fix numpy vecdot to apply axis before broadcasting
Browse files Browse the repository at this point in the history
This is changed in the 2023 version of the spec, and matches the new np.vecdot
gufunc.
  • Loading branch information
asmeurer committed Feb 27, 2024
1 parent 45a8e27 commit 93ce826
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,20 +489,17 @@ def tensordot(x1: ndarray,
return xp.tensordot(x1, x2, axes=axes, **kwargs)

def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
ndim = max(x1.ndim, x2.ndim)
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
if x1.shape[axis] != x2.shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")

if hasattr(xp, 'broadcast_tensors'):
_broadcast = xp.broadcast_tensors
else:
_broadcast = xp.broadcast_arrays

x1_, x2_ = _broadcast(x1, x2)
x1_ = xp.moveaxis(x1_, axis, -1)
x2_ = xp.moveaxis(x2_, axis, -1)
x1_ = xp.moveaxis(x1, axis, -1)
x2_ = xp.moveaxis(x2, axis, -1)
x1_, x2_ = _broadcast(x1_, x2_)

res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]
Expand Down

0 comments on commit 93ce826

Please sign in to comment.