diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py index 3f6688ba5..e579f9350 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py @@ -30,14 +30,12 @@ from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig -from pkg_resources import packaging +from megatron.core.utils import get_te_version, is_te_min_version from torch import Tensor __all__: Sequence[str] = ("ESM2DotProductAttention", "ESM2TEDotProductAttention") -from megatron.core.extensions.transformer_engine import _te_version - class ESM2TEDotProductAttention(TEDotProductAttention): """ESM2-Specific transformer engine core attention. @@ -52,6 +50,9 @@ def __init__( attn_mask_type: AttnMaskType, attention_type: str, attention_dropout: float | None = None, + softmax_scale: float = 1.0, + k_channels: int | None = None, + v_channels: int | None = None, ): """Initialize ESM2TEDotProductAttention.""" self.config = config @@ -67,25 +68,25 @@ def __init__( ) extra_kwargs = {} - if _te_version >= packaging.version.Version("0.11.0"): + if is_te_min_version("0.11.0"): extra_kwargs["num_gqa_groups"] = self.config.num_query_groups elif self.config.num_query_groups != self.config.num_attention_heads: raise ValueError( - f"Transformer Engine v{_te_version} does not support Grouped Query Attention, " + f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " f"use a newer version of Transformer Engine. " f"(num_query_groups ({self.config.num_query_groups}) != " f"num_attention_heads ({self.config.num_attention_heads}))" ) - if _te_version >= packaging.version.Version("0.10.0"): + if is_te_min_version("0.10.0"): extra_kwargs["attention_type"] = attention_type # older version don't need attention_type - if _te_version > packaging.version.Version("0.12.0"): + if is_te_min_version("0.12.0", check_equality=False): self.te_forward_mask_type = True # Only Transformer-Engine version >= 1.0.0 supports context parallelism - if _te_version >= packaging.version.Version("1.0.0"): + if is_te_min_version("1.0.0"): if getattr(TEDotProductAttention, "cp_stream") is None: TEDotProductAttention.cp_stream = torch.cuda.Stream() extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) @@ -106,15 +107,26 @@ def __init__( if config.window_size is not None: # Check version - assert _te_version >= packaging.version.Version("1.2.0"), ( - f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support" - "sliding window attention." + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" "sliding window attention." ) extra_kwargs["window_size"] = config.window_size + if is_te_min_version("1.10.0"): + # TE 1.10.0 introduces the ability to set the different k and v channels + kv_channels = ( + (k_channels, v_channels) + if k_channels is not None and v_channels is not None + else self.config.kv_channels + ) + else: + kv_channels = self.config.kv_channels + + extra_kwargs["softmax_scale"] = softmax_scale + super(TEDotProductAttention, self).__init__( num_attention_heads=self.config.num_attention_heads, - kv_channels=self.config.kv_channels, + kv_channels=kv_channels, attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout), attn_mask_type=attn_mask_type.name, sequence_parallel=self.config.sequence_parallel, @@ -122,7 +134,6 @@ def __init__( get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None), tp_group=get_tensor_model_parallel_group(check_initialized=False), layer_number=layer_number, - softmax_scale=1.0, # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing **extra_kwargs, )