Skip to content

Commit

Permalink
clear graph cache to avoid OOM
Browse files Browse the repository at this point in the history
Signed-off-by: xinhe3 <xinhe3@habana.ai>
  • Loading branch information
xinhe3 committed Oct 26, 2024
1 parent b1a5a92 commit dbb8041
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions neural_compressor/evaluation/lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ def __init__(
peft: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False,
pad_to_buckets: Optional[Union[bool]] = False,
buckets: Optional[list] = [64, 128, 256, 512, 1024, 2048],
buckets: Optional[list] = [32, 64, 128, 256, 512, 1024, 2048, 4096],
model_format: Optional[str] = "torch",
**kwargs,
) -> None:
super().__init__()
self.pad_to_buckets = pad_to_buckets
self.buckets = buckets
self.last_bucket = -1
self.model_format = model_format
# optionally: take in an already-initialized transformers.PreTrainedModel
if not isinstance(pretrained, str):
Expand Down Expand Up @@ -883,7 +884,10 @@ def find_bucket(self, length):
eval_logger.error("Please add a higher value into the buckets list for this case.")
exit(0)
else:
return suitable_buckets[0]
if self.last_bucket != suitable_buckets[0]:
self.model.clear_cache() # clear graph cache to avoid OOM
self.last_bucket = suitable_buckets[0]
return self.last_bucket

def _model_call(self, inps, attn_mask=None, labels=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/evaluation/lm_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
seed=[0, 1234, 1234],
trust_remote_code=False,
pad_to_buckets=None, # used by HPU to align input length for performance.
buckets=[64, 128, 256, 512, 1024, 2048], # used by HPU to limit input length range.
buckets=[32, 64, 128, 256, 512, 1024, 2048, 4096], # used by HPU to limit input length range.
):
self.model = model
self.tasks = tasks
Expand Down

0 comments on commit dbb8041

Please sign in to comment.