diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index c2be21f..a6d642e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -145,6 +145,9 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # Basic renames bitwise_invert = torch.bitwise_not newaxis = None +# torch.conj sets the conjugation bit, which breaks conversion to other +# libraries. See https://github.com/data-apis/array-api-compat/issues/173 +conj = torch.conj_physical # Two-arg elementwise functions # These require a wrapper to do the correct type promotion on 0-D tensors @@ -704,18 +707,18 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', - 'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', - 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', - 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', - 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', - 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod', - 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', - 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', - 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take'] + 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', + 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', + 'bitwise_xor', 'copysign', 'divide', 'equal', 'floor_divide', + 'greater', 'greater_equal', 'less', 'less_equal', 'logaddexp', + 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', + 'min', 'clip', 'sort', 'prod', 'sum', 'any', 'all', 'mean', 'std', + 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll', + 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', + 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', + 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', + 'UniqueInverseResult', 'unique_all', 'unique_counts', + 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', + 'vecdot', 'tensordot', 'isdtype', 'take'] _all_ignore = ['torch', 'get_xp']