-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fp16 in new python engine #393
Comments
Hello @Xp-speit2018 , |
changed error message regarding infinity and float16 ; for issue #393.
Thanks @joanglaunes for your clarification! Now I'm a bit more aware of where the error is originated from. I created a minimal demo to show this error: import torch
from pykeops.torch import LazyTensor
N = 10
M = 20
D = 8
X = torch.randn(N, D).half().cuda()
Y = torch.randn(N, D).half().cuda()
# find the L2-nearest neighbors of X in Y
X_i = LazyTensor(X[:, None, :]) # (N, 1, D)
Y_j = LazyTensor(Y[None, :, :]) # (1, M, D)
D_ij = ((X_i - Y_j) ** 2).sum(-1) # (N, M)
ind_k = D_ij.argmin(dim=1)
print(ind_k) The infinite is indeed created during argmin to initialize the min value, see: keops/keopscore/keopscore/formulas/reductions/Min_ArgMin_Reduction_Base.py Lines 12 to 25 in dfa60e2
The trick part is, on cuda devices half values are packed into half2 . Even if patched the infinity creation with:
def infinity(dtype):
if dtype == "float":
code = "( 1.0f/0.0f )"
elif dtype == "double":
code = "( 1.0/0.0 )"
elif dtype == "half2":
# not sure if it's logically correct but only used to bypass error checking
code = "__half2(65504.0f, 65504.0f)"
else:
KeOps_Error(
"only float and double dtypes are implemented in new python engine for now"
)
return c_variable(dtype, code) We still have not implemented error firstly raised from: keops/keopscore/keopscore/formulas/reductions/Min_ArgMin_Reduction_Base.py Lines 40 to 46 in dfa60e2
Now it's clear that we are lacking support for half2 introduced by cuda. Interestingly when I want to take a further look at how to patch half and I created the same demo on numpy: import numpy as np
from pykeops.numpy import LazyTensor
N = 10
M = 20
D = 8
X = np.random.randn(N, D).astype('float16')
Y = np.random.randn(M, D).astype('float16')
# find the L2-nearest neighbors of X in Y
X_i = LazyTensor(X[:, None, :]) # (N, 1, D)
Y_j = LazyTensor(Y[None, :, :]) # (1, M, D)
D_ij = ((X_i - Y_j) ** 2).sum(-1)
ind_k = D_ij.argmin(dim=1)
print(ind_k) I encountered the exactly same error. It seems that we are still compiling from Would it be possible to consider add support for half2 and half with reduction methods where creating infinity is a must? Your patience would be much appreciated. |
Hello @Xp-speit2018 , |
Hi,
I recently encountered an issue where fp16 (half-precision) is no longer supported in the new Python engine of KeOps (my version is 2.2.3). I received the following error:
However, I noticed that fp16 was claimed to be supported since version 1.4, as mentioned in this issue. I'm curious if there is a specific reason why fp16 support was dropped in the new engine introduced in version 2.0.
Given that half-precision is widely used in many prevalent models, I believe it would be beneficial to clarify the current status and potential roadmap for fp16 support in the new engine. This could encourage the community to contribute and help improve the feature.
Thank you for your attention.
The text was updated successfully, but these errors were encountered: