Skip to content

Commit

Permalink
Fix llama gqa attention bias (IBM#88)
Browse files Browse the repository at this point in the history
To support IBM granite code 8b models

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
njhill authored May 8, 2024
1 parent f091ad5 commit e87d462
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0

prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0
prefixes=prefixes, quantize=config.quantize, dim=0
)

if config.quantize != "gptq":
Expand All @@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
if config.attention_bias:
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
else:
bias = None

return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))


class FlashLlamaAttention(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0

prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0
prefixes=prefixes, quantize=config.quantize, dim=0
)

if config.quantize != "gptq":
Expand All @@ -173,7 +172,12 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))
if config.attention_bias:
bias = torch.cat([weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes], dim=0)
else:
bias = None

return TensorParallelColumnLinear(get_linear(weight, bias=bias, quantize=config.quantize))


class PagedLlamaAttention(torch.nn.Module):
Expand Down

0 comments on commit e87d462

Please sign in to comment.