diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 00000000..f0d9287b --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,34 @@ +name: TestPyPI Publish + +on: + push: + branches: + - main + +jobs: + test_pypi_publish: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.x + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y python3-pip + pip install pipenv + pipenv install twine + + - name: Build and Publish + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }} + REPOSITORY_URL: https://test.pypi.org/legacy/ + run: | + python setup.py sdist bdist_wheel + pipenv run twine upload --repository-url $REPOSITORY_URL dist/* diff --git a/examples/demo/silero_vad.py b/examples/demo/silero_vad.py deleted file mode 100644 index fa678e70..00000000 --- a/examples/demo/silero_vad.py +++ /dev/null @@ -1,237 +0,0 @@ -import logging -import queue -import time -import torch -import numpy as np -import soundfile -from argparse import Namespace, ArgumentParser -from simuleval.agents import SpeechToSpeechAgent, AgentStates -from simuleval.agents.actions import WriteAction, ReadAction -from simuleval.data.segments import Segment, EmptySegment, SpeechSegment - -logger = logging.getLogger() - - -class SileroVADStates(AgentStates): - def __init__(self, args): - self.model, utils = torch.hub.load( - repo_or_dir="snakers4/silero-vad", - model="silero_vad", - force_reload=False, - onnx=False, - ) - - ( - self.get_speech_timestamps, - self.save_audio, - self.read_audio, - self.VADIterator, - self.collect_chunks, - ) = utils - self.silence_limit_ms = args.silence_limit_ms - self.window_size_samples = args.window_size_samples - self.chunk_size_samples = args.chunk_size_samples - self.sample_rate = args.sample_rate - self.debug = args.debug - self.test_input_segments_wav = None - self.debug_log(args) - self.input_queue: queue.Queue[Segment] = queue.Queue() - self.next_input_queue: queue.Queue[Segment] = queue.Queue() - super().__init__() - - def clear_queues(self): - self.debug_log(f"clearing {self.input_queue.qsize()} chunks") - while not self.input_queue.empty(): - self.input_queue.get_nowait() - self.input_queue.task_done() - self.debug_log(f"moving {self.next_input_queue.qsize()} chunks") - # move everything from next_input_queue to input_queue - while not self.next_input_queue.empty(): - chunk = self.next_input_queue.get_nowait() - self.next_input_queue.task_done() - self.input_queue.put_nowait(chunk) - - def reset(self) -> None: - super().reset() - # TODO: in seamless_server, report latency for each new segment - self.first_input_ts = None - self.silence_acc_ms = 0 - self.input_chunk = np.empty(0, dtype=np.int16) - self.is_fresh_state = True - self.clear_queues() - self.model.reset_states() - - def get_speech_prob_from_np_float32(self, segment: np.ndarray): - t = torch.from_numpy(segment) - speech_probs = [] - # print("len(t): ", len(t)) - for i in range(0, len(t), self.window_size_samples): - chunk = t[i : i + self.window_size_samples] - if len(chunk) < self.window_size_samples: - break - speech_prob = self.model(chunk, self.sample_rate).item() - speech_probs.append(speech_prob) - return speech_probs - - def debug_log(self, m): - if self.debug: - logger.info(m) - - def process_speech(self, segment): - """ - Process a full or partial speech chunk - """ - queue = self.input_queue - if self.source_finished: - # current source is finished, but next speech starts to come in already - self.debug_log("use next_input_queue") - queue = self.next_input_queue - - # NOTE: we don't reset silence_acc_ms here so that once an utterance - # becomes longer (accumulating more silence), it has a higher chance - # of being segmented. - # self.silence_acc_ms = 0 - - if self.first_input_ts is None: - self.first_input_ts = time.time() * 1000 - - while len(segment) > 0: - # add chunks to states.buffer - i = self.chunk_size_samples - len(self.input_chunk) - self.input_chunk = np.concatenate((self.input_chunk, segment[:i])) - segment = segment[i:] - self.is_fresh_state = False - if len(self.input_chunk) == self.chunk_size_samples: - queue.put_nowait( - SpeechSegment(content=self.input_chunk, finished=False) - ) - self.input_chunk = np.empty(0, dtype=np.int16) - - def check_silence_acc(self): - if self.silence_acc_ms >= self.silence_limit_ms: - self.silence_acc_ms = 0 - if self.input_chunk.size > 0: - # flush partial input_chunk - self.input_queue.put_nowait( - SpeechSegment(content=self.input_chunk, finished=True) - ) - self.input_chunk = np.empty(0, dtype=np.int16) - self.input_queue.put_nowait(EmptySegment(finished=True)) - self.source_finished = True - - def update_source(self, segment: np.ndarray): - speech_probs = self.get_speech_prob_from_np_float32(segment) - chunk_size_ms = len(segment) * 1000 / self.sample_rate - self.debug_log( - f"{chunk_size_ms}, {len(speech_probs)} {[round(s, 2) for s in speech_probs]}" - ) - window_size_ms = self.window_size_samples * 1000 / self.sample_rate - if all(i <= 0.5 for i in speech_probs): - if self.source_finished: - return - self.debug_log("got silent chunk") - if not self.is_fresh_state: - self.silence_acc_ms += chunk_size_ms - self.check_silence_acc() - return - elif speech_probs[-1] <= 0.5: - self.debug_log("=== start of silence chunk") - # beginning = speech, end = silence - # pass to process_speech and accumulate silence - self.process_speech(segment) - # accumulate contiguous silence - for i in range(len(speech_probs) - 1, -1, -1): - if speech_probs[i] > 0.5: - break - self.silence_acc_ms += window_size_ms - self.check_silence_acc() - elif speech_probs[0] <= 0.5: - self.debug_log("=== start of speech chunk") - # beginning = silence, end = speech - # accumulate silence , pass next to process_speech - for i in range(0, len(speech_probs)): - if speech_probs[i] > 0.5: - break - self.silence_acc_ms += window_size_ms - self.check_silence_acc() - self.process_speech(segment) - else: - self.debug_log("======== got speech chunk") - self.process_speech(segment) - - def debug_write_wav(self, chunk): - if self.test_input_segments_wav is not None: - self.test_input_segments_wav.seek(0, soundfile.SEEK_END) - self.test_input_segments_wav.write(chunk) - - -class SileroVADAgent(SpeechToSpeechAgent): - def __init__(self, args: Namespace) -> None: - super().__init__(args) - self.chunk_size_samples = args.chunk_size_samples - self.args = args - - @staticmethod - def add_args(parser: ArgumentParser): - parser.add_argument( - "--sample-rate", - default=16000, - type=float, - ) - parser.add_argument( - "--window-size-samples", - default=512, # sampling_rate // 1000 * 32 => 32 ms at 16000 sample rate - type=int, - help="Window size for passing samples to VAD", - ) - parser.add_argument( - "--chunk-size-samples", - default=5120, # sampling_rate // 1000 * 320 => 320 ms at 16000 sample rate - type=int, - help="Chunk size for passing samples to model", - ) - parser.add_argument( - "--silence-limit-ms", - default=700, - type=int, - help="send EOS to the input_queue after this amount of silence", - ) - parser.add_argument( - "--debug", - default=False, - type=bool, - help="Enable debug logs", - ) - - def build_states(self) -> SileroVADStates: - return SileroVADStates(self.args) - - def policy(self, states: SileroVADStates): - states.debug_log( - f"queue size: {states.input_queue.qsize()}, input_chunk size: {len(states.input_chunk)}" - ) - content = np.empty(0, dtype=np.int16) - is_finished = states.source_finished - while not states.input_queue.empty(): - chunk = states.input_queue.get_nowait() - states.input_queue.task_done() - content = np.concatenate((content, chunk.content)) - - states.debug_write_wav(content) - if is_finished: - states.debug_write_wav(np.zeros(16000)) - - if len(content) == 0: # empty queue - if not states.source_finished: - return ReadAction() - else: - # NOTE: this should never happen, this logic is a safeguard - segment = EmptySegment(finished=True) - else: - segment = SpeechSegment( - content=content.tolist(), - finished=is_finished, - sample_rate=states.sample_rate, - ) - - return WriteAction(segment, finished=is_finished) diff --git a/examples/speech_to_speech/english_counter_agent.py b/examples/speech_to_speech/english_counter_agent.py deleted file mode 100644 index 66e8856c..00000000 --- a/examples/speech_to_speech/english_counter_agent.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional -from simuleval.agents.states import AgentStates -from simuleval.utils import entrypoint -from simuleval.data.segments import SpeechSegment -from simuleval.agents import SpeechToSpeechAgent -from simuleval.agents.actions import WriteAction, ReadAction -from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub -from fairseq.models.text_to_speech.hub_interface import TTSHubInterface - - -class TTSModel: - def __init__(self): - models, cfg, task = load_model_ensemble_and_task_from_hf_hub( - "facebook/fastspeech2-en-ljspeech", - arg_overrides={"vocoder": "hifigan", "fp16": False}, - ) - TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) - self.tts_generator = task.build_generator(models, cfg) - self.tts_task = task - self.tts_model = models[0] - self.tts_model.to("cpu") - self.tts_generator.vocoder.to("cpu") - - def synthesize(self, text): - sample = TTSHubInterface.get_model_input(self.tts_task, text) - if sample["net_input"]["src_lengths"][0] == 0: - return [], 0 - for key in sample["net_input"].keys(): - if sample["net_input"][key] is not None: - sample["net_input"][key] = sample["net_input"][key].to("cpu") - - wav, rate = TTSHubInterface.get_prediction( - self.tts_task, self.tts_model, self.tts_generator, sample - ) - wav = wav.tolist() - return wav, rate - - -@entrypoint -class EnglishSpeechCounter(SpeechToSpeechAgent): - """ - Incrementally feed text to this offline Fastspeech2 TTS model, - with a minimum numbers of phonemes every chunk. - """ - - def __init__(self, args): - super().__init__(args) - self.wait_seconds = args.wait_seconds - self.tts_model = TTSModel() - - @staticmethod - def add_args(parser): - parser.add_argument("--wait-seconds", default=1, type=int) - - def policy(self, states: Optional[AgentStates] = None): - if states is None: - states = self.states - if states.source_sample_rate == 0: - # empty source, source_sample_rate not set yet - length_in_seconds = 0 - else: - length_in_seconds = round(len(states.source) / states.source_sample_rate) - if not states.source_finished and length_in_seconds < self.wait_seconds: - return ReadAction() - samples, fs = self.tts_model.synthesize(f"{length_in_seconds} mississippi") - - # A SpeechSegment has to be returned for speech-to-speech translation system - return WriteAction( - SpeechSegment( - content=samples, - sample_rate=fs, - finished=states.source_finished, - ), - finished=states.source_finished, - ) diff --git a/examples/speech_to_speech_demo/english_counter_pipeline.py b/examples/speech_to_speech_demo/english_counter_pipeline.py deleted file mode 100644 index d9baf2ad..00000000 --- a/examples/speech_to_speech_demo/english_counter_pipeline.py +++ /dev/null @@ -1,10 +0,0 @@ -from simuleval.agents import AgentPipeline -from examples.demo.silero_vad import SileroVADAgent -from examples.speech_to_speech.english_counter_agent import EnglishSpeechCounter - - -class EnglishCounterAgentPipeline(AgentPipeline): - pipeline = [ - SileroVADAgent, - EnglishSpeechCounter, - ] diff --git a/examples/speech_to_speech_demo/readme.md b/examples/speech_to_speech_demo/readme.md deleted file mode 100644 index 7f9c8071..00000000 --- a/examples/speech_to_speech_demo/readme.md +++ /dev/null @@ -1,11 +0,0 @@ -Running the demo: -1. Create a directory for the dummy model: `models/$DUMMY_MODEL` -2. Create a new yaml file `models/$DUMMY_MODEL/vad_main.yaml`, with the following: -``` -agent_class: examples.speech_to_speech_demo.english_counter_pipeline.EnglishCounterAgentPipeline -``` -3. Set the available agent in `SimulevalAgentDirectory.py` to `$DUMMY_MODEL` -4. Run `python app.py` - - -- Note: If you get an ImportError for `examples.speech_to_speech_demo`, run `python -c "import examples; print(examples.__file__)"`. If the file is something like `$PREFIX/site-packages/examples/__init__.py`, run `rm -r $PREFIX/site-packages/examples` and try again. \ No newline at end of file diff --git a/examples/speech_to_text/counter_in_tgt_lang_agent.py b/examples/speech_to_text/counter_in_tgt_lang_agent.py index 8cf8a4e5..eb0a968a 100644 --- a/examples/speech_to_text/counter_in_tgt_lang_agent.py +++ b/examples/speech_to_text/counter_in_tgt_lang_agent.py @@ -14,13 +14,16 @@ class CounterInTargetLanguage(SpeechToTextAgent): def __init__(self, args): super().__init__(args) self.wait_seconds = args.wait_seconds - self.tgt_lang = args.tgt_lang + # if args is not None: + # with open(args.tgt_lang, "r") as file: + # tgt_lang = file.read() + # self.tgt_lang = tgt_lang @staticmethod def add_args(parser): parser.add_argument("--wait-seconds", default=1, type=int) parser.add_argument( - "--tgt-lang", default="en", type=str, choices=["en", "es", "de"] + "--tgt-lang" ) def policy(self, states: Optional[AgentStates] = None): @@ -35,11 +38,12 @@ def policy(self, states: Optional[AgentStates] = None): return ReadAction() prediction = f"{length_in_seconds} " - if self.tgt_lang == "en": + tgt_lang = states.tgt_lang + if tgt_lang == "en": prediction += "seconds" - elif self.tgt_lang == "es": + elif tgt_lang == "es": prediction += "segundos" - elif self.tgt_lang == "de": + elif tgt_lang == "de": prediction += "sekunden" else: prediction += "" @@ -47,4 +51,4 @@ def policy(self, states: Optional[AgentStates] = None): return WriteAction( content=prediction, finished=states.source_finished, - ) + ) \ No newline at end of file diff --git a/examples/speech_to_text/english_counter_agent.py b/examples/speech_to_text/english_counter_agent.py deleted file mode 100644 index 5f0e9644..00000000 --- a/examples/speech_to_text/english_counter_agent.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Optional -from simuleval.agents.states import AgentStates -from simuleval.utils import entrypoint -from simuleval.agents import SpeechToTextAgent -from simuleval.agents.actions import WriteAction, ReadAction - - -@entrypoint -class EnglishSpeechCounter(SpeechToTextAgent): - """ - The agent generate the number of seconds from an input audio. - """ - - def __init__(self, args): - super().__init__(args) - self.wait_seconds = args.wait_seconds - - @staticmethod - def add_args(parser): - parser.add_argument("--wait-seconds", default=1, type=int) - - def policy(self, states: Optional[AgentStates] = None): - if states is None: - states = self.states - if states.source_sample_rate == 0: - # empty source, source_sample_rate not set yet - length_in_seconds = 0 - else: - length_in_seconds = round(len(states.source) / states.source_sample_rate) - if not states.source_finished and length_in_seconds < self.wait_seconds: - return ReadAction() - - prediction = f"{length_in_seconds} second" - - return WriteAction( - content=prediction, - finished=states.source_finished, - ) diff --git a/examples/speech_to_text/eval.sh b/examples/speech_to_text/eval.sh old mode 100644 new mode 100755 index 9c81178b..0c26c673 --- a/examples/speech_to_text/eval.sh +++ b/examples/speech_to_text/eval.sh @@ -1,5 +1,7 @@ simuleval \ - --agent english_counter_agent.py \ + --agent counter_in_tgt_lang_agent.py \ --source-segment-size 1000 \ --source source.txt --target reference/en.txt \ + --tgt-lang reference/tgt_lang.txt \ --output output + \ No newline at end of file diff --git a/examples/speech_to_text/reference/tgt_lang.txt b/examples/speech_to_text/reference/tgt_lang.txt new file mode 100644 index 00000000..2c4c454f --- /dev/null +++ b/examples/speech_to_text/reference/tgt_lang.txt @@ -0,0 +1 @@ +en \ No newline at end of file diff --git a/examples/speech_to_text_demo/english_counter_pipeline.py b/examples/speech_to_text_demo/english_counter_pipeline.py deleted file mode 100644 index 5b84d3db..00000000 --- a/examples/speech_to_text_demo/english_counter_pipeline.py +++ /dev/null @@ -1,10 +0,0 @@ -from simuleval.agents import AgentPipeline -from examples.demo.silero_vad import SileroVADAgent -from examples.speech_to_text.english_counter_agent import EnglishSpeechCounter - - -class EnglishCounterAgentPipeline(AgentPipeline): - pipeline = [ - SileroVADAgent, - EnglishSpeechCounter, - ] diff --git a/examples/speech_to_text_demo/readme.md b/examples/speech_to_text_demo/readme.md deleted file mode 100644 index 1f5441a6..00000000 --- a/examples/speech_to_text_demo/readme.md +++ /dev/null @@ -1,11 +0,0 @@ -Running the demo: -1. Create a directory for the dummy model: `models/$DUMMY_MODEL` -2. Create a new yaml file `models/$DUMMY_MODEL/vad_main.yaml`, with the following: -``` -agent_class: examples.speech_to_text_demo.english_counter_pipeline.EnglishCounterAgentPipeline -``` -3. Set the available agent in `SimulevalAgentDirectory.py` to `$DUMMY_MODEL` -4. Run `python app.py` - - -- Note: If you get an ImportError for `examples.speech_to_text_demo`, run `python -c "import examples; print(examples.__file__)"`. If the file is something like `$PREFIX/site-packages/examples/__init__.py`, run `rm -r $PREFIX/site-packages/examples` and try again. \ No newline at end of file diff --git a/simuleval/agents/agent.py b/simuleval/agents/agent.py index c1e7b183..cb4659a5 100644 --- a/simuleval/agents/agent.py +++ b/simuleval/agents/agent.py @@ -184,6 +184,7 @@ class SpeechToTextAgent(GenericAgent): source_type: str = "speech" target_type: str = "text" + tgt_lang: str = None class SpeechToSpeechAgent(GenericAgent): diff --git a/simuleval/agents/states.py b/simuleval/agents/states.py index 2c01e8e7..8a3b5652 100644 --- a/simuleval/agents/states.py +++ b/simuleval/agents/states.py @@ -29,6 +29,7 @@ def reset(self) -> None: self.target_finished = False self.source_sample_rate = 0 self.target_sample_rate = 0 + self.tgt_lang = None def update_source(self, segment: Segment): """ @@ -45,6 +46,7 @@ def update_source(self, segment: Segment): elif isinstance(segment, SpeechSegment): self.source += segment.content self.source_sample_rate = segment.sample_rate + self.tgt_lang = segment.tgt_lang else: raise NotImplementedError diff --git a/simuleval/data/dataloader/__init__.py b/simuleval/data/dataloader/__init__.py index e4c1dea8..90faaa18 100644 --- a/simuleval/data/dataloader/__init__.py +++ b/simuleval/data/dataloader/__init__.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. import logging +from argparse import Namespace + from .dataloader import ( # noqa GenericDataloader, register_dataloader, @@ -21,7 +23,7 @@ logger = logging.getLogger("simuleval.dataloader") -def build_dataloader(args) -> GenericDataloader: +def build_dataloader(args: Namespace) -> GenericDataloader: dataloader_key = getattr(args, "dataloader", None) if dataloader_key is not None: assert dataloader_key in DATALOADER_DICT, f"{dataloader_key} is not defined" diff --git a/simuleval/data/dataloader/dataloader.py b/simuleval/data/dataloader/dataloader.py index 518ed751..57ca5af8 100644 --- a/simuleval/data/dataloader/dataloader.py +++ b/simuleval/data/dataloader/dataloader.py @@ -37,7 +37,7 @@ class GenericDataloader: """ def __init__( - self, source_list: List[str], target_list: Union[List[str], List[None]] + self, source_list: List[str], target_list: Union[List[str], List[None]], ) -> None: self.source_list = source_list self.target_list = target_list @@ -53,7 +53,9 @@ def get_target(self, index: int) -> Any: return self.preprocess_target(self.target_list[index]) def __getitem__(self, index: int) -> Dict[str, Any]: - return {"source": self.get_source(index), "target": self.get_target(index)} + return {"source": self.get_source(index), + "target": self.get_target(index), + } def preprocess_source(self, source: Any) -> Any: raise NotImplementedError @@ -95,3 +97,4 @@ def add_args(parser: ArgumentParser): default=1, help="Source segment size, For text the unit is # token, for speech is ms", ) + \ No newline at end of file diff --git a/simuleval/data/dataloader/s2t_dataloader.py b/simuleval/data/dataloader/s2t_dataloader.py index e5c238cc..4da1f084 100644 --- a/simuleval/data/dataloader/s2t_dataloader.py +++ b/simuleval/data/dataloader/s2t_dataloader.py @@ -58,6 +58,10 @@ def get_video_id(url): @register_dataloader("speech-to-text") class SpeechToTextDataloader(GenericDataloader): + def __init__(self, source_list: List[str], target_list: List[str], tgt_lang): + super().__init__(source_list, target_list) + self.tgt_lang = tgt_lang + def preprocess_source(self, source: Union[Path, str]) -> List[float]: assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed." samples, _ = soundfile.read(source, dtype="float32") @@ -75,58 +79,66 @@ def get_source_audio_path(self, index: int): @classmethod def from_files( - cls, source: Union[Path, str], target: Union[Path, str] + cls, source: Union[Path, str], target: Union[Path, str], + tgt_lang: str ) -> SpeechToTextDataloader: with open(source) as f: source_list = [line.strip() for line in f] with open(target) as f: target_list = [line.strip() for line in f] - dataloader = cls(source_list, target_list) + with open(tgt_lang, "r") as f: + tgt_lang = f.read() + print(type(tgt_lang)) + print(tgt_lang) + dataloader = cls(source_list, target_list, tgt_lang) return dataloader @classmethod def from_args(cls, args: Namespace): args.source_type = "speech" args.target_type = "text" - return cls.from_files(args.source, args.target) + return cls.from_files(args.source, args.target, args.tgt_lang) @register_dataloader("speech-to-speech") class SpeechToSpeechDataloader(SpeechToTextDataloader): @classmethod def from_files( - cls, source: Union[Path, str], target: Union[Path, str] + cls, source: Union[Path, str], target: Union[Path, str], + tgt_lang: str ) -> SpeechToSpeechDataloader: with open(source) as f: source_list = [line.strip() for line in f] with open(target) as f: target_list = [line.strip() for line in f] - dataloader = cls(source_list, target_list) + with open(tgt_lang) as f: + tgt_lang = [line.strip() for line in f] + dataloader = cls(source_list, target_list, tgt_lang) return dataloader @classmethod def from_args(cls, args: Namespace): args.source_type = "speech" args.target_type = "speech" - return cls.from_files(args.source, args.target) + return cls.from_files(args.source, args.target, args.tgt_lang) @register_dataloader("youtube-to-text") class YoutubeToTextDataloader(SpeechToTextDataloader): @classmethod def from_youtube( - cls, source: Union[Path, str], target: Union[Path, str] + cls, source: Union[Path, str], target: Union[Path, str], tgt_lang: str ) -> YoutubeToTextDataloader: source_list = [download_youtube_video(source)] target_list = [target] - dataloader = cls(source_list, target_list) + dataloader = cls(source_list, target_list, tgt_lang) return dataloader @classmethod def from_args(cls, args: Namespace): args.source_type = "youtube" args.target_type = "text" - return cls.from_youtube(args.source, args.target) + return cls.from_youtube(args.source, args.target, args.tgt_lang) @register_dataloader("youtube-to-speech") @@ -135,4 +147,4 @@ class YoutubeToSpeechDataloader(YoutubeToTextDataloader): def from_args(cls, args: Namespace): args.source_type = "youtube" args.target_type = "speech" - return cls.from_youtube(args.source, args.target) + return cls.from_youtube(args.source, args.target, args.tgt_lang) diff --git a/simuleval/data/dataloader/t2t_dataloader.py b/simuleval/data/dataloader/t2t_dataloader.py index 16d5d414..18f91a1c 100644 --- a/simuleval/data/dataloader/t2t_dataloader.py +++ b/simuleval/data/dataloader/t2t_dataloader.py @@ -11,6 +11,7 @@ from simuleval.data.dataloader import register_dataloader from argparse import Namespace +tgt_lang = "en" @register_dataloader("text-to-text") class TextToTextDataloader(GenericDataloader): @@ -33,7 +34,7 @@ def preprocess_target(self, target: str) -> List: @classmethod def from_files( - cls, source: Union[Path, str], target: Optional[Union[Path, str]] + cls, source: Union[Path, str], target: Optional[Union[Path, str]], tgt_lang ) -> TextToTextDataloader: assert source with open(source) as f: @@ -43,11 +44,12 @@ def from_files( target_list = f.readlines() else: target_list = [None for _ in source_list] - dataloader = cls(source_list, target_list) + dataloader = cls(source_list, target_list, tgt_lang) return dataloader @classmethod def from_args(cls, args: Namespace): args.source_type = "text" args.target_type = "text" - return cls.from_files(args.source, args.target) + tgt_lang = tgt_lang + return cls.from_files(args.source, args.target, tgt_lang) diff --git a/simuleval/data/segments.py b/simuleval/data/segments.py index c823695e..d0f25b4b 100644 --- a/simuleval/data/segments.py +++ b/simuleval/data/segments.py @@ -15,6 +15,7 @@ class Segment: finished: bool = False is_empty: bool = False data_type: str = None + tgt_lang: str = None def json(self) -> str: info_dict = {attribute: value for attribute, value in self.__dict__.items()} @@ -34,12 +35,14 @@ class EmptySegment(Segment): class TextSegment(Segment): content: str = "" data_type: str = "text" + tgt_lang: str = "" @dataclass class SpeechSegment(Segment): sample_rate: int = -1 data_type: str = "speech" + tgt_lang: str = "" def segment_from_json_string(string: str): @@ -49,4 +52,4 @@ def segment_from_json_string(string: str): elif info_dict["data_type"] == "speech": return SpeechSegment.from_json(string) else: - return EmptySegment.from_json(string) + return EmptySegment.from_json(string) \ No newline at end of file diff --git a/simuleval/evaluator/instance.py b/simuleval/evaluator/instance.py index c08fd978..48d267be 100644 --- a/simuleval/evaluator/instance.py +++ b/simuleval/evaluator/instance.py @@ -232,13 +232,15 @@ def __init__( self, index: int, dataloader: Optional[SpeechToTextDataloader], - args: Optional[Namespace], + args: Optional[Namespace], ): super().__init__(index, dataloader, args) + self.args = args self.sample_rate_value = None self.sample_list = None self.source_finished_reading = False self.dataloader: SpeechToTextDataloader + self.tgt_lang = self.dataloader.tgt_lang @property def sample_rate(self): @@ -282,6 +284,7 @@ def send_source(self, segment_size=10): content=samples, sample_rate=self.audio_info.samplerate, finished=is_finished, + tgt_lang=self.tgt_lang ) else: @@ -454,3 +457,4 @@ def __init__(self, info: str) -> None: self.source_length = self.info.get("source_length") # just for testing! self.finish_prediction = True self.metrics = {} + \ No newline at end of file