Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial memleak fix attempts #323

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def main(
hidden_size: int = 1280,
num_attention_heads: int = 20,
ffn_hidden_size: int = 1280 * 4,
torch_empty_cache_steps: int = 1_000,
) -> None:
"""Train an ESM2 model on UR data.

Expand Down Expand Up @@ -164,6 +165,7 @@ def main(

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
# MemoryCleanupCallback(torch_empty_cache_steps),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
]
Expand Down Expand Up @@ -557,6 +559,13 @@ def main(
default=4 * 1280,
help="FFN hidden size of the model. Default is 4 * 1280.",
)
parser.add_argument(
"--torch-empty-cache-steps",
type=int,
required=False,
default=1_000,
help="Clean up torch cache every N steps.",
)

if __name__ == "__main__":
args = parser.parse_args()
Expand Down Expand Up @@ -608,4 +617,5 @@ def main(
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
ffn_hidden_size=args.ffn_hidden_size,
torch_empty_cache_steps=args.torch_empty_cache_steps,
)
19 changes: 0 additions & 19 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@


import logging
import math
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar

import torch
import torch.distributed
from megatron.core import tensor_parallel
from megatron.core.models.bert.bert_lm_head import BertLMHead
from megatron.core.models.bert.pooler import Pooler
Expand Down Expand Up @@ -208,22 +205,6 @@ def embedding_forward(
)


@torch.compile
def esm_gelu_func(x: Tensor) -> Tensor:
"""ESM2-specific gelu implementation from the original ESM repo.

!!! warning

Using F.gelu yields subtly wrong results, but only when used in combination with bias_activation_fusion=True
This variant will not allow you to use bias_activation_fusion=True, which may be the only accuracy benefit over
a native F.gelu.

Args:
x: input tensor of any given dimension
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


ESM2ModelT = TypeVar("ESM2ModelT", bound=ESM2Model)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,13 @@ class BioBertConfig(

# From megatron.core.models.gpt.bert_model.GPTModel
fp16_lm_cross_entropy: bool = False
apply_rope_fusion: bool = True
apply_rope_fusion: bool = False
parallel_output: bool = True
bias_dropout_fusion: bool = True
bias_activation_fusion: bool = True
masked_softmax_fusion: bool = True
persist_layer_norm: bool = True
get_attention_mask_from_fusion: bool = True
bias_dropout_fusion: bool = False
bias_activation_fusion: bool = False
masked_softmax_fusion: bool = False
persist_layer_norm: bool = False
get_attention_mask_from_fusion: bool = False

share_embeddings_and_output_weights: bool = False # try True
make_vocab_size_divisible_by: int = 128
Expand Down
2 changes: 1 addition & 1 deletion sub-packages/bionemo-llm/src/bionemo/llm/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ def __init__(self, config: TransformerConfig, *args, **kwargs) -> None: # noqa:
self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)

@torch.compile
# @torch.compile
def forward(self, query, *args, **kwargs): # noqa: D102
return query / self.sqrt_val
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import torch
from nemo.lightning import io
from nemo.utils import logging
from pytorch_lightning.callbacks.callback import Callback


class MemoryCleanupCallback(Callback, io.IOMixin):
"""Class to print out memory usage at the end of each training batch."""

def __init__(self, cleanup_every_n_steps: int = 1_000):
"""Initialize the memory usage list."""
self._cleanup_every_n_steps = cleanup_every_n_steps

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: # noqa: D102
if batch_idx and batch_idx % self._cleanup_every_n_steps == 0:
gc.collect()
torch.cuda.empty_cache()

logging.info(
f" Cleaning up CUDA cache on batch {batch_idx}. "
f"Mem: {torch.cuda.memory_allocated()/1024/1024/1024:} /"
f"{torch.cuda.max_memory_reserved()/1024/1024/1024}"
)

# self.memory_usage.append((batch_idx, torch.cuda.memory_allocated(), torch.cuda.max_memory_reserved()))

# logging.info(
# f"on_train_batch_end {batch_idx} mem: {torch.cuda.memory_allocated()/1024/1024/1024} /"
# f"{torch.cuda.max_memory_reserved()/1024/1024/1024}"
# )
Loading