From cf74122c0b88daf5170702d7e5b0dec01e1c7f4a Mon Sep 17 00:00:00 2001 From: StoneT2000 Date: Tue, 5 Mar 2024 02:24:10 -0800 Subject: [PATCH] Update sapien_utils.py --- mani_skill2/utils/sapien_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mani_skill2/utils/sapien_utils.py b/mani_skill2/utils/sapien_utils.py index f4ea1de87..0a922ff9a 100644 --- a/mani_skill2/utils/sapien_utils.py +++ b/mani_skill2/utils/sapien_utils.py @@ -36,10 +36,15 @@ def to_tensor(array: Union[torch.Tensor, np.array, Sequence]): return torch.Tensor(array).cuda() elif get_backend_name() == "numpy": if isinstance(array, np.ndarray): - return torch.from_numpy(array) - # TODO (arth): better way to address torch "UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow" ? + ret = torch.from_numpy(array) + if ret.dtype == torch.float64: + ret = ret.float() + return ret elif isinstance(array, list) and isinstance(array[0], np.ndarray): - return torch.from_numpy(np.array(array)) + ret = torch.from_numpy(np.array(array)) + if ret.dtype == torch.float64: + ret = ret.float() + return ret elif np.iterable(array): return torch.Tensor(array) else: