From dfd0e6b0b4d6eabb513b350af62fd5f81ac096de Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 18 Oct 2024 09:22:56 -0700 Subject: [PATCH 1/3] initial memleak fix attempts --- scripts/protein/esm2/esm2_pretrain.py | 10 ++++- .../src/bionemo/esm2/model/model.py | 19 --------- .../src/bionemo/llm/model/layers.py | 2 +- .../src/bionemo/llm/utils/memory_callback.py | 39 +++++++++++++++++++ 4 files changed, 48 insertions(+), 22 deletions(-) create mode 100644 sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py diff --git a/scripts/protein/esm2/esm2_pretrain.py b/scripts/protein/esm2/esm2_pretrain.py index 896a7e79b..3e4617f1e 100644 --- a/scripts/protein/esm2/esm2_pretrain.py +++ b/scripts/protein/esm2/esm2_pretrain.py @@ -17,6 +17,7 @@ from pathlib import Path from typing import List, Optional, Sequence, get_args +import pandas as pd from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl from nemo.collections import llm @@ -31,11 +32,11 @@ from bionemo.esm2.data.dataset import RandomMaskStrategy from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler -from bionemo.llm.lightning import PerplexityLoggingCallback from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.llm.model.biobert.model import BiobertSpecOption from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger +from bionemo.llm.utils.memory_callback import MemoryCleanupCallback __all__: Sequence[str] = ("main", "parser") @@ -162,8 +163,11 @@ def main( ) ) + mem_callback = MemoryCleanupCallback() + callbacks = [ - PerplexityLoggingCallback(log_train=False, log_val=True), + # PerplexityLoggingCallback(log_train=False, log_val=True), + mem_callback, RichModelSummary(max_depth=4), LearningRateMonitor(), ] @@ -268,6 +272,8 @@ def main( ), ) + pd.DataFrame(mem_callback.memory_usage).to_csv(result_dir / "memory_usage.csv", index=False) + # TODO migrate to hydra config # Parse the arguments and pull them out into local variables for ease of future refactor to a diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index 1d19c1543..0508d84f6 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -15,12 +15,9 @@ import logging -import math from dataclasses import dataclass from typing import Callable, Literal, Optional, Sequence, Type, TypeVar -import torch -import torch.distributed from megatron.core import tensor_parallel from megatron.core.models.bert.bert_lm_head import BertLMHead from megatron.core.models.bert.pooler import Pooler @@ -208,22 +205,6 @@ def embedding_forward( ) -@torch.compile -def esm_gelu_func(x: Tensor) -> Tensor: - """ESM2-specific gelu implementation from the original ESM repo. - - !!! warning - - Using F.gelu yields subtly wrong results, but only when used in combination with bias_activation_fusion=True - This variant will not allow you to use bias_activation_fusion=True, which may be the only accuracy benefit over - a native F.gelu. - - Args: - x: input tensor of any given dimension - """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - ESM2ModelT = TypeVar("ESM2ModelT", bound=ESM2Model) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/layers.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/layers.py index 0f233c9bf..8b9fb2a6c 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/layers.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/layers.py @@ -57,6 +57,6 @@ def __init__(self, config: TransformerConfig, *args, **kwargs) -> None: # noqa: self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads) self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head) - @torch.compile + # @torch.compile def forward(self, query, *args, **kwargs): # noqa: D102 return query / self.sqrt_val diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py b/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py new file mode 100644 index 000000000..d7149792e --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from nemo.lightning import io +from nemo.utils import logging +from pytorch_lightning.callbacks.callback import Callback + + +class MemoryCleanupCallback(Callback, io.IOMixin): + """Class to print out memory usage at the end of each training batch.""" + + def __init__(self): + """Initialize the memory usage list.""" + self.memory_usage = [] + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: # noqa: D102 + # gc.collect() + # torch.cuda.empty_cache() + + self.memory_usage.append((batch_idx, torch.cuda.memory_allocated(), torch.cuda.max_memory_reserved())) + + logging.info( + f"on_train_batch_end {batch_idx} mem: {torch.cuda.memory_allocated()/1024/1024/1024} /" + f"{torch.cuda.max_memory_reserved()/1024/1024/1024}" + ) From 3fc8974ecb82e6c75944241bf0008347a4c38df3 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 18 Oct 2024 10:12:41 -0700 Subject: [PATCH 2/3] add callback for clearing memory --- scripts/protein/esm2/esm2_pretrain.py | 19 +++++++----- .../src/bionemo/llm/utils/memory_callback.py | 30 ++++++++++++------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/scripts/protein/esm2/esm2_pretrain.py b/scripts/protein/esm2/esm2_pretrain.py index 3e4617f1e..411af1c5f 100644 --- a/scripts/protein/esm2/esm2_pretrain.py +++ b/scripts/protein/esm2/esm2_pretrain.py @@ -17,7 +17,6 @@ from pathlib import Path from typing import List, Optional, Sequence, get_args -import pandas as pd from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl from nemo.collections import llm @@ -32,6 +31,7 @@ from bionemo.esm2.data.dataset import RandomMaskStrategy from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler +from bionemo.llm.lightning import PerplexityLoggingCallback from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.llm.model.biobert.model import BiobertSpecOption from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size @@ -91,6 +91,7 @@ def main( hidden_size: int = 1280, num_attention_heads: int = 20, ffn_hidden_size: int = 1280 * 4, + torch_empty_cache_steps: int = 1_000, ) -> None: """Train an ESM2 model on UR data. @@ -163,11 +164,9 @@ def main( ) ) - mem_callback = MemoryCleanupCallback() - callbacks = [ - # PerplexityLoggingCallback(log_train=False, log_val=True), - mem_callback, + PerplexityLoggingCallback(log_train=False, log_val=True), + MemoryCleanupCallback(torch_empty_cache_steps), RichModelSummary(max_depth=4), LearningRateMonitor(), ] @@ -272,8 +271,6 @@ def main( ), ) - pd.DataFrame(mem_callback.memory_usage).to_csv(result_dir / "memory_usage.csv", index=False) - # TODO migrate to hydra config # Parse the arguments and pull them out into local variables for ease of future refactor to a @@ -563,6 +560,13 @@ def main( default=4 * 1280, help="FFN hidden size of the model. Default is 4 * 1280.", ) +parser.add_argument( + "--torch-empty-cache-steps", + type=int, + required=False, + default=1_000, + help="Clean up torch cache every N steps.", +) if __name__ == "__main__": args = parser.parse_args() @@ -614,4 +618,5 @@ def main( hidden_size=args.hidden_size, num_attention_heads=args.num_attention_heads, ffn_hidden_size=args.ffn_hidden_size, + torch_empty_cache_steps=args.torch_empty_cache_steps, ) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py b/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py index d7149792e..99a7f0fcc 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/utils/memory_callback.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import torch from nemo.lightning import io @@ -23,17 +24,24 @@ class MemoryCleanupCallback(Callback, io.IOMixin): """Class to print out memory usage at the end of each training batch.""" - def __init__(self): + def __init__(self, cleanup_every_n_steps: int = 1_000): """Initialize the memory usage list.""" - self.memory_usage = [] + self._cleanup_every_n_steps = cleanup_every_n_steps def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) -> None: # noqa: D102 - # gc.collect() - # torch.cuda.empty_cache() - - self.memory_usage.append((batch_idx, torch.cuda.memory_allocated(), torch.cuda.max_memory_reserved())) - - logging.info( - f"on_train_batch_end {batch_idx} mem: {torch.cuda.memory_allocated()/1024/1024/1024} /" - f"{torch.cuda.max_memory_reserved()/1024/1024/1024}" - ) + if batch_idx and batch_idx % self._cleanup_every_n_steps == 0: + gc.collect() + torch.cuda.empty_cache() + + logging.info( + f" Cleaning up CUDA cache on batch {batch_idx}. " + f"Mem: {torch.cuda.memory_allocated()/1024/1024/1024:} /" + f"{torch.cuda.max_memory_reserved()/1024/1024/1024}" + ) + + # self.memory_usage.append((batch_idx, torch.cuda.memory_allocated(), torch.cuda.max_memory_reserved())) + + # logging.info( + # f"on_train_batch_end {batch_idx} mem: {torch.cuda.memory_allocated()/1024/1024/1024} /" + # f"{torch.cuda.max_memory_reserved()/1024/1024/1024}" + # ) From 6c95af745721db9b87cefe0c9a2df0adef442165 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 18 Oct 2024 15:16:36 -0700 Subject: [PATCH 3/3] remove fusion layers --- scripts/protein/esm2/esm2_pretrain.py | 3 +-- .../src/bionemo/llm/model/biobert/model.py | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/scripts/protein/esm2/esm2_pretrain.py b/scripts/protein/esm2/esm2_pretrain.py index 411af1c5f..dac2629bf 100644 --- a/scripts/protein/esm2/esm2_pretrain.py +++ b/scripts/protein/esm2/esm2_pretrain.py @@ -36,7 +36,6 @@ from bionemo.llm.model.biobert.model import BiobertSpecOption from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger -from bionemo.llm.utils.memory_callback import MemoryCleanupCallback __all__: Sequence[str] = ("main", "parser") @@ -166,7 +165,7 @@ def main( callbacks = [ PerplexityLoggingCallback(log_train=False, log_val=True), - MemoryCleanupCallback(torch_empty_cache_steps), + # MemoryCleanupCallback(torch_empty_cache_steps), RichModelSummary(max_depth=4), LearningRateMonitor(), ] diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index e2259bb66..39b1f3608 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -436,13 +436,13 @@ class BioBertConfig( # From megatron.core.models.gpt.bert_model.GPTModel fp16_lm_cross_entropy: bool = False - apply_rope_fusion: bool = True + apply_rope_fusion: bool = False parallel_output: bool = True - bias_dropout_fusion: bool = True - bias_activation_fusion: bool = True - masked_softmax_fusion: bool = True - persist_layer_norm: bool = True - get_attention_mask_from_fusion: bool = True + bias_dropout_fusion: bool = False + bias_activation_fusion: bool = False + masked_softmax_fusion: bool = False + persist_layer_norm: bool = False + get_attention_mask_from_fusion: bool = False share_embeddings_and_output_weights: bool = False # try True make_vocab_size_divisible_by: int = 128