Linear Relational Embeddings (LREs) and Linear Relational Concepts (LRCs) for LLMs using PyTorch and Huggingface Transformers.
Full docs: https://chanind.github.io/linear-relational
This library provides utilities and PyTorch modules for working with LREs and LRCs. LREs estimate the relation between a subject and object in a transformer language model (LM) as a linear map.
This library assumes you're working with sentences with a subject, relation, and object. For instance, in the sentence: "Lyon is located in the country of France" would have the subject "Lyon", relation "located in country", and object "France". A LRE models a relation like "located in country" as a linear map consisting of a weight matrix
LREs can be inverted using a low-rank inverse, shown as
Linear Relational Concepts (LRCs) represent a concept
For more information on LREs and LRCs, check out the following papers:
- Identifying Linear Relational Concepts in Large Language Models
- Linearity of Relation Decoding in Transformer Language Models
pip install linear-relational
This library assumes you're using PyTorch with a decoder-only generative language model (e.g. GPT, LLaMa, etc...), and a tokenizer from Huggingface.
To train a LRE for a relation, first collect prompts which elicit the relation. We provide a Prompt
class to represent this data, and a Trainer
class to make training a LRE easy. Below, we train a LRE to represent the "located in country" relation.
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from linear_relational import Prompt, Trainer
# We load a generative LM from huggingface. The LMHead must be included.
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# Prompts consist of text, an answer, and subject.
# The subject must appear in the text. The answer
# is what the model should respond with, and corresponds to the "object"
prompts = [
Prompt("Paris is located in the country of", "France", subject="Paris"),
Prompt("Shanghai is located in the country of", "China", subject="Shanghai"),
Prompt("Kyoto is located in the country of", "Japan", subject="Kyoto"),
Prompt("San Jose is located in the country of", "Costa Rica", subject="San Jose"),
]
trainer = Trainer(model, tokenizer)
lre = trainer.train_lre(
relation="located in country",
subject_layer=8, # subject layer must be before the object layer
object_layer=10,
prompts=prompts,
)
A LRE is a PyTorch module, so once a LRE is trained, we can use it to predict object activations from subject activations:
object_acts_estimate = lre(subject_acts)
We can also create a low-rank estimate of the LRE:
low_rank_lre = lre.to_low_rank(50)
low_rank_obj_acts_estimate = low_rank_lre(subject_acts)
Finally we can invert the LRE:
inv_lre = lre.invert(rank=50)
subject_acts_estimate = inv_lre(object_acts)
The Trainer
can also create LRCs for a relation. Internally, this first create a LRE, inverts it, then generates LRCs from each object in the relation. Objects refer to the answers in the prompts, e.g. in the example above, "France" is an object, "Japan" is an object, etc...
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from linear_relational import Prompt, Trainer
# We load a generative LM from huggingface. The LMHead must be included.
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# Prompts consist of text, an answer, and subject.
# The subject must appear in the text. The answer
# is what the model should respond with, and corresponds to the "object"
prompts = [
Prompt("Paris is located in the country of", "France", subject="Paris"),
Prompt("Shanghai is located in the country of", "China", subject="Shanghai"),
Prompt("Kyoto is located in the country of", "Japan", subject="Kyoto"),
Prompt("San Jose is located in the country of", "Costa Rica", subject="San Jose"),
]
trainer = Trainer(model, tokenizer)
concepts = trainer.train_relation_concepts(
relation="located in country",
subject_layer=8,
object_layer=10,
prompts=prompts,
max_lre_training_samples=10,
inv_lre_rank=50,
)
Once we have LRCs trained, we can use them to perform causal edits while the model is running. For instance, we can perform a causal edit to make the model output that "Shanghai is located in the country of France" by subtracting the "located in country: China" concept from "Shanghai" and adding the "located in country: France" concept. We can use the CausalEditor
class to perform these edits.
from linear_relational import CausalEditor
concepts = trainer.train_relation_concepts(...)
editor = CausalEditor(model, tokenizer, concepts=concepts)
edited_answer = editor.swap_subject_concepts_and_predict_greedy(
text="Shanghai is located in the country of",
subject="Shanghai",
remove_concept="located in country: China",
add_concept="located in country: France",
edit_single_layer=8,
magnitude_multiplier=3.0,
predict_num_tokens=1,
)
print(edited_answer) # " France"
Above we performed a single-layer edit, only modifying subject activations at layer 8. However, we may want to perform an edit at all subject layers at the same time instead. To do this, we can pass edit_single_layer=False
to editor.swap_subject_concepts_and_predict_greedy()
. We should also reduce the magnitude_multiplier
since now we're going to make the edit at every layer, if we use too large of a multiplier we'll drown out the rest of the activations in the model. The magnitude_multiplier
is a hyperparam that requires tuning depending on the model being edited.
from linear_relational import CausalEditor
concepts = trainer.train_relation_concepts(...)
editor = CausalEditor(model, tokenizer, concepts=concepts)
edited_answer = editor.swap_subject_concepts_and_predict_greedy(
text="Shanghai is located in the country of",
subject="Shanghai",
remove_concept="located in country: China",
add_concept="located in country: France",
edit_single_layer=False,
magnitude_multiplier=0.1,
predict_num_tokens=1,
)
print(edited_answer) # " France"
We can use learned concepts (LRCs) to act like classifiers and match them against subject activations in sentences. We can use the ConceptMatcher
class to do this matching.
from linear_relational import ConceptMatcher
concepts = trainer.train_relation_concepts(...)
matcher = ConceptMatcher(model, tokenizer, concepts=concepts)
match_info = matcher.query("Beijing is a northern city", subject="Beijing")
print(match_info.best_match.concept) # located in country: China
print(match_info.best_match.score) # 0.832
This library is inspired by and uses modified code from the following excellent projects:
Any contributions to improve this project are welcome! Please open an issue or pull request in this repo with any bugfixes / changes / improvements you have!
This project uses Black for code formatting, Flake8 for linting, and Pytest for tests. Make sure any changes you submit pass these code checks in your PR. If you have trouble getting these to run feel free to open a pull-request regardless and we can discuss further in the PR.
This code is released under a MIT license.
If you use this library in your work, please cite the following:
@article{chanin2023identifying,
title={Identifying Linear Relational Concepts in Large Language Models},
author={David Chanin and Anthony Hunter and Oana-Maria Camburu},
journal={arXiv preprint arXiv:2311.08968},
year={2023}
}