diff --git a/tests/myarray.py b/tests/myarray.py index 651f65f..c3d2f46 100644 --- a/tests/myarray.py +++ b/tests/myarray.py @@ -219,16 +219,8 @@ def _bitcast_convert_type_p() -> MyArray: @register(lax.broadcast_in_dim_p) -def _broadcast_in_dim_p( - operand: MyArray, - *, - shape: Any, - broadcast_dimensions: Any, -) -> MyArray: - return replace( - operand, - array=lax.broadcast_in_dim(operand.array, shape, broadcast_dimensions), - ) +def _broadcast_in_dim_p(operand: MyArray, **kwargs: Any) -> MyArray: + return replace(operand, array=lax.broadcast_in_dim(operand.array, **kwargs)) # ==============================================================================