From 1e1d31387aa594b2e745c8ed8964962c134d203d Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:10:09 +0200 Subject: [PATCH] minor fix for bfloat16 (#7003) --- src/datasets/utils/_dill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py index 2dedf7f1fbc..2a414459266 100644 --- a/src/datasets/utils/_dill.py +++ b/src/datasets/utils/_dill.py @@ -165,7 +165,7 @@ def _save_torchTensor(pickler, obj): def create_torchTensor(np_array, dtype=None): tensor = torch.from_numpy(np_array) if dtype: - tensor = tensor.type(torch.bfloat16) + tensor = tensor.type(dtype) return tensor log(pickler, f"To: {obj}")