diff --git a/linear_relational/Lre.py b/linear_relational/Lre.py index 252f017..4a25d9b 100644 --- a/linear_relational/Lre.py +++ b/linear_relational/Lre.py @@ -192,6 +192,26 @@ def invert(self, rank: int) -> InvertedLre: metadata=self.metadata, ) + def forward( + self, subject_activations: torch.Tensor, normalize: bool = False + ) -> torch.Tensor: + return self.calculate_object_activation( + subject_activations=subject_activations, normalize=normalize + ) + + def calculate_object_activation( + self, + subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) + normalize: bool = False, + ) -> torch.Tensor: + # match precision of weight_inverse and bias + vec = subject_activations @ self.weight.T + self.bias + if len(vec.shape) == 2: + vec = vec.mean(dim=0) + if normalize: + vec = vec / vec.norm() + return vec + def to_low_rank(self, rank: int) -> LowRankLre: """Create a low-rank approximation of this LRE""" u, s, v = self._low_rank_svd(rank) diff --git a/tests/test_Lre.py b/tests/test_Lre.py index 8dd43b9..54325bb 100644 --- a/tests/test_Lre.py +++ b/tests/test_Lre.py @@ -51,6 +51,21 @@ def test_Lre_to_low_rank() -> None: assert low_rank_lre.__repr__() == "LowRankLre(test, rank 2, layers 5 -> 10, mean)" +def test_Lre_to_low_rank_forward_matches_original_lre() -> None: + bias = torch.tensor([1.0, 0.0, 0.0]) + lre = Lre( + relation="test", + subject_layer=5, + object_layer=10, + object_aggregation="mean", + bias=bias, + weight=torch.eye(3) + torch.randn(3, 3), + ) + full_low_rank_lre = lre.to_low_rank(rank=3) + test_input = torch.rand(2, 3) + assert torch.allclose(full_low_rank_lre(test_input), lre(test_input), atol=1e-4) + + def test_LowRankLre_calculate_object_activation_unnormalized() -> None: acts = torch.stack( [ diff --git a/tests/training/test_train_lre.py b/tests/training/test_train_lre.py index 7ff471a..de315b1 100644 --- a/tests/training/test_train_lre.py +++ b/tests/training/test_train_lre.py @@ -1,5 +1,11 @@ +import torch from transformers import GPT2LMHeadModel, GPT2TokenizerFast +from linear_relational.lib.extract_token_activations import ( + extract_final_token_activations, + extract_token_activations, +) +from linear_relational.lib.token_utils import find_token_range from linear_relational.training.train_lre import train_lre from tests.helpers import create_prompt @@ -39,3 +45,50 @@ def test_train_lre(model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast) -> None assert lre.object_layer == 9 assert lre.weight.shape == (768, 768) assert lre.bias.shape == (768,) + + +def test_train_lre_on_single_prompt_perfectly_replicates_object( + model: GPT2LMHeadModel, tokenizer: GPT2TokenizerFast +) -> None: + fsl_prefixes = "\n".join( + [ + "Berlin is located in the country of Germany", + "Toronto is located in the country of Canada", + "Lagos is located in the country of Nigeria", + ] + ) + prompt = create_prompt( + text=f"{fsl_prefixes}\nTokyo is located in the country of", + answer="Japan", + subject="Tokyo", + ) + prompts = [prompt] + lre = train_lre( + model=model, + tokenizer=tokenizer, + layer_matcher="transformer.h.{num}", + relation="city in country", + subject_layer=5, + object_layer=9, + prompts=prompts, + object_aggregation="mean", + ) + + subj_index = ( + find_token_range(tokenizer, tokenizer.encode(prompt.text), prompt.subject)[-1] + - 1 + ) + subj_act = extract_token_activations( + model=model, + tokenizer=tokenizer, + texts=[prompt.text], + layers=["transformer.h.5"], + token_indices=[subj_index], + )[0]["transformer.h.5"][0] + obj_act = extract_final_token_activations( + model=model, + tokenizer=tokenizer, + texts=[prompt.text], + layers=["transformer.h.9"], + )[0]["transformer.h.9"] + assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)