Skip to content

Commit

Permalink
update ESM2TEDotProductAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
sichu2023 committed Oct 12, 2024
1 parent 7ba0a16 commit dcd025e
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from pkg_resources import packaging
from megatron.core.utils import get_te_version, is_te_min_version
from torch import Tensor


__all__: Sequence[str] = ("ESM2DotProductAttention", "ESM2TEDotProductAttention")

from megatron.core.extensions.transformer_engine import _te_version


class ESM2TEDotProductAttention(TEDotProductAttention):
"""ESM2-Specific transformer engine core attention.
Expand All @@ -52,6 +50,9 @@ def __init__(
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float | None = None,
softmax_scale: float = 1.0,
k_channels: int | None = None,
v_channels: int | None = None,
):
"""Initialize ESM2TEDotProductAttention."""
self.config = config
Expand All @@ -67,25 +68,25 @@ def __init__(
)

extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)

if _te_version >= packaging.version.Version("0.10.0"):
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type

if _te_version > packaging.version.Version("0.12.0"):
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True

# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if is_te_min_version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
Expand All @@ -106,23 +107,33 @@ def __init__(

if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version("1.2.0"), (
f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
"sliding window attention."
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" "sliding window attention."
)
extra_kwargs["window_size"] = config.window_size

if is_te_min_version("1.10.0"):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels = (
(k_channels, v_channels)
if k_channels is not None and v_channels is not None
else self.config.kv_channels
)
else:
kv_channels = self.config.kv_channels

extra_kwargs["softmax_scale"] = softmax_scale

super(TEDotProductAttention, self).__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
kv_channels=kv_channels,
attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
softmax_scale=1.0, # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
**extra_kwargs,
)

Expand Down

0 comments on commit dcd025e

Please sign in to comment.