diff --git a/.buildinfo b/.buildinfo new file mode 100644 index 0000000..1ea0e9b --- /dev/null +++ b/.buildinfo @@ -0,0 +1,4 @@ +# Sphinx build info version 1 +# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. +config: 63f72c355ea704a4d4b29de044612d50 +tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/.doctrees/about.doctree b/.doctrees/about.doctree new file mode 100644 index 0000000..f1a53b5 Binary files /dev/null and b/.doctrees/about.doctree differ diff --git a/.doctrees/advanced_usage.doctree b/.doctrees/advanced_usage.doctree new file mode 100644 index 0000000..535b2ac Binary files /dev/null and b/.doctrees/advanced_usage.doctree differ diff --git a/.doctrees/api/causal_editor.doctree b/.doctrees/api/causal_editor.doctree new file mode 100644 index 0000000..6e93948 Binary files /dev/null and b/.doctrees/api/causal_editor.doctree differ diff --git a/.doctrees/api/concept.doctree b/.doctrees/api/concept.doctree new file mode 100644 index 0000000..d8356cd Binary files /dev/null and b/.doctrees/api/concept.doctree differ diff --git a/.doctrees/api/concept_matcher.doctree b/.doctrees/api/concept_matcher.doctree new file mode 100644 index 0000000..5d5a9d5 Binary files /dev/null and b/.doctrees/api/concept_matcher.doctree differ diff --git a/.doctrees/api/lre.doctree b/.doctrees/api/lre.doctree new file mode 100644 index 0000000..bdb5fec Binary files /dev/null and b/.doctrees/api/lre.doctree differ diff --git a/.doctrees/api/trainer.doctree b/.doctrees/api/trainer.doctree new file mode 100644 index 0000000..4f51e1d Binary files /dev/null and b/.doctrees/api/trainer.doctree differ diff --git a/.doctrees/basic_usage.doctree b/.doctrees/basic_usage.doctree new file mode 100644 index 0000000..569de28 Binary files /dev/null and b/.doctrees/basic_usage.doctree differ diff --git a/.doctrees/environment.pickle b/.doctrees/environment.pickle new file mode 100644 index 0000000..48683fb Binary files /dev/null and b/.doctrees/environment.pickle differ diff --git a/.doctrees/index.doctree b/.doctrees/index.doctree new file mode 100644 index 0000000..0a647de Binary files /dev/null and b/.doctrees/index.doctree differ diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/_modules/index.html b/_modules/index.html new file mode 100644 index 0000000..88e4050 --- /dev/null +++ b/_modules/index.html @@ -0,0 +1,251 @@ + + +
+ + + + +
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Callable, Literal, Optional, Sequence, TypeVar, Union, cast
+
+import torch
+from tokenizers import Tokenizer
+from torch import nn
+
+from linear_relational.Concept import Concept
+from linear_relational.lib.layer_matching import (
+ LayerMatcher,
+ collect_matching_layers,
+ get_layer_name,
+ guess_hidden_layer_matcher,
+)
+from linear_relational.lib.token_utils import (
+ ensure_tokenizer_has_pad_token,
+ find_final_word_token_index,
+ make_inputs,
+ predict_all_token_probs_from_input,
+ predict_next_tokens_greedy,
+)
+from linear_relational.lib.torch_utils import get_device, untuple_tensor
+from linear_relational.lib.TraceLayerDict import TraceLayerDict
+from linear_relational.lib.util import batchify
+
+EditorSubject = Union[str, int, Callable[[str, list[int]], int]]
+
+
+
+[docs]
+@dataclass
+class ConceptSwapRequest:
+ text: str
+ subject: EditorSubject
+ remove_concept: str
+ add_concept: str
+
+
+
+
+[docs]
+@dataclass
+class ConceptSwapAndPredictGreedyRequest(ConceptSwapRequest):
+ predict_num_tokens: int = 1
+
+
+
+T = TypeVar("T")
+
+
+
+[docs]
+class CausalEditor:
+ """Modify model activations during inference to swap concepts"""
+
+ concepts: list[Concept]
+ model: nn.Module
+ tokenizer: Tokenizer
+ layer_matcher: LayerMatcher
+ layer_name_to_num: dict[str, int]
+
+ def __init__(
+ self,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+ concepts: list[Concept],
+ layer_matcher: Optional[LayerMatcher] = None,
+ ) -> None:
+ self.concepts = concepts
+ self.model = model
+ self.tokenizer = tokenizer
+ self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
+ ensure_tokenizer_has_pad_token(tokenizer)
+ num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
+ self.layer_name_to_num = {}
+ for layer_num in range(num_layers):
+ self.layer_name_to_num[
+ get_layer_name(model, self.layer_matcher, layer_num)
+ ] = layer_num
+
+ @property
+ def device(self) -> torch.device:
+ return get_device(self.model)
+
+
+[docs]
+ def swap_subject_concepts_and_predict_greedy(
+ self,
+ text: str,
+ subject: EditorSubject,
+ remove_concept: str,
+ add_concept: str,
+ # if False, edit the subject token at every layer
+ edit_single_layer: int | Literal[False] = False,
+ predict_num_tokens: int = 1,
+ magnitude_multiplier: float = 1.0,
+ # if True, use the magnitude of the projection of the remove_concept against the subject's original activation
+ # if False, use the magnitude of the subject's original activation
+ use_remove_concept_projection_magnitude: bool = False,
+ ) -> str:
+ results = self.swap_subject_concepts_and_predict_greedy_bulk(
+ [
+ ConceptSwapAndPredictGreedyRequest(
+ text, subject, remove_concept, add_concept, predict_num_tokens
+ )
+ ],
+ magnitude_multiplier=magnitude_multiplier,
+ edit_single_layer=edit_single_layer,
+ use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude,
+ )
+ return results[0]
+
+
+
+[docs]
+ def swap_subject_concepts_and_predict_greedy_bulk(
+ self,
+ requests: Sequence[ConceptSwapAndPredictGreedyRequest],
+ # if False, edit the subject token at every layer
+ edit_single_layer: int | Literal[False] = False,
+ magnitude_multiplier: float = 1.0,
+ # if True, use the magnitude of the projection of the remove_concept against the subject's original activation
+ # if False, use the magnitude of the subject's original activation
+ use_remove_concept_projection_magnitude: bool = False,
+ ) -> list[str]:
+ next_tokens = self.swap_subject_concepts_and_predict_tokens_greedy_bulk(
+ requests,
+ edit_single_layer=edit_single_layer,
+ magnitude_multiplier=magnitude_multiplier,
+ use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude,
+ )
+ return [self.tokenizer.decode(tokens) for tokens in next_tokens]
+
+
+
+[docs]
+ def swap_subject_concepts_and_predict_tokens_greedy_bulk(
+ self,
+ requests: Sequence[ConceptSwapAndPredictGreedyRequest],
+ # if False, edit the subject token at every layer
+ edit_single_layer: int | Literal[False],
+ magnitude_multiplier: float = 1.0,
+ # if True, use the magnitude of the projection of the remove_concept against the subject's original activation
+ # if False, use the magnitude of the subject's original activation
+ use_remove_concept_projection_magnitude: bool = False,
+ batch_size: int = 12,
+ show_progress: bool = False,
+ ) -> list[list[int]]:
+ results: list[list[int]] = []
+ for batch in batchify(requests, batch_size, show_progress=show_progress):
+
+ def run_batch_fn() -> list[list[int]]:
+ max_num_tokens = max(req.predict_num_tokens for req in batch)
+ next_tokens = predict_next_tokens_greedy(
+ self.model,
+ self.tokenizer,
+ [req.text for req in batch],
+ num_tokens=max_num_tokens,
+ device=self.device,
+ )
+ return [
+ tokens[: req.predict_num_tokens]
+ for tokens, req in zip(next_tokens, batch)
+ ]
+
+ results.extend(
+ self._swap_subject_concepts_and_run_batch(
+ batch,
+ run_fn=run_batch_fn,
+ edit_single_layer=edit_single_layer,
+ magnitude_multiplier=magnitude_multiplier,
+ use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude,
+ )
+ )
+ return results
+
+
+
+[docs]
+ def swap_subject_concepts_and_predict_all_token_probs_bulk(
+ self,
+ requests: Sequence[ConceptSwapRequest],
+ magnitude_multiplier: float = 1.0,
+ # if False, edit the subject token at every layer
+ edit_single_layer: int | Literal[False] = False,
+ # if True, use the magnitude of the projection of the remove_concept against the subject's original activation
+ # if False, use the magnitude of the subject's original activation
+ use_remove_concept_projection_magnitude: bool = False,
+ batch_size: int = 12,
+ show_progress: bool = False,
+ ) -> list[torch.Tensor]:
+ results: list[torch.Tensor] = []
+ for batch in batchify(requests, batch_size, show_progress=show_progress):
+
+ def run_batch_fn() -> list[torch.Tensor]:
+ inputs = make_inputs(
+ self.tokenizer, [req.text for req in batch], device=self.device
+ )
+ return predict_all_token_probs_from_input(self.model, inputs)
+
+ results.extend(
+ self._swap_subject_concepts_and_run_batch(
+ batch,
+ run_fn=run_batch_fn,
+ magnitude_multiplier=magnitude_multiplier,
+ edit_single_layer=edit_single_layer,
+ use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude,
+ )
+ )
+ return results
+
+
+ def _swap_subject_concepts_and_run_batch(
+ self,
+ requests: Sequence[ConceptSwapRequest],
+ run_fn: Callable[[], T],
+ edit_single_layer: int | Literal[False],
+ magnitude_multiplier: float = 1.0,
+ # if True, use the magnitude of the projection of the remove_concept against the subject's original activation
+ # if False, use the magnitude of the subject's original activation
+ use_remove_concept_projection_magnitude: bool = False,
+ ) -> T:
+ """
+ Helper to run the given run_fn while swapping the subject concept for each request.
+ The run_fn should run the model with the same batch of inputs as specified in the requests
+ """
+ subj_tokens = [self._find_subject_token(req) for req in requests]
+ with torch.no_grad():
+ remove_concept_vectors = [
+ (
+ self._find_concept(req.remove_concept)
+ .vector.detach()
+ .clone()
+ .type(cast(torch.dtype, self.model.dtype))
+ .to(self.device)
+ )
+ for req in requests
+ ]
+ add_concept_vectors = [
+ (
+ self._find_concept(req.add_concept)
+ .vector.detach()
+ .clone()
+ .type(cast(torch.dtype, self.model.dtype))
+ .to(self.device)
+ )
+ for req in requests
+ ]
+
+ def edit_model_output(output: torch.Tensor, layer_name: str) -> torch.Tensor:
+ if (
+ edit_single_layer is not False
+ and self.layer_name_to_num[layer_name] != edit_single_layer
+ ):
+ return output
+ fixed_output = untuple_tensor(output)
+ for i, subj_token in enumerate(subj_tokens):
+ remove_concept_vector = remove_concept_vectors[i]
+ add_concept_vector = add_concept_vectors[i]
+ original_subj_act = fixed_output[i][subj_token]
+ if use_remove_concept_projection_magnitude:
+ base_magnitude = original_subj_act.dot(remove_concept_vector)
+ else:
+ base_magnitude = original_subj_act.norm()
+ magnitude = base_magnitude * magnitude_multiplier
+ fixed_output[i][subj_token] = original_subj_act + magnitude * (
+ add_concept_vector - remove_concept_vector
+ )
+ return output
+
+ with torch.no_grad(), TraceLayerDict(
+ self.model,
+ layers=self.layer_name_to_num.keys(),
+ edit_output=edit_model_output,
+ ):
+ return run_fn()
+
+ def _find_subject_token(self, query: ConceptSwapRequest) -> int:
+ text = query.text
+ subject = query.subject
+ if isinstance(subject, int):
+ return subject
+ if isinstance(subject, str):
+ return find_final_word_token_index(self.tokenizer, text, subject)
+ if callable(subject):
+ return subject(text, self.tokenizer.encode(text))
+ raise ValueError(f"Unknown subject type: {type(subject)}")
+
+ def _find_concept(self, concept_name: str) -> Concept:
+ for concept in self.concepts:
+ if concept.name == concept_name:
+ return concept
+ raise ValueError(f"Unknown concept: {concept_name}")
+
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import torch
+from torch import nn
+
+
+
+[docs]
+class Concept(nn.Module):
+ """Linear Relation Concept (LRC)"""
+
+ layer: int
+ vector: torch.Tensor
+ object: str
+ relation: str
+ name: str
+ metadata: dict[str, Any]
+
+ def __init__(
+ self,
+ layer: int,
+ vector: torch.Tensor,
+ object: str,
+ relation: str,
+ metadata: Optional[dict[str, Any]] = None,
+ name: Optional[str] = None,
+ ) -> None:
+ super().__init__()
+ self.layer = layer
+ self.vector = vector
+ self.object = object
+ self.relation = relation
+ self.metadata = metadata or {}
+ self.name = name or f"{self.relation}: {self.object}"
+
+
+[docs]
+ def forward(self, activations: torch.Tensor) -> torch.Tensor:
+ vector = self.vector.to(activations.device, dtype=activations.dtype)
+ if len(activations.shape) == 1:
+ return vector @ activations
+ return vector @ activations.T
+
+
+
+from dataclasses import dataclass
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+from tokenizers import Tokenizer
+from torch import nn
+
+from linear_relational.Concept import Concept
+from linear_relational.lib.extract_token_activations import (
+ TokenLayerActivationsList,
+ extract_token_activations,
+)
+from linear_relational.lib.layer_matching import (
+ LayerMatcher,
+ collect_matching_layers,
+ get_layer_name,
+ guess_hidden_layer_matcher,
+)
+from linear_relational.lib.token_utils import (
+ ensure_tokenizer_has_pad_token,
+ find_final_word_token_index,
+)
+from linear_relational.lib.torch_utils import get_device
+from linear_relational.lib.util import batchify
+
+QuerySubject = Union[str, int, Callable[[str, list[int]], int]]
+
+
+
+
+
+
+
+
+
+
+
+[docs]
+@dataclass
+class QueryResult:
+ concept_results: dict[str, ConceptMatchResult]
+
+ @property
+ def best_match(self) -> ConceptMatchResult:
+ return max(self.concept_results.values(), key=lambda x: x.score)
+
+
+
+
+[docs]
+class ConceptMatcher:
+ """Match concepts against subject activations in a model"""
+
+ concepts: list[Concept]
+ model: nn.Module
+ tokenizer: Tokenizer
+ layer_matcher: LayerMatcher
+ layer_name_to_num: dict[str, int]
+ map_activations_fn: (
+ Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
+ )
+
+ def __init__(
+ self,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+ concepts: list[Concept],
+ layer_matcher: Optional[LayerMatcher] = None,
+ map_activations_fn: (
+ Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
+ ) = None,
+ ) -> None:
+ self.concepts = concepts
+ self.model = model
+ self.tokenizer = tokenizer
+ self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
+ self.map_activations_fn = map_activations_fn
+ ensure_tokenizer_has_pad_token(tokenizer)
+ num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
+ self.layer_name_to_num = {}
+ for layer_num in range(num_layers):
+ self.layer_name_to_num[
+ get_layer_name(model, self.layer_matcher, layer_num)
+ ] = layer_num
+
+
+[docs]
+ def query(self, query: str, subject: QuerySubject) -> QueryResult:
+ return self.query_bulk([ConceptMatchQuery(query, subject)])[0]
+
+
+
+[docs]
+ def query_bulk(
+ self,
+ queries: Sequence[ConceptMatchQuery],
+ batch_size: int = 4,
+ verbose: bool = False,
+ ) -> list[QueryResult]:
+ results: list[QueryResult] = []
+ for batch in batchify(queries, batch_size, show_progress=verbose):
+ results.extend(self._query_batch(batch))
+ return results
+
+
+ def _query_batch(self, queries: Sequence[ConceptMatchQuery]) -> list[QueryResult]:
+ subj_tokens = [self._find_subject_token(query) for query in queries]
+ with torch.no_grad():
+ batch_subj_token_activations = extract_token_activations(
+ self.model,
+ self.tokenizer,
+ layers=self.layer_name_to_num.keys(),
+ texts=[q.text for q in queries],
+ token_indices=subj_tokens,
+ device=get_device(self.model),
+ # batching is handled already, so no need to batch here too
+ batch_size=len(queries),
+ show_progress=False,
+ )
+ if self.map_activations_fn is not None:
+ batch_subj_token_activations = self.map_activations_fn(
+ batch_subj_token_activations
+ )
+
+ results: list[QueryResult] = []
+ for raw_subj_token_activations in batch_subj_token_activations:
+ concept_results: dict[str, ConceptMatchResult] = {}
+ # need to replace the layer name with the layer number
+ subj_token_activations = {
+ self.layer_name_to_num[layer_name]: layer_activations[0]
+ for layer_name, layer_activations in raw_subj_token_activations.items()
+ }
+ for concept in self.concepts:
+ concept_results[concept.name] = _apply_concept_to_activations(
+ concept, subj_token_activations
+ )
+ results.append(QueryResult(concept_results))
+ return results
+
+ def _find_subject_token(self, query: ConceptMatchQuery) -> int:
+ text = query.text
+ subject = query.subject
+ if isinstance(subject, int):
+ return subject
+ if isinstance(subject, str):
+ return find_final_word_token_index(self.tokenizer, text, subject)
+ if callable(subject):
+ return subject(text, self.tokenizer.encode(text))
+ raise ValueError(f"Unknown subject type: {type(subject)}")
+
+
+
+@torch.no_grad()
+def _apply_concept_to_activations(
+ concept: Concept, activations: dict[int, torch.Tensor]
+) -> ConceptMatchResult:
+ score = concept.forward(activations[concept.layer]).item()
+ return ConceptMatchResult(
+ concept=concept.name,
+ score=score,
+ )
+
+from typing import Any, Literal
+
+import torch
+from torch import nn
+
+
+
+[docs]
+class InvertedLre(nn.Module):
+ """Low-rank inverted LRE, used for calculating subject activations from object activations"""
+
+ relation: str
+ subject_layer: int
+ object_layer: int
+ # store u, v, s, and bias separately to avoid storing the full weight matrix
+ u: nn.Parameter
+ s: nn.Parameter
+ v: nn.Parameter
+ bias: nn.Parameter
+ object_aggregation: Literal["mean", "first_token"]
+ metadata: dict[str, Any] | None = None
+
+ def __init__(
+ self,
+ relation: str,
+ subject_layer: int,
+ object_layer: int,
+ object_aggregation: Literal["mean", "first_token"],
+ u: torch.Tensor,
+ s: torch.Tensor,
+ v: torch.Tensor,
+ bias: torch.Tensor,
+ metadata: dict[str, Any] | None = None,
+ ) -> None:
+ super().__init__()
+ self.relation = relation
+ self.subject_layer = subject_layer
+ self.object_layer = object_layer
+ self.object_aggregation = object_aggregation
+ self.u = nn.Parameter(u, requires_grad=False)
+ self.s = nn.Parameter(s, requires_grad=False)
+ self.v = nn.Parameter(v, requires_grad=False)
+ self.bias = nn.Parameter(bias, requires_grad=False)
+ self.metadata = metadata
+
+ @property
+ def rank(self) -> int:
+ return self.s.shape[0]
+
+
+[docs]
+ def w_inv_times_vec(self, vec: torch.Tensor) -> torch.Tensor:
+ # group u.T @ vec to avoid calculating larger matrices than needed
+ return self.v @ torch.diag(1 / self.s) @ (self.u.T @ vec)
+
+
+
+[docs]
+ def forward(
+ self,
+ object_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
+ normalize: bool = False,
+ ) -> torch.Tensor:
+ return self.calculate_subject_activation(
+ object_activations=object_activations,
+ normalize=normalize,
+ )
+
+
+
+[docs]
+ def calculate_subject_activation(
+ self,
+ object_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
+ normalize: bool = False,
+ ) -> torch.Tensor:
+ # match precision of weight_inverse and bias
+ unbiased_acts = object_activations - self.bias.unsqueeze(0)
+ vec = self.w_inv_times_vec(unbiased_acts.T).mean(dim=1)
+
+ if normalize:
+ vec = vec / vec.norm()
+ return vec
+
+
+
+
+
+[docs]
+class LowRankLre(nn.Module):
+ """Low-rank approximation of a LRE"""
+
+ relation: str
+ subject_layer: int
+ object_layer: int
+ # store u, v, s, and bias separately to avoid storing the full weight matrix
+ u: nn.Parameter
+ s: nn.Parameter
+ v: nn.Parameter
+ bias: nn.Parameter
+ object_aggregation: Literal["mean", "first_token"]
+ metadata: dict[str, Any] | None = None
+
+ def __init__(
+ self,
+ relation: str,
+ subject_layer: int,
+ object_layer: int,
+ object_aggregation: Literal["mean", "first_token"],
+ u: torch.Tensor,
+ s: torch.Tensor,
+ v: torch.Tensor,
+ bias: torch.Tensor,
+ metadata: dict[str, Any] | None = None,
+ ) -> None:
+ super().__init__()
+ self.relation = relation
+ self.subject_layer = subject_layer
+ self.object_layer = object_layer
+ self.object_aggregation = object_aggregation
+ self.u = nn.Parameter(u, requires_grad=False)
+ self.s = nn.Parameter(s, requires_grad=False)
+ self.v = nn.Parameter(v, requires_grad=False)
+ self.bias = nn.Parameter(bias, requires_grad=False)
+ self.metadata = metadata
+
+ @property
+ def rank(self) -> int:
+ return self.s.shape[0]
+
+
+[docs]
+ def w_times_vec(self, vec: torch.Tensor) -> torch.Tensor:
+ # group v.T @ vec to avoid calculating larger matrices than needed
+ return self.u @ torch.diag(self.s) @ (self.v.T @ vec)
+
+
+
+[docs]
+ def forward(
+ self,
+ subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
+ normalize: bool = False,
+ ) -> torch.Tensor:
+ return self.calculate_object_activation(
+ subject_activations=subject_activations,
+ normalize=normalize,
+ )
+
+
+
+[docs]
+ def calculate_object_activation(
+ self,
+ subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size)
+ normalize: bool = False,
+ ) -> torch.Tensor:
+ # match precision of weight_inverse and bias
+ ws = self.w_times_vec(subject_activations.T)
+ vec = (ws + self.bias.unsqueeze(-1)).mean(dim=1)
+ if normalize:
+ vec = vec / vec.norm()
+ return vec
+
+
+
+
+
+[docs]
+class Lre(nn.Module):
+ """Linear Relational Embedding"""
+
+ relation: str
+ subject_layer: int
+ object_layer: int
+ weight: nn.Parameter
+ bias: nn.Parameter
+ object_aggregation: Literal["mean", "first_token"]
+ metadata: dict[str, Any] | None = None
+
+ def __init__(
+ self,
+ relation: str,
+ subject_layer: int,
+ object_layer: int,
+ object_aggregation: Literal["mean", "first_token"],
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ metadata: dict[str, Any] | None = None,
+ ) -> None:
+ super().__init__()
+ self.relation = relation
+ self.subject_layer = subject_layer
+ self.object_layer = object_layer
+ self.object_aggregation = object_aggregation
+ self.weight = nn.Parameter(weight, requires_grad=False)
+ self.bias = nn.Parameter(bias, requires_grad=False)
+ self.metadata = metadata
+
+
+[docs]
+ def invert(self, rank: int) -> InvertedLre:
+ """Invert this LRE using a low-rank approximation"""
+ u, s, v = self._low_rank_svd(rank)
+ return InvertedLre(
+ relation=self.relation,
+ subject_layer=self.subject_layer,
+ object_layer=self.object_layer,
+ object_aggregation=self.object_aggregation,
+ u=u.detach().clone(),
+ s=s.detach().clone(),
+ v=v.detach().clone(),
+ bias=self.bias.detach().clone(),
+ metadata=self.metadata,
+ )
+
+
+
+[docs]
+ def to_low_rank(self, rank: int) -> LowRankLre:
+ """Create a low-rank approximation of this LRE"""
+ u, s, v = self._low_rank_svd(rank)
+ return LowRankLre(
+ relation=self.relation,
+ subject_layer=self.subject_layer,
+ object_layer=self.object_layer,
+ object_aggregation=self.object_aggregation,
+ u=u.detach().clone(),
+ s=s.detach().clone(),
+ v=v.detach().clone(),
+ bias=self.bias.detach().clone(),
+ metadata=self.metadata,
+ )
+
+
+ @torch.no_grad()
+ def _low_rank_svd(
+ self, rank: int
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # use a float for the svd, then convert back to the original dtype
+ u, s, v = torch.svd(self.weight.float())
+ low_rank_u: torch.Tensor = u[:, :rank].to(self.weight.dtype)
+ low_rank_v: torch.Tensor = v[:, :rank].to(self.weight.dtype)
+ low_rank_s: torch.Tensor = s[:rank].to(self.weight.dtype)
+ return low_rank_u, low_rank_s, low_rank_v
+
+
+from dataclasses import dataclass
+
+
+
+[docs]
+@dataclass(frozen=True, slots=True)
+class Prompt:
+ """A prompt for training LREs and LRCs"""
+
+ text: str
+ answer: str
+ subject: str
+ subject_name: str = "" # If not provided, will be set to subject
+ object_name: str = "" # If not provided, will be set to answer
+
+ def __post_init__(self) -> None:
+ if self.subject_name == "":
+ object.__setattr__(self, "subject_name", self.subject)
+ if self.object_name == "":
+ object.__setattr__(self, "object_name", self.answer)
+
+
+from collections import defaultdict
+from time import time
+from typing import Callable, Literal, Optional
+
+import torch
+from tokenizers import Tokenizer
+from torch import nn
+
+from linear_relational.Concept import Concept
+from linear_relational.lib.balance_grouped_items import balance_grouped_items
+from linear_relational.lib.extract_token_activations import extract_token_activations
+from linear_relational.lib.layer_matching import (
+ LayerMatcher,
+ get_layer_name,
+ guess_hidden_layer_matcher,
+)
+from linear_relational.lib.logger import log_or_print, logger
+from linear_relational.lib.token_utils import PromptAnswerData, find_prompt_answer_data
+from linear_relational.lib.torch_utils import get_device
+from linear_relational.lib.util import group_items
+from linear_relational.Lre import InvertedLre, Lre
+from linear_relational.Prompt import Prompt
+from linear_relational.PromptValidator import PromptValidator
+from linear_relational.training.train_lre import ObjectAggregation, train_lre
+
+VectorAggregation = Literal["pre_mean", "post_mean"]
+
+
+
+[docs]
+class Trainer:
+ """Train LREs and concepts from prompts"""
+
+ model: nn.Module
+ tokenizer: Tokenizer
+ layer_matcher: LayerMatcher
+ prompt_validator: PromptValidator
+
+ def __init__(
+ self,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+ layer_matcher: Optional[LayerMatcher] = None,
+ prompt_validator: Optional[PromptValidator] = None,
+ ) -> None:
+ self.model = model
+ self.tokenizer = tokenizer
+ self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
+ self.prompt_validator = prompt_validator or PromptValidator(model, tokenizer)
+
+
+[docs]
+ def train_lre(
+ self,
+ relation: str,
+ subject_layer: int,
+ object_layer: int,
+ prompts: list[Prompt],
+ max_lre_training_samples: int | None = None,
+ object_aggregation: ObjectAggregation = "mean",
+ validate_prompts: bool = True,
+ validate_prompts_batch_size: int = 4,
+ move_to_cpu: bool = False,
+ verbose: bool = True,
+ seed: int | str | float = 42,
+ ) -> Lre:
+ processed_prompts = self._process_relation_prompts(
+ relation=relation,
+ prompts=prompts,
+ validate_prompts=validate_prompts,
+ validate_prompts_batch_size=validate_prompts_batch_size,
+ verbose=verbose,
+ )
+ prompts_by_object = group_items(processed_prompts, lambda p: p.object_name)
+ lre_train_prompts = balance_grouped_items(
+ items_by_group=prompts_by_object,
+ max_total=max_lre_training_samples,
+ seed=seed,
+ )
+ return train_lre(
+ model=self.model,
+ tokenizer=self.tokenizer,
+ layer_matcher=self.layer_matcher,
+ relation=relation,
+ subject_layer=subject_layer,
+ object_layer=object_layer,
+ prompts=lre_train_prompts,
+ object_aggregation=object_aggregation,
+ move_to_cpu=move_to_cpu,
+ )
+
+
+
+[docs]
+ def train_relation_concepts(
+ self,
+ relation: str,
+ subject_layer: int,
+ object_layer: int,
+ prompts: list[Prompt],
+ max_lre_training_samples: int | None = 20,
+ object_aggregation: ObjectAggregation = "mean",
+ vector_aggregation: VectorAggregation = "post_mean",
+ inv_lre_rank: int = 200,
+ validate_prompts_batch_size: int = 4,
+ validate_prompts: bool = True,
+ verbose: bool = True,
+ name_concept_fn: Optional[Callable[[str, str], str]] = None,
+ seed: int | str | float = 42,
+ ) -> list[Concept]:
+ processed_prompts = self._process_relation_prompts(
+ relation=relation,
+ prompts=prompts,
+ validate_prompts=validate_prompts,
+ validate_prompts_batch_size=validate_prompts_batch_size,
+ verbose=verbose,
+ )
+ prompts_by_object = group_items(processed_prompts, lambda p: p.object_name)
+ if len(prompts_by_object) == 1:
+ logger.warning(
+ f"Only one valid object found for {relation}. Results may be poor."
+ )
+ lre_train_prompts = balance_grouped_items(
+ items_by_group=prompts_by_object,
+ max_total=max_lre_training_samples,
+ seed=seed,
+ )
+ inv_lre = train_lre(
+ model=self.model,
+ tokenizer=self.tokenizer,
+ layer_matcher=self.layer_matcher,
+ relation=relation,
+ subject_layer=subject_layer,
+ object_layer=object_layer,
+ prompts=lre_train_prompts,
+ object_aggregation=object_aggregation,
+ ).invert(inv_lre_rank)
+
+ return self.train_relation_concepts_from_inv_lre(
+ relation=relation,
+ inv_lre=inv_lre,
+ prompts=processed_prompts,
+ vector_aggregation=vector_aggregation,
+ object_aggregation=object_aggregation,
+ object_layer=object_layer,
+ validate_prompts_batch_size=validate_prompts_batch_size,
+ validate_prompts=False, # we already validated the prompts above
+ name_concept_fn=name_concept_fn,
+ verbose=verbose,
+ )
+
+
+
+[docs]
+ def train_relation_concepts_from_inv_lre(
+ self,
+ inv_lre: InvertedLre | Callable[[str], InvertedLre],
+ prompts: list[Prompt],
+ vector_aggregation: VectorAggregation = "post_mean",
+ object_aggregation: ObjectAggregation | None = None,
+ relation: str | None = None,
+ object_layer: int | None = None,
+ validate_prompts_batch_size: int = 4,
+ extract_objects_batch_size: int = 4,
+ validate_prompts: bool = True,
+ name_concept_fn: Optional[Callable[[str, str], str]] = None,
+ verbose: bool = True,
+ ) -> list[Concept]:
+ if isinstance(inv_lre, InvertedLre):
+ if object_aggregation is None:
+ object_aggregation = inv_lre.object_aggregation
+ if object_layer is None:
+ object_layer = inv_lre.object_layer
+ if relation is None:
+ relation = inv_lre.relation
+ if object_aggregation is None:
+ raise ValueError(
+ "object_aggregation must be specified if inv_lre is a function"
+ )
+ if object_layer is None:
+ raise ValueError("object_layer must be specified if inv_lre is a function")
+ if relation is None:
+ raise ValueError("relation must be specified if inv_lre is a function")
+ processed_prompts = self._process_relation_prompts(
+ relation=relation,
+ prompts=prompts,
+ validate_prompts=validate_prompts,
+ validate_prompts_batch_size=validate_prompts_batch_size,
+ verbose=verbose,
+ )
+ start_time = time()
+ object_activations = self._extract_target_object_activations_for_inv_lre(
+ prompts=processed_prompts,
+ batch_size=extract_objects_batch_size,
+ object_aggregation=object_aggregation,
+ object_layer=object_layer,
+ show_progress=verbose,
+ move_to_cpu=True,
+ )
+ logger.info(
+ f"Extracted {len(object_activations)} object activations in {time() - start_time:.2f}s"
+ )
+ concepts: list[Concept] = []
+
+ with torch.no_grad():
+ for (
+ object_name,
+ activations,
+ ) in object_activations.items():
+ resolved_inv_lre = (
+ inv_lre
+ if isinstance(inv_lre, InvertedLre)
+ else inv_lre(object_name)
+ )
+ name = None
+ if name_concept_fn is not None:
+ name = name_concept_fn(relation, object_name)
+ concept = self._build_concept(
+ relation_name=relation,
+ layer=resolved_inv_lre.subject_layer,
+ inv_lre=resolved_inv_lre,
+ object_name=object_name,
+ activations=activations,
+ vector_aggregation=vector_aggregation,
+ name=name,
+ )
+ concepts.append(concept)
+ return concepts
+
+
+ def _process_relation_prompts(
+ self,
+ relation: str,
+ prompts: list[Prompt],
+ validate_prompts: bool,
+ validate_prompts_batch_size: int,
+ verbose: bool,
+ ) -> list[Prompt]:
+ valid_prompts = prompts
+ if validate_prompts:
+ log_or_print(f"validating {len(prompts)} prompts", verbose=verbose)
+ valid_prompts = self.prompt_validator.filter_prompts(
+ prompts, validate_prompts_batch_size, verbose
+ )
+ if len(valid_prompts) == 0:
+ raise ValueError(f"No valid prompts found for {relation}.")
+ return valid_prompts
+
+ def _build_concept(
+ self,
+ layer: int,
+ relation_name: str,
+ object_name: str,
+ activations: list[torch.Tensor],
+ inv_lre: InvertedLre,
+ vector_aggregation: VectorAggregation,
+ name: str | None,
+ ) -> Concept:
+ device = inv_lre.bias.device
+ dtype = inv_lre.bias.dtype
+ if vector_aggregation == "pre_mean":
+ acts = [torch.stack(activations).to(device=device, dtype=dtype).mean(dim=0)]
+ elif vector_aggregation == "post_mean":
+ acts = [act.to(device=device, dtype=dtype) for act in activations]
+ else:
+ raise ValueError(f"Unknown vector aggregation method {vector_aggregation}")
+ vecs = [
+ inv_lre.calculate_subject_activation(act, normalize=False) for act in acts
+ ]
+ vec = torch.stack(vecs).mean(dim=0)
+ vec = vec / vec.norm()
+ return Concept(
+ name=name,
+ object=object_name,
+ relation=relation_name,
+ layer=layer,
+ vector=vec.detach().clone().cpu(),
+ )
+
+ @torch.no_grad()
+ def _extract_target_object_activations_for_inv_lre(
+ self,
+ object_layer: int,
+ object_aggregation: Literal["mean", "first_token"],
+ prompts: list[Prompt],
+ batch_size: int,
+ show_progress: bool = False,
+ move_to_cpu: bool = True,
+ ) -> dict[str, list[torch.Tensor]]:
+ activations_by_object: dict[str, list[torch.Tensor]] = defaultdict(list)
+ prompt_answer_data: list[PromptAnswerData] = []
+ for prompt in prompts:
+ prompt_answer_data.append(
+ find_prompt_answer_data(self.tokenizer, prompt.text, prompt.answer)
+ )
+
+ layer_name = get_layer_name(self.model, self.layer_matcher, object_layer)
+ raw_activations = extract_token_activations(
+ self.model,
+ self.tokenizer,
+ layers=[layer_name],
+ texts=[prompt_answer.full_prompt for prompt_answer in prompt_answer_data],
+ token_indices=[
+ prompt_answer.output_answer_token_indices
+ for prompt_answer in prompt_answer_data
+ ],
+ device=get_device(self.model),
+ batch_size=batch_size,
+ show_progress=show_progress,
+ move_results_to_cpu=move_to_cpu,
+ )
+ for prompt, raw_activation in zip(prompts, raw_activations):
+ if object_aggregation == "mean":
+ activation = torch.stack(raw_activation[layer_name]).mean(dim=0)
+ elif object_aggregation == "first_token":
+ activation = raw_activation[layer_name][0]
+ else:
+ raise ValueError(
+ f"Unknown inv_lre.object_aggregation: {object_aggregation}"
+ )
+ activations_by_object[prompt.object_name].append(activation)
+ return activations_by_object
+
+
' + + '' + + _("Hide Search Matches") + + "
" + ) + ); + }, + + /** + * helper function to hide the search marks again + */ + hideSearchWords: () => { + document + .querySelectorAll("#searchbox .highlight-link") + .forEach((el) => el.remove()); + document + .querySelectorAll("span.highlighted") + .forEach((el) => el.classList.remove("highlighted")); + localStorage.removeItem("sphinx_highlight_terms") + }, + + initEscapeListener: () => { + // only install a listener if it is really needed + if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; + + document.addEventListener("keydown", (event) => { + // bail for input elements + if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; + // bail with special keys + if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; + if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { + SphinxHighlight.hideSearchWords(); + event.preventDefault(); + } + }); + }, +}; + +_ready(() => { + /* Do not call highlightSearchWords() when we are on the search page. + * It will highlight words from the *previous* search query. + */ + if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords(); + SphinxHighlight.initEscapeListener(); +}); diff --git a/_static/styles/furo-extensions.css b/_static/styles/furo-extensions.css new file mode 100644 index 0000000..bc447f2 --- /dev/null +++ b/_static/styles/furo-extensions.css @@ -0,0 +1,2 @@ +#furo-sidebar-ad-placement{padding:var(--sidebar-item-spacing-vertical) var(--sidebar-item-spacing-horizontal)}#furo-sidebar-ad-placement .ethical-sidebar{background:var(--color-background-secondary);border:none;box-shadow:none}#furo-sidebar-ad-placement .ethical-sidebar:hover{background:var(--color-background-hover)}#furo-sidebar-ad-placement .ethical-sidebar a{color:var(--color-foreground-primary)}#furo-sidebar-ad-placement .ethical-callout a{color:var(--color-foreground-secondary)!important}#furo-readthedocs-versions{background:transparent;display:block;position:static;width:100%}#furo-readthedocs-versions .rst-versions{background:#1a1c1e}#furo-readthedocs-versions .rst-current-version{background:var(--color-sidebar-item-background);cursor:unset}#furo-readthedocs-versions .rst-current-version:hover{background:var(--color-sidebar-item-background)}#furo-readthedocs-versions .rst-current-version .fa-book{color:var(--color-foreground-primary)}#furo-readthedocs-versions>.rst-other-versions{padding:0}#furo-readthedocs-versions>.rst-other-versions small{opacity:1}#furo-readthedocs-versions .injected .rst-versions{position:unset}#furo-readthedocs-versions:focus-within,#furo-readthedocs-versions:hover{box-shadow:0 0 0 1px var(--color-sidebar-background-border)}#furo-readthedocs-versions:focus-within .rst-current-version,#furo-readthedocs-versions:hover .rst-current-version{background:#1a1c1e;font-size:inherit;height:auto;line-height:inherit;padding:12px;text-align:right}#furo-readthedocs-versions:focus-within .rst-current-version .fa-book,#furo-readthedocs-versions:hover .rst-current-version .fa-book{color:#fff;float:left}#furo-readthedocs-versions:focus-within .fa-caret-down,#furo-readthedocs-versions:hover .fa-caret-down{display:none}#furo-readthedocs-versions:focus-within .injected,#furo-readthedocs-versions:focus-within .rst-current-version,#furo-readthedocs-versions:focus-within .rst-other-versions,#furo-readthedocs-versions:hover .injected,#furo-readthedocs-versions:hover .rst-current-version,#furo-readthedocs-versions:hover .rst-other-versions{display:block}#furo-readthedocs-versions:focus-within>.rst-current-version,#furo-readthedocs-versions:hover>.rst-current-version{display:none}.highlight:hover button.copybtn{color:var(--color-code-foreground)}.highlight button.copybtn{align-items:center;background-color:var(--color-code-background);border:none;color:var(--color-background-item);cursor:pointer;height:1.25em;opacity:1;right:.5rem;top:.625rem;transition:color .3s,opacity .3s;width:1.25em}.highlight button.copybtn:hover{background-color:var(--color-code-background);color:var(--color-brand-content)}.highlight button.copybtn:after{background-color:transparent;color:var(--color-code-foreground);display:none}.highlight button.copybtn.success{color:#22863a;transition:color 0ms}.highlight button.copybtn.success:after{display:block}.highlight button.copybtn svg{padding:0}body{--sd-color-primary:var(--color-brand-primary);--sd-color-primary-highlight:var(--color-brand-content);--sd-color-primary-text:var(--color-background-primary);--sd-color-shadow:rgba(0,0,0,.05);--sd-color-card-border:var(--color-card-border);--sd-color-card-border-hover:var(--color-brand-content);--sd-color-card-background:var(--color-card-background);--sd-color-card-text:var(--color-foreground-primary);--sd-color-card-header:var(--color-card-marginals-background);--sd-color-card-footer:var(--color-card-marginals-background);--sd-color-tabs-label-active:var(--color-brand-content);--sd-color-tabs-label-hover:var(--color-foreground-muted);--sd-color-tabs-label-inactive:var(--color-foreground-muted);--sd-color-tabs-underline-active:var(--color-brand-content);--sd-color-tabs-underline-hover:var(--color-foreground-border);--sd-color-tabs-underline-inactive:var(--color-background-border);--sd-color-tabs-overline:var(--color-background-border);--sd-color-tabs-underline:var(--color-background-border)}.sd-tab-content{box-shadow:0 -2px var(--sd-color-tabs-overline),0 1px var(--sd-color-tabs-underline)}.sd-card{box-shadow:0 .1rem .25rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)}.sd-shadow-sm{box-shadow:0 .1rem .25rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)!important}.sd-shadow-md{box-shadow:0 .3rem .75rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)!important}.sd-shadow-lg{box-shadow:0 .6rem 1.5rem var(--sd-color-shadow),0 0 .0625rem rgba(0,0,0,.1)!important}.sd-card-hover:hover{transform:none}.sd-cards-carousel{gap:.25rem;padding:.25rem}body{--tabs--label-text:var(--color-foreground-muted);--tabs--label-text--hover:var(--color-foreground-muted);--tabs--label-text--active:var(--color-brand-content);--tabs--label-text--active--hover:var(--color-brand-content);--tabs--label-background:transparent;--tabs--label-background--hover:transparent;--tabs--label-background--active:transparent;--tabs--label-background--active--hover:transparent;--tabs--padding-x:0.25em;--tabs--margin-x:1em;--tabs--border:var(--color-background-border);--tabs--label-border:transparent;--tabs--label-border--hover:var(--color-foreground-muted);--tabs--label-border--active:var(--color-brand-content);--tabs--label-border--active--hover:var(--color-brand-content)}[role=main] .container{max-width:none;padding-left:0;padding-right:0}.shadow.docutils{border:none;box-shadow:0 .2rem .5rem rgba(0,0,0,.05),0 0 .0625rem rgba(0,0,0,.1)!important}.sphinx-bs .card{background-color:var(--color-background-secondary);color:var(--color-foreground)} +/*# sourceMappingURL=furo-extensions.css.map*/ \ No newline at end of file diff --git a/_static/styles/furo-extensions.css.map b/_static/styles/furo-extensions.css.map new file mode 100644 index 0000000..9ba5637 --- /dev/null +++ b/_static/styles/furo-extensions.css.map @@ -0,0 +1 @@ +{"version":3,"file":"styles/furo-extensions.css","mappings":"AAGA,2BACE,oFACA,4CAKE,6CAHA,YACA,eAEA,CACA,kDACE,yCAEF,8CACE,sCAEJ,8CACE,kDAEJ,2BAGE,uBACA,cAHA,gBACA,UAEA,CAGA,yCACE,mBAEF,gDAEE,gDADA,YACA,CACA,sDACE,gDACF,yDACE,sCAEJ,+CACE,UACA,qDACE,UAGF,mDACE,eAEJ,yEAEE,4DAEA,mHASE,mBAPA,kBAEA,YADA,oBAGA,aADA,gBAIA,CAEA,qIAEE,WADA,UACA,CAEJ,uGACE,aAEF,iUAGE,cAEF,mHACE,aC1EJ,gCACE,mCAEF,0BAKE,mBAUA,8CACA,YAFA,mCAKA,eAZA,cALA,UASA,YADA,YAYA,iCAdA,YAcA,CAEA,gCAEE,8CADA,gCACA,CAEF,gCAGE,6BADA,mCADA,YAEA,CAEF,kCAEE,cADA,oBACA,CACA,wCACE,cAEJ,8BACE,UC5CN,KAEE,6CAA8C,CAC9C,uDAAwD,CACxD,uDAAwD,CAGxD,iCAAsC,CAGtC,+CAAgD,CAChD,uDAAwD,CACxD,uDAAwD,CACxD,oDAAqD,CACrD,6DAA8D,CAC9D,6DAA8D,CAG9D,uDAAwD,CACxD,yDAA0D,CAC1D,4DAA6D,CAC7D,2DAA4D,CAC5D,8DAA+D,CAC/D,iEAAkE,CAClE,uDAAwD,CACxD,wDAAyD,CAG3D,gBACE,qFAGF,SACE,6EAEF,cACE,uFAEF,cACE,uFAEF,cACE,uFAGF,qBACE,eAEF,mBACE,WACA,eChDF,KACE,gDAAiD,CACjD,uDAAwD,CACxD,qDAAsD,CACtD,4DAA6D,CAC7D,oCAAqC,CACrC,2CAA4C,CAC5C,4CAA6C,CAC7C,mDAAoD,CACpD,wBAAyB,CACzB,oBAAqB,CACrB,6CAA8C,CAC9C,gCAAiC,CACjC,yDAA0D,CAC1D,uDAAwD,CACxD,8DAA+D,CCbjE,uBACE,eACA,eACA,gBAGF,iBACE,YACA,+EAGF,iBACE,mDACA","sources":["webpack:///./src/furo/assets/styles/extensions/_readthedocs.sass","webpack:///./src/furo/assets/styles/extensions/_copybutton.sass","webpack:///./src/furo/assets/styles/extensions/_sphinx-design.sass","webpack:///./src/furo/assets/styles/extensions/_sphinx-inline-tabs.sass","webpack:///./src/furo/assets/styles/extensions/_sphinx-panels.sass"],"sourcesContent":["// This file contains the styles used for tweaking how ReadTheDoc's embedded\n// contents would show up inside the theme.\n\n#furo-sidebar-ad-placement\n padding: var(--sidebar-item-spacing-vertical) var(--sidebar-item-spacing-horizontal)\n .ethical-sidebar\n // Remove the border and box-shadow.\n border: none\n box-shadow: none\n // Manage the background colors.\n background: var(--color-background-secondary)\n &:hover\n background: var(--color-background-hover)\n // Ensure the text is legible.\n a\n color: var(--color-foreground-primary)\n\n .ethical-callout a\n color: var(--color-foreground-secondary) !important\n\n#furo-readthedocs-versions\n position: static\n width: 100%\n background: transparent\n display: block\n\n // Make the background color fit with the theme's aesthetic.\n .rst-versions\n background: rgb(26, 28, 30)\n\n .rst-current-version\n cursor: unset\n background: var(--color-sidebar-item-background)\n &:hover\n background: var(--color-sidebar-item-background)\n .fa-book\n color: var(--color-foreground-primary)\n\n > .rst-other-versions\n padding: 0\n small\n opacity: 1\n\n .injected\n .rst-versions\n position: unset\n\n &:hover,\n &:focus-within\n box-shadow: 0 0 0 1px var(--color-sidebar-background-border)\n\n .rst-current-version\n // Undo the tweaks done in RTD's CSS\n font-size: inherit\n line-height: inherit\n height: auto\n text-align: right\n padding: 12px\n\n // Match the rest of the body\n background: #1a1c1e\n\n .fa-book\n float: left\n color: white\n\n .fa-caret-down\n display: none\n\n .rst-current-version,\n .rst-other-versions,\n .injected\n display: block\n\n > .rst-current-version\n display: none\n",".highlight\n &:hover button.copybtn\n color: var(--color-code-foreground)\n\n button.copybtn\n // Make it visible\n opacity: 1\n\n // Align things correctly\n align-items: center\n\n height: 1.25em\n width: 1.25em\n\n top: 0.625rem // $code-spacing-vertical\n right: 0.5rem\n\n // Make it look better\n color: var(--color-background-item)\n background-color: var(--color-code-background)\n border: none\n\n // Change to cursor to make it obvious that you can click on it\n cursor: pointer\n\n // Transition smoothly, for aesthetics\n transition: color 300ms, opacity 300ms\n\n &:hover\n color: var(--color-brand-content)\n background-color: var(--color-code-background)\n\n &::after\n display: none\n color: var(--color-code-foreground)\n background-color: transparent\n\n &.success\n transition: color 0ms\n color: #22863a\n &::after\n display: block\n\n svg\n padding: 0\n","body\n // Colors\n --sd-color-primary: var(--color-brand-primary)\n --sd-color-primary-highlight: var(--color-brand-content)\n --sd-color-primary-text: var(--color-background-primary)\n\n // Shadows\n --sd-color-shadow: rgba(0, 0, 0, 0.05)\n\n // Cards\n --sd-color-card-border: var(--color-card-border)\n --sd-color-card-border-hover: var(--color-brand-content)\n --sd-color-card-background: var(--color-card-background)\n --sd-color-card-text: var(--color-foreground-primary)\n --sd-color-card-header: var(--color-card-marginals-background)\n --sd-color-card-footer: var(--color-card-marginals-background)\n\n // Tabs\n --sd-color-tabs-label-active: var(--color-brand-content)\n --sd-color-tabs-label-hover: var(--color-foreground-muted)\n --sd-color-tabs-label-inactive: var(--color-foreground-muted)\n --sd-color-tabs-underline-active: var(--color-brand-content)\n --sd-color-tabs-underline-hover: var(--color-foreground-border)\n --sd-color-tabs-underline-inactive: var(--color-background-border)\n --sd-color-tabs-overline: var(--color-background-border)\n --sd-color-tabs-underline: var(--color-background-border)\n\n// Tabs\n.sd-tab-content\n box-shadow: 0 -2px var(--sd-color-tabs-overline), 0 1px var(--sd-color-tabs-underline)\n\n// Shadows\n.sd-card // Have a shadow by default\n box-shadow: 0 0.1rem 0.25rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1)\n\n.sd-shadow-sm\n box-shadow: 0 0.1rem 0.25rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n.sd-shadow-md\n box-shadow: 0 0.3rem 0.75rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n.sd-shadow-lg\n box-shadow: 0 0.6rem 1.5rem var(--sd-color-shadow), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n// Cards\n.sd-card-hover:hover // Don't change scale on hover\n transform: none\n\n.sd-cards-carousel // Have a bit of gap in the carousel by default\n gap: 0.25rem\n padding: 0.25rem\n","// This file contains styles to tweak sphinx-inline-tabs to work well with Furo.\n\nbody\n --tabs--label-text: var(--color-foreground-muted)\n --tabs--label-text--hover: var(--color-foreground-muted)\n --tabs--label-text--active: var(--color-brand-content)\n --tabs--label-text--active--hover: var(--color-brand-content)\n --tabs--label-background: transparent\n --tabs--label-background--hover: transparent\n --tabs--label-background--active: transparent\n --tabs--label-background--active--hover: transparent\n --tabs--padding-x: 0.25em\n --tabs--margin-x: 1em\n --tabs--border: var(--color-background-border)\n --tabs--label-border: transparent\n --tabs--label-border--hover: var(--color-foreground-muted)\n --tabs--label-border--active: var(--color-brand-content)\n --tabs--label-border--active--hover: var(--color-brand-content)\n","// This file contains styles to tweak sphinx-panels to work well with Furo.\n\n// sphinx-panels includes Bootstrap 4, which uses .container which can conflict\n// with docutils' `.. container::` directive.\n[role=\"main\"] .container\n max-width: initial\n padding-left: initial\n padding-right: initial\n\n// Make the panels look nicer!\n.shadow.docutils\n border: none\n box-shadow: 0 0.2rem 0.5rem rgba(0, 0, 0, 0.05), 0 0 0.0625rem rgba(0, 0, 0, 0.1) !important\n\n// Make panel colors respond to dark mode\n.sphinx-bs .card\n background-color: var(--color-background-secondary)\n color: var(--color-foreground)\n"],"names":[],"sourceRoot":""} \ No newline at end of file diff --git a/_static/styles/furo.css b/_static/styles/furo.css new file mode 100644 index 0000000..3d29a21 --- /dev/null +++ b/_static/styles/furo.css @@ -0,0 +1,2 @@ +/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{-webkit-text-size-adjust:100%;line-height:1.15}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}[hidden],template{display:none}@media print{.content-icon-container,.headerlink,.mobile-header,.related-pages{display:none!important}.highlight{border:.1pt solid var(--color-foreground-border)}a,blockquote,dl,ol,pre,table,ul{page-break-inside:avoid}caption,figure,h1,h2,h3,h4,h5,h6,img{page-break-after:avoid;page-break-inside:avoid}dl,ol,ul{page-break-before:avoid}}.visually-hidden{clip:rect(0,0,0,0)!important;border:0!important;height:1px!important;margin:-1px!important;overflow:hidden!important;padding:0!important;position:absolute!important;white-space:nowrap!important;width:1px!important}:-moz-focusring{outline:auto}body{--font-stack:-apple-system,BlinkMacSystemFont,Segoe UI,Helvetica,Arial,sans-serif,Apple Color Emoji,Segoe UI Emoji;--font-stack--monospace:"SFMono-Regular",Menlo,Consolas,Monaco,Liberation Mono,Lucida Console,monospace;--font-size--normal:100%;--font-size--small:87.5%;--font-size--small--2:81.25%;--font-size--small--3:75%;--font-size--small--4:62.5%;--sidebar-caption-font-size:var(--font-size--small--2);--sidebar-item-font-size:var(--font-size--small);--sidebar-search-input-font-size:var(--font-size--small);--toc-font-size:var(--font-size--small--3);--toc-font-size--mobile:var(--font-size--normal);--toc-title-font-size:var(--font-size--small--4);--admonition-font-size:0.8125rem;--admonition-title-font-size:0.8125rem;--code-font-size:var(--font-size--small--2);--api-font-size:var(--font-size--small);--header-height:calc(var(--sidebar-item-line-height) + var(--sidebar-item-spacing-vertical)*4);--header-padding:0.5rem;--sidebar-tree-space-above:1.5rem;--sidebar-caption-space-above:1rem;--sidebar-item-line-height:1rem;--sidebar-item-spacing-vertical:0.5rem;--sidebar-item-spacing-horizontal:1rem;--sidebar-item-height:calc(var(--sidebar-item-line-height) + var(--sidebar-item-spacing-vertical)*2);--sidebar-expander-width:var(--sidebar-item-height);--sidebar-search-space-above:0.5rem;--sidebar-search-input-spacing-vertical:0.5rem;--sidebar-search-input-spacing-horizontal:0.5rem;--sidebar-search-input-height:1rem;--sidebar-search-icon-size:var(--sidebar-search-input-height);--toc-title-padding:0.25rem 0;--toc-spacing-vertical:1.5rem;--toc-spacing-horizontal:1.5rem;--toc-item-spacing-vertical:0.4rem;--toc-item-spacing-horizontal:1rem;--icon-search:url('data:image/svg+xml;charset=utf-8,