Skip to content

Commit

Permalink
fix: bug with gemma cache on non-fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Aug 7, 2024
1 parent a7aaad7 commit d379f1b
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 6 deletions.
7 changes: 7 additions & 0 deletions linear_relational/lib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def get_module(model: nn.Module, name: str) -> nn.Module:
raise LookupError(name)


def get_dtype(model: nn.Module) -> torch.dtype:
"""
Returns the dtype of the model.
"""
return next(model.parameters()).dtype


def get_device(model: nn.Module) -> torch.device:
"""
Returns the device on which the model is running.
Expand Down
13 changes: 11 additions & 2 deletions linear_relational/training/train_lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
find_final_word_token_index,
find_prompt_answer_data,
)
from linear_relational.lib.torch_utils import get_device, untuple_tensor
from linear_relational.lib.torch_utils import (
get_device,
get_dtype,
untuple_tensor,
)
from linear_relational.lib.TraceLayer import TraceLayer
from linear_relational.lib.TraceLayerDict import TraceLayerDict
from linear_relational.Lre import Lre
Expand Down Expand Up @@ -112,7 +116,12 @@ def order_1_approx(
hasattr(model, "config")
and getattr(model.config, "cache_implementation", None) == "hybrid"
):
cache = HybridCache(model.config, input_ids.shape[0], input_ids.shape[1] + 1)
cache = HybridCache(
model.config,
input_ids.shape[0],
input_ids.shape[1] + 1,
dtype=get_dtype(model),
)
cache_position = torch.arange(input_ids.shape[1], device=device)
precache_extra_params["past_key_values"] = cache
precache_extra_params["cache_position"] = cache_position[:subject_index]
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def model() -> GPT2LMHeadModel:
def empty_gemma2_model() -> Gemma2ForCausalLM:
config = Gemma2Config(
num_hidden_layers=3,
hidden_size=1024,
intermediate_size=2752,
hidden_size=64,
intermediate_size=128,
vocab_size=_tokenizer.vocab_size,
)
return Gemma2ForCausalLM(config).eval()
Expand Down
41 changes: 39 additions & 2 deletions tests/training/test_train_lre.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,43 @@ def test_train_lre_on_single_prompt_with_gemma2_perfectly_replicates_object(
texts=[prompt.text],
layers=["model.layers.2"],
)[0]["model.layers.2"]
print(lre(subj_act))
print(obj_act)
assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)


def test_train_lre_works_with_gemma2_and_float16(
empty_gemma2_model: PreTrainedModel, tokenizer: GPT2TokenizerFast
) -> None:
model = empty_gemma2_model.half()
prompt = create_prompt(
text="Tokyo is located in the country of",
answer="Japan",
subject="Tokyo",
)
lre = train_lre(
model=model,
tokenizer=tokenizer,
layer_matcher="model.layers.{num}",
relation="city in country",
subject_layer=1,
object_layer=2,
prompts=[prompt],
).float()

subj_index = (
find_token_range(tokenizer, tokenizer.encode(prompt.text), prompt.subject)[-1]
- 1
)
subj_act = extract_token_activations(
model=model,
tokenizer=tokenizer,
texts=[prompt.text],
layers=["model.layers.1"],
token_indices=[subj_index],
)[0]["model.layers.1"][0]
obj_act = extract_final_token_activations(
model=model,
tokenizer=tokenizer,
texts=[prompt.text],
layers=["model.layers.2"],
)[0]["model.layers.2"]
assert torch.allclose(lre(subj_act), obj_act, atol=1e-4)

0 comments on commit d379f1b

Please sign in to comment.