Skip to content

Commit

Permalink
feat: adding LRE forward() method (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind authored Aug 4, 2024
1 parent 728ac81 commit 9558e5f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
20 changes: 20 additions & 0 deletions linear_relational/Lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,26 @@ def invert(self, rank: int) -> InvertedLre:
metadata=self.metadata,
)

def forward(
self, subject_activations: torch.Tensor, normalize: bool = False
) -> torch.Tensor:
return self.calculate_object_activation(
subject_activations=subject_activations, normalize=normalize
)

def calculate_object_activation(
self,
subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
normalize: bool = False,
) -> torch.Tensor:
# match precision of weight_inverse and bias
vec = subject_activations @ self.weight.T + self.bias
if len(vec.shape) == 2:
vec = vec.mean(dim=0)
if normalize:
vec = vec / vec.norm()
return vec

def to_low_rank(self, rank: int) -> LowRankLre:
"""Create a low-rank approximation of this LRE"""
u, s, v = self._low_rank_svd(rank)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_Lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ def test_Lre_to_low_rank() -> None:
assert low_rank_lre.__repr__() == "LowRankLre(test, rank 2, layers 5 -> 10, mean)"


def test_Lre_to_low_rank_forward_matches_original_lre() -> None:
bias = torch.tensor([1.0, 0.0, 0.0])
lre = Lre(
relation="test",
subject_layer=5,
object_layer=10,
object_aggregation="mean",
bias=bias,
weight=torch.eye(3) + torch.randn(3, 3),
)
full_low_rank_lre = lre.to_low_rank(rank=3)
test_input = torch.rand(2, 3)
assert torch.allclose(full_low_rank_lre(test_input), lre(test_input), atol=1e-4)


def test_LowRankLre_calculate_object_activation_unnormalized() -> None:
acts = torch.stack(
[
Expand Down
53 changes: 53 additions & 0 deletions tests/training/test_train_lre.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

from linear_relational.lib.extract_token_activations import (
extract_final_token_activations,
extract_token_activations,
)
from linear_relational.lib.token_utils import find_token_range
from linear_relational.training.train_lre import train_lre
from tests.helpers import create_prompt

Expand Down Expand Up @@ -39,3 +45,50 @@ def test_train_lre(model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast) -> None
assert lre.object_layer == 9
assert lre.weight.shape == (768, 768)
assert lre.bias.shape == (768,)


def test_train_lre_on_single_prompt_perfectly_replicates_object(
model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast
) -> None:
fsl_prefixes = "\n".join(
[
"Berlin is located in the country of Germany",
"Toronto is located in the country of Canada",
"Lagos is located in the country of Nigeria",
]
)
prompt = create_prompt(
text=f"{fsl_prefixes}\nTokyo is located in the country of",
answer="Japan",
subject="Tokyo",
)
prompts = [prompt]
lre = train_lre(
model=model,
tokenizer=tokenizer,
layer_matcher="transformer.h.{num}",
relation="city in country",
subject_layer=5,
object_layer=9,
prompts=prompts,
object_aggregation="mean",
)

subj_index = (
find_token_range(tokenizer, tokenizer.encode(prompt.text), prompt.subject)[-1]
- 1
)
subj_act = extract_token_activations(
model=model,
tokenizer=tokenizer,
texts=[prompt.text],
layers=["transformer.h.5"],
token_indices=[subj_index],
)[0]["transformer.h.5"][0]
obj_act = extract_final_token_activations(
model=model,
tokenizer=tokenizer,
texts=[prompt.text],
layers=["transformer.h.9"],
)[0]["transformer.h.9"]
assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)

0 comments on commit 9558e5f

Please sign in to comment.