diff --git a/routellm/routers/matrix_factorization/model.py b/routellm/routers/matrix_factorization/model.py index 09fbb25..c6e1551 100644 --- a/routellm/routers/matrix_factorization/model.py +++ b/routellm/routers/matrix_factorization/model.py @@ -79,10 +79,11 @@ def __init__( text_dim, num_classes, use_proj, + collapse_linear=False, ): super().__init__() - self._name = "TextMF" self.use_proj = use_proj + self.collapse_linear = collapse_linear # collapse the linear transformations into a single linear layer self.P = torch.nn.Embedding(num_models, dim) self.embedding_model = "text-embedding-3-small" @@ -104,19 +105,20 @@ def get_device(self): return self.P.weight.device def forward(self, model_id, prompt): - model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) - - model_embed = self.P(model_id) - model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1) - prompt_embed = ( OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) .data[0] .embedding ) prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) - prompt_embed = self.text_proj(prompt_embed) + model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) + + if self.collapse_linear: + upscaled_model_embed = self.precompute_upscaled_embedding(model_id) + return upscaled_model_embed @ prompt_embed.squeeze(-1) + model_embed = self.P(model_id) + prompt_embed = self.text_proj(prompt_embed) return self.classifier(model_embed * prompt_embed).squeeze() @torch.no_grad() @@ -127,3 +129,22 @@ def pred_win_rate(self, model_a, model_b, prompt): def load(self, path): self.load_state_dict(torch.load(path)) + + def post_process_weight(self): + # since the current model consist of only linear transformations + # we can collapse the linear transformations into a single linear layer + # https://github.com/lm-sys/RouteLLM/issues/9 + num_models = self.P.weight.shape[0] + text_dim = self.text_proj[0].weight.shape[1] + + self.P.weight.data = torch.nn.functional.normalize( + self.P.weight.data, p=2, dim=1 + ) + + if self.collapse_linear: + self.precompute_upscaled_embedding = torch.nn.Embedding( + num_models, text_dim + ) + self.precompute_upscaled_embedding.weight.data = ( + self.P.weight * self.classifier[0].weight.data + ) @ self.text_proj[0].weight.data diff --git a/routellm/routers/routers.py b/routellm/routers/routers.py index 0096c0a..a1efeb3 100644 --- a/routellm/routers/routers.py +++ b/routellm/routers/routers.py @@ -231,6 +231,7 @@ def __init__( num_classes=num_classes, use_proj=use_proj, ) + self.model.post_process_weight() self.model = self.model.eval().to(device) self.strong_model_id = MODEL_IDS[strong_model] self.weak_model_id = MODEL_IDS[weak_model]