From 16fd91d44e1f3a23422c43b2495f8d24767f9a9e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 18 Feb 2024 12:07:02 -0800 Subject: [PATCH] fix attention scale --- assert.py | 5 +++++ ring_attention_pytorch/ring_attention.py | 4 ++-- setup.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/assert.py b/assert.py index 4216550..bde46b9 100644 --- a/assert.py +++ b/assert.py @@ -24,6 +24,7 @@ def start( batch_size, batch_size_var_len, seq_len, + causal, dim, use_cuda ): @@ -32,6 +33,7 @@ def start( ring_attention_net = RingTransformer( num_tokens = 256, dim = dim, + causal = causal, depth = 1, dim_head = 8, ring_attn = True, @@ -43,6 +45,7 @@ def start( flash_attention_net = RingTransformer( num_tokens = 256, dim = dim, + causal = causal, depth = 1, dim_head = 8, ring_attn = False @@ -111,6 +114,7 @@ def start( batch_size = 1 batch_size_var_len = False use_cuda = False + causal = False assert not use_cuda or torch.cuda.device_count() <= world_size @@ -124,6 +128,7 @@ def start( batch_size, batch_size_var_len, seq_len, + causal, dim, use_cuda ), diff --git a/ring_attention_pytorch/ring_attention.py b/ring_attention_pytorch/ring_attention.py index e41dbd1..f9d3f25 100644 --- a/ring_attention_pytorch/ring_attention.py +++ b/ring_attention_pytorch/ring_attention.py @@ -42,6 +42,8 @@ def default_attention( mask: Optional[Tensor], causal: bool = False ): + q = q * (q.shape[-1] ** 0.5) + mask_value = -torch.finfo(q.dtype).max # similarity @@ -214,8 +216,6 @@ def forward( qkv = self.to_qkv(x) q, k, v = rearrange('b n (qkv h d) -> qkv b h n d', qkv, qkv = 3, h = self.heads) - q = q * self.scale - if self.force_regular_attn or not is_distributed(): out = default_attention(q, k, v, mask = mask, causal = self.causal) else: diff --git a/setup.py b/setup.py index 5b52301..d672cb1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ring-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'Ring Attention - Pytorch', author = 'Phil Wang',