Skip to content

Commit

Permalink
add ability to run new/multiple benchmarks (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
dhosterman authored May 24, 2024
1 parent 91f71f5 commit 1488950
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
15 changes: 11 additions & 4 deletions src/modelbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

from modelbench.benchmarks import (
BenchmarkDefinition,
GeneralPurposeAiChatBenchmark,
)
from modelbench.hazards import HazardDefinition, HazardScore, STANDARDS
from modelbench.modelgauge_runner import ModelGaugeSut, SutDescription
Expand Down Expand Up @@ -64,7 +63,7 @@ def cli() -> None:
print()


@cli.command(help="run the standard benchmark")
@cli.command(help="run a benchmark")
@click.option(
"--output-dir", "-o", default="./web", type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path)
)
Expand All @@ -77,9 +76,17 @@ def cli() -> None:
)
@click.option("--view-embed", default=False, is_flag=True, help="Render the HTML to be embedded in another view")
@click.option("--anonymize", type=int, help="Random number seed for consistent anonymization of SUTs")
@click.option("--parallel", default=False, help="experimentally run SUTs in parallel")
@click.option("--parallel", default=False, help="Experimentally run SUTs in parallel")
@click.option(
"--benchmark",
type=click.Choice([c.__name__ for c in BenchmarkDefinition.__subclasses__()]),
default=["GeneralPurposeAiChatBenchmark"],
help="Benchmark to run (Default: GeneralPurposeAiChatBenchmark)",
multiple=True,
)
@local_plugin_dir_option
def benchmark(
benchmark: str,
output_dir: pathlib.Path,
max_instances: int,
debug: bool,
Expand All @@ -89,7 +96,7 @@ def benchmark(
parallel=False,
) -> None:
suts = find_suts_for_sut_argument(sut)
benchmarks = [GeneralPurposeAiChatBenchmark()]
benchmarks = [b() for b in BenchmarkDefinition.__subclasses__() if b.__name__ in benchmark]
benchmark_scores = score_benchmarks(benchmarks, suts, max_instances, debug, parallel)
generate_content(benchmark_scores, output_dir, anonymize, view_embed)

Expand Down
45 changes: 42 additions & 3 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import json
import pathlib
import unittest.mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import click
import pytest
from click.testing import CliRunner

from modelbench.benchmarks import BenchmarkDefinition
from modelbench.hazards import HazardScore, SafeCbrHazard
from modelbench.hazards import SafeHazard
from modelbench.modelgauge_runner import ModelGaugeSut
from modelbench.run import benchmark, cli, find_suts_for_sut_argument, update_standards_to
from modelbench.scoring import ValueEstimate
from modelbench.hazards import HazardScore, SafeCbrHazard
from modelbench.run import update_standards_to, find_suts_for_sut_argument


@patch("modelbench.run.run_tests")
Expand Down Expand Up @@ -44,3 +47,39 @@ def test_find_suts():

with pytest.raises(click.BadParameter):
find_suts_for_sut_argument(["something nonexistent"])


class TestCli:

@pytest.fixture(autouse=True)
def mock_score_benchmarks(self, monkeypatch):
import modelbench

mock_obj = MagicMock()

monkeypatch.setattr(modelbench.run, "score_benchmarks", mock_obj)
return mock_obj

@pytest.fixture(autouse=True)
def do_not_make_static_site(self, monkeypatch):
import modelbench

monkeypatch.setattr(modelbench.run, "generate_content", MagicMock())

@pytest.fixture
def runner(self):
return CliRunner()

def test_nonexistent_benchmarks_can_not_be_called(self, runner):
result = runner.invoke(cli, ["benchmark", "--benchmark", "NotARealBenchmark"])
assert result.exit_code == 2
assert "Invalid value for '--benchmark'" in result.output

def test_calls_score_benchmark_with_correct_benchmark(self, runner, mock_score_benchmarks):
class MyBenchmark(BenchmarkDefinition):
def __init__(self):
super().__init__([c() for c in SafeHazard.__subclasses__()])

cli.commands["benchmark"].params[-2].type.choices += ["MyBenchmark"]
result = runner.invoke(cli, ["benchmark", "--benchmark", "MyBenchmark"])
assert isinstance(mock_score_benchmarks.call_args.args[0][0], MyBenchmark)

0 comments on commit 1488950

Please sign in to comment.