From 1ced99aea359765fbe61a269c5e7c7485c13dcb0 Mon Sep 17 00:00:00 2001 From: Anna Sun <13106449+annasun28@users.noreply.github.com> Date: Wed, 30 Aug 2023 18:51:55 -0400 Subject: [PATCH] add dummy segmenter agent --- examples/quick_start/spm_detokenizer_agent.py | 33 +++++++++++++++++++ examples/quick_start/spm_source.txt | 1 + examples/quick_start/spm_target.txt | 1 + simuleval/test/test_agent.py | 28 ++++++++++++++++ 4 files changed, 63 insertions(+) create mode 100644 examples/quick_start/spm_source.txt create mode 100644 examples/quick_start/spm_target.txt diff --git a/examples/quick_start/spm_detokenizer_agent.py b/examples/quick_start/spm_detokenizer_agent.py index 06a9f571..2cf6bee4 100644 --- a/examples/quick_start/spm_detokenizer_agent.py +++ b/examples/quick_start/spm_detokenizer_agent.py @@ -4,9 +4,38 @@ from simuleval.agents import TextToTextAgent from simuleval.agents.actions import ReadAction, WriteAction +from simuleval.agents.pipeline import AgentPipeline from simuleval.agents.states import AgentStates +class DummySegmentAgent(TextToTextAgent): + """ + This agent just splits on space + """ + def __init__(self, args): + super().__init__(args) + self.segment_k = args.segment_k + + @classmethod + def from_args(cls, args, **kwargs): + return cls(args) + + def add_args(parser: ArgumentParser): + parser.add_argument( + "--segment-k", + type=int, + help="Output segments with this many words", + required=True, + ) + + def policy(self, states: AgentStates): + if len(states.source) == self.segment_k or states.source_finished: + out = " ".join(states.source) + states.source = [] + return WriteAction(out, finished=states.source_finished) + return ReadAction() + + class SentencePieceModelDetokenizerAgent(TextToTextAgent): def __init__(self, args): super().__init__(args) @@ -59,3 +88,7 @@ def policy(self, states: AgentStates): return WriteAction(" ".join(full_words), finished=False) else: return ReadAction() + + +class DummyPipeline(AgentPipeline): + pipeline = [DummySegmentAgent, SentencePieceModelDetokenizerAgent] \ No newline at end of file diff --git a/examples/quick_start/spm_source.txt b/examples/quick_start/spm_source.txt new file mode 100644 index 00000000..fe099b99 --- /dev/null +++ b/examples/quick_start/spm_source.txt @@ -0,0 +1 @@ +▁Let ' s ▁do ▁it ▁with out ▁hesitation . \ No newline at end of file diff --git a/examples/quick_start/spm_target.txt b/examples/quick_start/spm_target.txt new file mode 100644 index 00000000..e3b2ae21 --- /dev/null +++ b/examples/quick_start/spm_target.txt @@ -0,0 +1 @@ +Let's do it without hesitation. \ No newline at end of file diff --git a/simuleval/test/test_agent.py b/simuleval/test/test_agent.py index cf298a2c..a561949c 100644 --- a/simuleval/test/test_agent.py +++ b/simuleval/test/test_agent.py @@ -101,3 +101,31 @@ def test_spm_detokenizer_agent(detokenize_only): else: assert output == ["Let's do it", "without hesitation."] assert delays == [1, 1, 1, 2, 2] + + +@pytest.mark.parametrize("detokenize_only", [True, False]) +def test_spm_detokenizer_agent_pipeline(detokenize_only, root_path=ROOT_PATH): + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer_file = f"{tmpdirname}/tokenizer.model" + tokenizer_url = "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model" + urllib.request.urlretrieve(tokenizer_url, tokenizer_file) + + cli.sys.argv[1:] = [ + "--user-dir", + os.path.join(root_path, "examples"), + "--agent-class", + "examples.quick_start.spm_detokenizer_agent.DummyPipeline", + "--source", + os.path.join(root_path, "examples", "quick_start", "spm_source.txt"), + "--target", + os.path.join(root_path, "examples", "quick_start", "spm_target.txt"), + "--output", + tmpdirname, + "--segment-k", + "3", + "--sentencepiece-model", + tokenizer_file, + ] + if detokenize_only: + cli.sys.argv.append("--detokenize-only") + cli.main()