Skip to content

Commit

Permalink
fix attention scale
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 18, 2024
1 parent 4221ce6 commit 16fd91d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
5 changes: 5 additions & 0 deletions assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def start(
batch_size,
batch_size_var_len,
seq_len,
causal,
dim,
use_cuda
):
Expand All @@ -32,6 +33,7 @@ def start(
ring_attention_net = RingTransformer(
num_tokens = 256,
dim = dim,
causal = causal,
depth = 1,
dim_head = 8,
ring_attn = True,
Expand All @@ -43,6 +45,7 @@ def start(
flash_attention_net = RingTransformer(
num_tokens = 256,
dim = dim,
causal = causal,
depth = 1,
dim_head = 8,
ring_attn = False
Expand Down Expand Up @@ -111,6 +114,7 @@ def start(
batch_size = 1
batch_size_var_len = False
use_cuda = False
causal = False

assert not use_cuda or torch.cuda.device_count() <= world_size

Expand All @@ -124,6 +128,7 @@ def start(
batch_size,
batch_size_var_len,
seq_len,
causal,
dim,
use_cuda
),
Expand Down
4 changes: 2 additions & 2 deletions ring_attention_pytorch/ring_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def default_attention(
mask: Optional[Tensor],
causal: bool = False
):
q = q * (q.shape[-1] ** 0.5)

mask_value = -torch.finfo(q.dtype).max

# similarity
Expand Down Expand Up @@ -214,8 +216,6 @@ def forward(
qkv = self.to_qkv(x)
q, k, v = rearrange('b n (qkv h d) -> qkv b h n d', qkv, qkv = 3, h = self.heads)

q = q * self.scale

if self.force_regular_attn or not is_distributed():
out = default_attention(q, k, v, mask = mask, causal = self.causal)
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ring-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Ring Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 16fd91d

Please sign in to comment.