From 60c13b238467deb6a651d9bcf7d0cdaa711b35c1 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 4 Nov 2024 20:44:38 -0500 Subject: [PATCH] Update to latest ws3 voting strategy (#669) * private ensemble using new voting strategy * private ensemble passes dict to new voting interface --- .../private_ensemble_annotator_set.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/modelgauge/private_ensemble_annotator_set.py b/src/modelgauge/private_ensemble_annotator_set.py index 8ff7044b..c6f0f7e1 100644 --- a/src/modelgauge/private_ensemble_annotator_set.py +++ b/src/modelgauge/private_ensemble_annotator_set.py @@ -12,9 +12,6 @@ from modelgauge.single_turn_prompt_response import TestItemAnnotations try: - from modelgauge.annotators.ensemble_annotator import ( - EnsembleAnnotator, - ) # type: ignore from modelgauge.annotators.llama_3_70b_chat_annotator import ( LLAMA_3_70B_PE_TAMALPAIS_2024_09_09_CONFIG as LLAMA_3_70B_CONFIG, ) # type: ignore @@ -34,6 +31,7 @@ WILDGUARD_ANNOTATOR_CONFIG, WildguardAnnotator, ) + from modelgauge.ensemble.majority_vote_ensemble_strategy import MajorityVoteEnsembleStrategy # type: ignore from modelgauge.safety_model_response import SafetyModelResponse # type: ignore except: raise NotImplementedError("Private annotators are not available.") @@ -51,6 +49,8 @@ class EnsembleAnnotatorSet(AnnotatorSet): def __init__(self, secrets): self.secrets = secrets + # TODO: Pass in the strategy as a parameter for easy swapping. + self.strategy = MajorityVoteEnsembleStrategy() self.__configure_vllm_annotators() self.__configure_huggingface_annotators() self.__configure_together_annotators() @@ -88,12 +88,13 @@ def __configure_together_annotators(self): def evaluate(self, item: TestItemAnnotations) -> Dict[str, float]: annotated_completion = item.interactions[0].response.completions[0] - individual_annotations = [ - annotated_completion.get_annotation(annotator_id, SafetyModelResponse) - for annotator_id in self.annotators - if annotator_id in annotated_completion.annotations - ] - annotation = EnsembleAnnotator.simple_join_evaluator_responses(individual_annotations) + individual_annotations = {} + for annotator_id in self.annotators: + if annotator_id in annotated_completion.annotations: + individual_annotations[annotator_id] = annotated_completion.get_annotation( + annotator_id, SafetyModelResponse + ) + annotation = self.strategy.compute_response(individual_annotations) return {"is_safe": float(annotation.is_safe)}