Skip to content

Commit

Permalink
add dummy segmenter agent
Browse files Browse the repository at this point in the history
  • Loading branch information
annasun28 committed Aug 30, 2023
1 parent cd81265 commit 1ced99a
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
33 changes: 33 additions & 0 deletions examples/quick_start/spm_detokenizer_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,38 @@

from simuleval.agents import TextToTextAgent
from simuleval.agents.actions import ReadAction, WriteAction
from simuleval.agents.pipeline import AgentPipeline
from simuleval.agents.states import AgentStates


class DummySegmentAgent(TextToTextAgent):
"""
This agent just splits on space
"""
def __init__(self, args):
super().__init__(args)
self.segment_k = args.segment_k

@classmethod
def from_args(cls, args, **kwargs):
return cls(args)

def add_args(parser: ArgumentParser):
parser.add_argument(
"--segment-k",
type=int,
help="Output segments with this many words",
required=True,
)

def policy(self, states: AgentStates):
if len(states.source) == self.segment_k or states.source_finished:
out = " ".join(states.source)
states.source = []
return WriteAction(out, finished=states.source_finished)
return ReadAction()


class SentencePieceModelDetokenizerAgent(TextToTextAgent):
def __init__(self, args):
super().__init__(args)
Expand Down Expand Up @@ -59,3 +88,7 @@ def policy(self, states: AgentStates):
return WriteAction(" ".join(full_words), finished=False)
else:
return ReadAction()


class DummyPipeline(AgentPipeline):
pipeline = [DummySegmentAgent, SentencePieceModelDetokenizerAgent]
1 change: 1 addition & 0 deletions examples/quick_start/spm_source.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
▁Let ' s ▁do ▁it ▁with out ▁hesitation .
1 change: 1 addition & 0 deletions examples/quick_start/spm_target.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Let's do it without hesitation.
28 changes: 28 additions & 0 deletions simuleval/test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,31 @@ def test_spm_detokenizer_agent(detokenize_only):
else:
assert output == ["Let's do it", "without hesitation."]
assert delays == [1, 1, 1, 2, 2]


@pytest.mark.parametrize("detokenize_only", [True, False])
def test_spm_detokenizer_agent_pipeline(detokenize_only, root_path=ROOT_PATH):
with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer_file = f"{tmpdirname}/tokenizer.model"
tokenizer_url = "https://huggingface.co/facebook/seamless-m4t-large/resolve/main/tokenizer.model"
urllib.request.urlretrieve(tokenizer_url, tokenizer_file)

cli.sys.argv[1:] = [
"--user-dir",
os.path.join(root_path, "examples"),
"--agent-class",
"examples.quick_start.spm_detokenizer_agent.DummyPipeline",
"--source",
os.path.join(root_path, "examples", "quick_start", "spm_source.txt"),
"--target",
os.path.join(root_path, "examples", "quick_start", "spm_target.txt"),
"--output",
tmpdirname,
"--segment-k",
"3",
"--sentencepiece-model",
tokenizer_file,
]
if detokenize_only:
cli.sys.argv.append("--detokenize-only")
cli.main()

0 comments on commit 1ced99a

Please sign in to comment.