Skip to content

Commit

Permalink
Unit tests (#57)
Browse files Browse the repository at this point in the history
* added new spanish unit test

* new unit test

* formatting

* get rid of ds store

* add to gitignore ds_store

* empty commit

* format

* formatting

* more formatting

* transfer unit test to new folder

* change def policy in test_s2s

* delete eval2.sh

* added stateless case

* added test_s2s to main.yml

* formatting

* added fairseq download

* install huggingface-hub

* install packages
  • Loading branch information
mandyschen authored Aug 1, 2023
1 parent 075c4d3 commit f199b56
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 8 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ jobs:
sudo apt-get install libsndfile1
python -m pip install --upgrade pip
pip install flake8 pytest black
pip install g2p-en
pip install huggingface-hub
pip install fairseq
pip install -e .
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with black
Expand All @@ -44,3 +47,4 @@ jobs:
pytest simuleval/test/test_agent_pipeline.py
pytest simuleval/test/test_evaluator.py
pytest simuleval/test/test_remote_evaluation.py
pytest simuleval/test/test_s2s.py
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,6 @@ dmypy.json
# Cython debug symbols
cython_debug/
.vscode

# Mac files
.DS_Store
74 changes: 74 additions & 0 deletions examples/speech_to_speech/english_alternate_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from simuleval.utils import entrypoint
from simuleval.data.segments import SpeechSegment
from simuleval.agents import SpeechToSpeechAgent
from simuleval.agents.actions import WriteAction, ReadAction
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface


class TTSModel:
def __init__(self):
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"facebook/fastspeech2-en-ljspeech",
arg_overrides={"vocoder": "hifigan", "fp16": False},
)
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
self.tts_generator = task.build_generator(models, cfg)
self.tts_task = task
self.tts_model = models[0]
self.tts_model.to("cpu")
self.tts_generator.vocoder.to("cpu")

def synthesize(self, text):
sample = TTSHubInterface.get_model_input(self.tts_task, text)
if sample["net_input"]["src_lengths"][0] == 0:
return [], 0
for key in sample["net_input"].keys():
if sample["net_input"][key] is not None:
sample["net_input"][key] = sample["net_input"][key].to("cpu")

wav, rate = TTSHubInterface.get_prediction(
self.tts_task, self.tts_model, self.tts_generator, sample
)
wav = wav.tolist()
return wav, rate


@entrypoint
class EnglishAlternateAgent(SpeechToSpeechAgent):
"""
Incrementally feed text to this offline Fastspeech2 TTS model,
with an alternating speech pattern that is decrementing.
"""

def __init__(self, args):
super().__init__(args)
self.wait_seconds = args.wait_seconds
self.tts_model = TTSModel()

@staticmethod
def add_args(parser):
parser.add_argument("--wait-seconds", default=1, type=int)

def policy(self):
length_in_seconds = round(
len(self.states.source) / self.states.source_sample_rate
)
if not self.states.source_finished and length_in_seconds < self.wait_seconds:
return ReadAction()
if length_in_seconds % 2 == 0:
samples, fs = self.tts_model.synthesize(
f"{8 - length_in_seconds} even even"
)
else:
samples, fs = self.tts_model.synthesize(f"{8 - length_in_seconds} odd odd")

# A SpeechSegment has to be returned for speech-to-speech translation system
return WriteAction(
SpeechSegment(
content=samples,
sample_rate=fs,
finished=self.states.source_finished,
),
finished=self.states.source_finished,
)
4 changes: 1 addition & 3 deletions examples/speech_to_speech/english_counter_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def policy(self, states: Optional[AgentStates] = None):
# empty source, source_sample_rate not set yet
length_in_seconds = 0
else:
length_in_seconds = round(
len(states.source) / states.source_sample_rate
)
length_in_seconds = round(len(states.source) / states.source_sample_rate)
if not states.source_finished and length_in_seconds < self.wait_seconds:
return ReadAction()
samples, fs = self.tts_model.synthesize(f"{length_in_seconds} mississippi")
Expand Down
2 changes: 1 addition & 1 deletion examples/speech_to_speech_demo/english_counter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class EnglishCounterAgentPipeline(AgentPipeline):
pipeline = [
SileroVADAgent,
EnglishSpeechCounter,
]
]
4 changes: 1 addition & 3 deletions examples/speech_to_text/english_counter_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def policy(self, states: Optional[AgentStates] = None):
# empty source, source_sample_rate not set yet
length_in_seconds = 0
else:
length_in_seconds = round(
len(states.source) / states.source_sample_rate
)
length_in_seconds = round(len(states.source) / states.source_sample_rate)
if not states.source_finished and length_in_seconds < self.wait_seconds:
return ReadAction()

Expand Down
2 changes: 1 addition & 1 deletion examples/speech_to_text_demo/english_counter_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class EnglishCounterAgentPipeline(AgentPipeline):
pipeline = [
SileroVADAgent,
EnglishSpeechCounter,
]
]
93 changes: 93 additions & 0 deletions simuleval/test/test_s2s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import tempfile
from pathlib import Path
from typing import Optional
from simuleval.agents.states import AgentStates

import simuleval.cli as cli
from simuleval.agents import SpeechToSpeechAgent
from simuleval.agents.actions import ReadAction, WriteAction
from simuleval.data.segments import SpeechSegment

ROOT_PATH = Path(__file__).parents[2]


def test_s2s(root_path=ROOT_PATH):
args_path = Path.joinpath(root_path, "examples", "speech_to_speech")
os.chdir(args_path)
with tempfile.TemporaryDirectory() as tmpdirname:
cli.sys.argv[1:] = [
"--agent",
os.path.join(
root_path, "examples", "speech_to_speech", "english_alternate_agent.py"
),
"--user-dir",
os.path.join(root_path, "examples"),
"--agent-class",
"agents.EnglishAlternateAgent",
"--source-segment-size",
"1000",
"--source",
os.path.join(root_path, "examples", "speech_to_speech", "source.txt"),
"--target",
os.path.join(root_path, "examples", "speech_to_speech", "reference/en.txt"),
"--output",
tmpdirname,
]
cli.main()


def test_stateless_agent(root_path=ROOT_PATH):
class EnglishAlternateAgent(SpeechToSpeechAgent):
waitk = 0
wait_seconds = 3
vocab = [chr(i) for i in range(ord("A"), ord("Z") + 1)]

def policy(self, states: Optional[AgentStates] = None):
if states is None:
states = states

length_in_seconds = round(len(states.source) / states.source_sample_rate)
if (
not self.states.source_finished
and length_in_seconds < self.wait_seconds
):
return ReadAction()

if length_in_seconds % 2 == 0:
samples, fs = self.tts_model.synthesize(
f"{8 - length_in_seconds} even even"
)
else:
samples, fs = self.tts_model.synthesize(
f"{8 - length_in_seconds} odd odd"
)

prediction = f"{length_in_seconds} second"

return WriteAction(
SpeechSegment(
content=samples,
sample_rate=fs,
finished=self.states.source_finished,
),
content=prediction,
finished=self.states.source_finished,
)

args = None
agent_stateless = EnglishAlternateAgent.from_args(args)
agent_state = agent_stateless.build_states()
agent_stateful = EnglishAlternateAgent.from_args(args)

for _ in range(10):
segment = SpeechSegment(0, "A")
output_1 = agent_stateless.pushpop(segment, agent_state)
output_2 = agent_stateful.pushpop(segment)
assert output_1.content == output_2.content

0 comments on commit f199b56

Please sign in to comment.