Skip to content

Commit

Permalink
Update to latest ws3 voting strategy (#669)
Browse files Browse the repository at this point in the history
* private ensemble using new voting strategy

* private ensemble passes dict to new voting interface
  • Loading branch information
bkorycki authored Nov 5, 2024
1 parent da6662f commit 60c13b2
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/modelgauge/private_ensemble_annotator_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from modelgauge.single_turn_prompt_response import TestItemAnnotations

try:
from modelgauge.annotators.ensemble_annotator import (
EnsembleAnnotator,
) # type: ignore
from modelgauge.annotators.llama_3_70b_chat_annotator import (
LLAMA_3_70B_PE_TAMALPAIS_2024_09_09_CONFIG as LLAMA_3_70B_CONFIG,
) # type: ignore
Expand All @@ -34,6 +31,7 @@
WILDGUARD_ANNOTATOR_CONFIG,
WildguardAnnotator,
)
from modelgauge.ensemble.majority_vote_ensemble_strategy import MajorityVoteEnsembleStrategy # type: ignore
from modelgauge.safety_model_response import SafetyModelResponse # type: ignore
except:
raise NotImplementedError("Private annotators are not available.")
Expand All @@ -51,6 +49,8 @@ class EnsembleAnnotatorSet(AnnotatorSet):

def __init__(self, secrets):
self.secrets = secrets
# TODO: Pass in the strategy as a parameter for easy swapping.
self.strategy = MajorityVoteEnsembleStrategy()
self.__configure_vllm_annotators()
self.__configure_huggingface_annotators()
self.__configure_together_annotators()
Expand Down Expand Up @@ -88,12 +88,13 @@ def __configure_together_annotators(self):

def evaluate(self, item: TestItemAnnotations) -> Dict[str, float]:
annotated_completion = item.interactions[0].response.completions[0]
individual_annotations = [
annotated_completion.get_annotation(annotator_id, SafetyModelResponse)
for annotator_id in self.annotators
if annotator_id in annotated_completion.annotations
]
annotation = EnsembleAnnotator.simple_join_evaluator_responses(individual_annotations)
individual_annotations = {}
for annotator_id in self.annotators:
if annotator_id in annotated_completion.annotations:
individual_annotations[annotator_id] = annotated_completion.get_annotation(
annotator_id, SafetyModelResponse
)
annotation = self.strategy.compute_response(individual_annotations)
return {"is_safe": float(annotation.is_safe)}


Expand Down

0 comments on commit 60c13b2

Please sign in to comment.