Skip to content

Commit

Permalink
Add test for rmm
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com>
  • Loading branch information
VibhuJawa committed Oct 3, 2024
1 parent 5569026 commit 69d306e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion crossfit/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import gc
from typing import List, Union

import rmm
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -120,6 +119,8 @@ def reset_memory_tracking() -> None:
# TODO: This is hacky, we need to check if the allocator is rmm
# and then reset the peak memory stats
if torch.cuda.memory.get_allocator_backend() == "pluggable":
import rmm

rmm.statistics.enable_statistics()
rmm.statistics.push_statistics()
else:
Expand All @@ -137,6 +138,8 @@ def get_peak_memory_used() -> int:
int: Peak memory usage in bytes.
"""
if torch.cuda.memory.get_allocator_backend() == "pluggable":
import rmm

stats = rmm.statistics.pop_statistics()
return stats.peak_bytes
else:
Expand Down

0 comments on commit 69d306e

Please sign in to comment.