Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to switch output languages for multilingual models #69 #74

Merged
merged 85 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
89fbc24
Testing Circleci on main
Jul 19, 2023
93f9643
Merge branch 'main' of https://github.com/SamDewriter/SimulEval
Jul 19, 2023
fd41dc2
Testing Circleci on main
Jul 19, 2023
4becf02
Testing Circleci on main
Jul 19, 2023
cc893ef
Testing Circleci on main
Jul 19, 2023
a6f00f8
Testing Circleci on main
Jul 19, 2023
0f41351
correct Circle config
Jul 20, 2023
78d13d0
correct Circle config
Jul 20, 2023
98ae6ec
correct Circle config
Jul 20, 2023
cdffc8d
correct Circle config
Jul 20, 2023
cacfbc9
Revert "[demo] s2t + s2s agent pipelines (#58)"
Jul 20, 2023
f97cdfa
resolve branch changes
Aug 14, 2023
4ddf84a
add target language
Aug 15, 2023
69e5816
add target language as a parameter
Aug 15, 2023
812bcbc
Test dynamic language
Aug 18, 2023
233aa35
Switch language dynamically
Aug 18, 2023
f15634c
Add ability to switch output language (#69)
Aug 18, 2023
c06fb0c
Add tgt language argument
Aug 18, 2023
5e0165a
Add Namespace to args argument (#69)
Aug 18, 2023
c43e9da
Modify code to read target language from a file
Aug 18, 2023
bbe6a88
Add ability to switch input language
Aug 18, 2023
03c06b1
Add a tgt-lang file to test
Aug 18, 2023
9d97d18
Add tgt_lang to instance
Aug 20, 2023
3e5fbe6
Add tgt_lang to AgentStates
Aug 25, 2023
6d06a8e
States
Aug 25, 2023
50efb42
Add tgt_lang from state to test (#69)
Aug 25, 2023
1c07b35
Target language to test
Aug 25, 2023
0627451
Format with Black (#69)
Aug 28, 2023
3a63e16
Delete circleci (69)
Aug 28, 2023
71bc44e
Remove unused tgt_lang (#69)
Aug 28, 2023
38be852
Refactor tgt_lang
Aug 28, 2023
e3de9f1
Remove tgt_lang (#69)
Aug 28, 2023
8fd77e1
Remove tgt_lang from S2S and Y2T Dataloaders (#69)
Aug 28, 2023
7599842
Remove tgt_lang from S2S (#69)
Aug 28, 2023
426c1e9
Format with Black (#69)
Aug 28, 2023
9bc2563
Add tgt_lang to S2S dataloader to pass test(#69)
Aug 28, 2023
a8f3e0e
Add tgt_lang to S2S dataloader to pass test(#69)
Aug 28, 2023
c4853b1
Add tgt_lang to S2S dataloader to pass test(#69)
Aug 28, 2023
53f42d1
Change tgt_lang to es (#69)
Aug 28, 2023
606fc84
Add tgt_lang to test suites
Aug 28, 2023
3dc4e22
Add tgt_lang to test suites
Aug 28, 2023
f2856fb
Fix tgt-lang issue (#69)
Aug 28, 2023
7cc8d4a
format with black
Aug 28, 2023
5e0af0b
Add tgt-lang arg (#69)
Aug 30, 2023
75fce41
(#69)
Aug 30, 2023
133c290
Add tgt-lang (#69)
Aug 30, 2023
bd65412
Change instance prediction (#69)
Aug 30, 2023
d782a4d
Format (#69)
Aug 30, 2023
790c179
Add tgt_lang argument
Aug 30, 2023
20a77df
Resolve tgt_lang (#69)
Aug 30, 2023
70f0dce
Remove tgt-lang argument (#69)
Sep 6, 2023
52777b4
Remove tgt-lang argument (#69)
Sep 6, 2023
dfb1069
Add tgt-lang arg to dataloader
Sep 6, 2023
4817998
Preprocess tgt-lang (#69)
Sep 6, 2023
b02f372
Handle tgt-lang list (#69)
Sep 6, 2023
8b7b3e8
Format with black (#69)
Sep 6, 2023
6060f97
Merge branch 'main' of https://github.com/SamDewriter/SimulEval into …
Sep 6, 2023
7dddb93
Testing Circleci on main
Jul 19, 2023
1527591
Testing Circleci on main
Jul 19, 2023
e31714b
Testing Circleci on main
Jul 19, 2023
a5a007e
Testing Circleci on main
Jul 19, 2023
d5890da
Testing Circleci on main
Jul 19, 2023
54936d2
correct Circle config
Jul 20, 2023
2352d22
correct Circle config
Jul 20, 2023
2ea30a9
correct Circle config
Jul 20, 2023
4377457
correct Circle config
Jul 20, 2023
0521818
Recover deleted files
Sep 6, 2023
fa3b94c
Move tgt-lang to DataLoader (#69)
Sep 13, 2023
5d72f50
Rewrite tgt_lang (#69)
Sep 13, 2023
a6d3a58
Initialize tgt-lang (#69)
Sep 13, 2023
e8b4517
Format with black
Sep 13, 2023
609d282
Fix tgt-lang (#69)
Sep 13, 2023
f6b8cde
Correct tgt_lang logic
Sep 17, 2023
cca97aa
Merge branch 'main' into dynamic_language
SamDewriter Sep 17, 2023
54b1fb6
Resolve merge conflict
Sep 17, 2023
6b9e465
Merge branch 'dynamic_language' of https://github.com/SamDewriter/Sim…
Sep 17, 2023
6e164a7
remove tgt_lang check to reduce redundancy (#69)
Sep 20, 2023
c5daa0c
Lint with black (#69)
Sep 20, 2023
0279a59
Merge branch 'main' of https://github.com/SamDewriter/SimulEval into …
Sep 20, 2023
72254a7
Import from typing (#69)
Sep 20, 2023
e3b90a9
Lint (#69)
Sep 20, 2023
b28570e
Handle when tgt_lang is not for s2s (#69)
Sep 20, 2023
dcc037d
Remove comments (#69)
Sep 20, 2023
c5e0010
Add check for tgt_lang s2t (#69)
Sep 20, 2023
1698117
Format with black (#69)
Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/speech_to_speech/reference/tgt_lang.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
es
11 changes: 4 additions & 7 deletions examples/speech_to_text/counter_in_tgt_lang_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@ class CounterInTargetLanguage(SpeechToTextAgent):
def __init__(self, args):
super().__init__(args)
self.wait_seconds = args.wait_seconds
self.tgt_lang = args.tgt_lang

@staticmethod
def add_args(parser):
parser.add_argument("--wait-seconds", default=1, type=int)
parser.add_argument(
"--tgt-lang", default="en", type=str, choices=["en", "es", "de"]
)

def policy(self, states: Optional[AgentStates] = None):
if states is None:
Expand All @@ -35,11 +31,12 @@ def policy(self, states: Optional[AgentStates] = None):
return ReadAction()

prediction = f"{length_in_seconds} "
if self.tgt_lang == "en":
tgt_lang = states.tgt_lang
if tgt_lang == "en":
prediction += "seconds"
elif self.tgt_lang == "es":
elif tgt_lang == "es":
prediction += "segundos"
elif self.tgt_lang == "de":
elif tgt_lang == "de":
prediction += "sekunden"
else:
prediction += "<unknown>"
Expand Down
4 changes: 3 additions & 1 deletion examples/speech_to_text/eval.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
simuleval \
--agent english_counter_agent.py \
--agent counter_in_tgt_lang_agent.py \
--source-segment-size 1000 \
--source source.txt --target reference/en.txt \
--tgt-lang reference/tgt_lang.txt \
--output output

1 change: 1 addition & 0 deletions examples/speech_to_text/reference/tgt_lang.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
es
1 change: 1 addition & 0 deletions simuleval/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class SpeechToTextAgent(GenericAgent):

source_type: str = "speech"
target_type: str = "text"
tgt_lang: Optional[str] = None


class SpeechToSpeechAgent(GenericAgent):
Expand Down
3 changes: 3 additions & 0 deletions simuleval/agents/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def reset(self) -> None:
self.target_finished = False
self.source_sample_rate = 0
self.target_sample_rate = 0
self.tgt_lang = None
self.upstream_states = []

def update_source(self, segment: Segment):
Expand All @@ -43,9 +44,11 @@ def update_source(self, segment: Segment):
return
elif isinstance(segment, TextSegment):
self.source.append(segment.content)
self.tgt_lang = segment.tgt_lang
elif isinstance(segment, SpeechSegment):
self.source += segment.content
self.source_sample_rate = segment.sample_rate
self.tgt_lang = segment.tgt_lang
Comment on lines 48 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to restrict this just to the SpeechSegment case, can you add the tgt_lang part to TextSegment case too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

else:
raise NotImplementedError

Expand Down
4 changes: 3 additions & 1 deletion simuleval/data/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

import logging
from argparse import Namespace

from .dataloader import ( # noqa
GenericDataloader,
register_dataloader,
Expand All @@ -21,7 +23,7 @@
logger = logging.getLogger("simuleval.dataloader")


def build_dataloader(args) -> GenericDataloader:
def build_dataloader(args: Namespace) -> GenericDataloader:
dataloader_key = getattr(args, "dataloader", None)
if dataloader_key is not None:
assert dataloader_key in DATALOADER_DICT, f"{dataloader_key} is not defined"
Expand Down
26 changes: 23 additions & 3 deletions simuleval/data/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional
from argparse import Namespace, ArgumentParser

SUPPORTED_MEDIUM = ["text", "speech"]
Expand Down Expand Up @@ -37,10 +37,14 @@ class GenericDataloader:
"""

def __init__(
self, source_list: List[str], target_list: Union[List[str], List[None]]
self,
source_list: List[str],
target_list: Union[List[str], List[None]],
tgt_lang_list: Optional[List[str]] = None,
) -> None:
self.source_list = source_list
self.target_list = target_list
self.tgt_lang_list = tgt_lang_list
assert len(self.source_list) == len(self.target_list)

def __len__(self):
Expand All @@ -52,8 +56,18 @@ def get_source(self, index: int) -> Any:
def get_target(self, index: int) -> Any:
return self.preprocess_target(self.target_list[index])

def get_tgt_lang(self, index: int) -> Optional[str]:
if self.tgt_lang_list is None or index >= len(self.tgt_lang_list):
return None
else:
return self.tgt_lang_list[index]

def __getitem__(self, index: int) -> Dict[str, Any]:
return {"source": self.get_source(index), "target": self.get_target(index)}
return {
"source": self.get_source(index),
"target": self.get_target(index),
"tgt_lang": self.get_tgt_lang(index),
}

def preprocess_source(self, source: Any) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -95,3 +109,9 @@ def add_args(parser: ArgumentParser):
default=1,
help="Source segment size, For text the unit is # token, for speech is ms",
)
parser.add_argument(
"--tgt-lang",
type=str,
default=None,
help="Target language",
)
32 changes: 25 additions & 7 deletions simuleval/data/dataloader/s2t_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations
from pathlib import Path
from typing import List, Union
from typing import List, Union, Optional
from .dataloader import GenericDataloader
from simuleval.data.dataloader import register_dataloader
from argparse import Namespace
Expand Down Expand Up @@ -58,6 +58,14 @@ def get_video_id(url):

@register_dataloader("speech-to-text")
class SpeechToTextDataloader(GenericDataloader):
def __init__(
self,
source_list: List[str],
target_list: List[str],
tgt_lang_list: Optional[List[str]] = None,
) -> None:
super().__init__(source_list, target_list, tgt_lang_list)

def preprocess_source(self, source: Union[Path, str]) -> List[float]:
assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed."
samples, _ = soundfile.read(source, dtype="float32")
Expand All @@ -75,40 +83,50 @@ def get_source_audio_path(self, index: int):

@classmethod
def from_files(
cls, source: Union[Path, str], target: Union[Path, str]
cls,
source: Union[Path, str],
target: Union[Path, str],
tgt_lang: Union[Path, str],
) -> SpeechToTextDataloader:
with open(source) as f:
source_list = [line.strip() for line in f]
with open(target) as f:
target_list = [line.strip() for line in f]
dataloader = cls(source_list, target_list)
with open(tgt_lang) as f:
tgt_lang_list = [line.strip() for line in f]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the null check for tgt_lang similar to what you added in the S2S dataloader. Good job on adding the null check in the S2S dataloader.

dataloader = cls(source_list, target_list, tgt_lang_list)
return dataloader

@classmethod
def from_args(cls, args: Namespace):
args.source_type = "speech"
args.target_type = "text"
return cls.from_files(args.source, args.target)
return cls.from_files(args.source, args.target, args.tgt_lang)


@register_dataloader("speech-to-speech")
class SpeechToSpeechDataloader(SpeechToTextDataloader):
@classmethod
def from_files(
cls, source: Union[Path, str], target: Union[Path, str]
cls,
source: Union[Path, str],
target: Union[Path, str],
tgt_lang: Union[Path, str],
) -> SpeechToSpeechDataloader:
with open(source) as f:
source_list = [line.strip() for line in f]
with open(target) as f:
target_list = [line.strip() for line in f]
dataloader = cls(source_list, target_list)
with open(tgt_lang, "r") as f:
tgt_lang_list = [line.strip() for line in f]
dataloader = cls(source_list, target_list, tgt_lang_list)
return dataloader

@classmethod
def from_args(cls, args: Namespace):
args.source_type = "speech"
args.target_type = "speech"
return cls.from_files(args.source, args.target)
return cls.from_files(args.source, args.target, args.tgt_lang)


@register_dataloader("youtube-to-text")
Expand Down
4 changes: 4 additions & 0 deletions simuleval/data/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import json
from dataclasses import dataclass, field
from typing import Optional


@dataclass
Expand All @@ -15,6 +16,7 @@ class Segment:
finished: bool = False
is_empty: bool = False
data_type: str = None
tgt_lang: str = None

def json(self) -> str:
info_dict = {attribute: value for attribute, value in self.__dict__.items()}
Expand All @@ -34,12 +36,14 @@ class EmptySegment(Segment):
class TextSegment(Segment):
content: str = ""
data_type: str = "text"
tgt_lang: str = Optional[str]


@dataclass
class SpeechSegment(Segment):
sample_rate: int = -1
data_type: str = "speech"
tgt_lang: str = Optional[str]


def segment_from_json_string(string: str):
Expand Down
10 changes: 10 additions & 0 deletions simuleval/evaluator/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def __init__(
if self.dataloader is not None:
self.source = self.dataloader[self.index]["source"]
self.reference = self.dataloader[self.index]["target"]

if self.dataloader.tgt_lang_list is not None and self.index < len(
self.dataloader.tgt_lang_list
):
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, this check here is redundant as we already do this check in the dataloader

Suggested change
if self.dataloader.tgt_lang_list is not None and self.index < len(
self.dataloader.tgt_lang_list
):
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]


self.reset()
if args is not None:
self.args = args
Expand Down Expand Up @@ -235,6 +243,7 @@ def __init__(
args: Optional[Namespace],
):
super().__init__(index, dataloader, args)
self.args = args
self.sample_rate_value = None
self.sample_list = None
self.source_finished_reading = False
Expand Down Expand Up @@ -282,6 +291,7 @@ def send_source(self, segment_size=10):
content=samples,
sample_rate=self.audio_info.samplerate,
finished=is_finished,
tgt_lang=self.tgt_lang,
)

else:
Expand Down
6 changes: 6 additions & 0 deletions simuleval/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def general_parser():
default="main.yaml",
help="Name of the config yaml of the system configs.",
)
# parser.add_argument(
# "--tgt-lang",
# type=str,
# default=None,
# help="Path to the Target language file.",
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# parser.add_argument(
# "--tgt-lang",
# type=str,
# default=None,
# help="Path to the Target language file.",
# )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up this remnant commented code.

parser.add_argument("--dataloader", default=None, help="Dataloader to use")
parser.add_argument(
"--log-level",
Expand Down
4 changes: 4 additions & 0 deletions simuleval/test/test_s2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_s2s(root_path=ROOT_PATH):
os.path.join(root_path, "examples", "speech_to_speech", "reference/en.txt"),
"--output",
tmpdirname,
"--tgt-lang",
os.path.join(
root_path, "examples", "speech_to_speech", "reference/tgt_lang.txt"
),
]
cli.main()

Expand Down
12 changes: 9 additions & 3 deletions simuleval/test/test_s2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_s2t(root_path=ROOT_PATH):
cli.sys.argv[1:] = [
"--agent",
os.path.join(
root_path, "examples", "speech_to_text", "english_counter_agent.py"
root_path, "examples", "speech_to_text", "counter_in_tgt_lang_agent.py"
),
"--user-dir",
os.path.join(root_path, "examples"),
Expand All @@ -38,6 +38,10 @@ def test_s2t(root_path=ROOT_PATH):
os.path.join(root_path, "examples", "speech_to_text", "reference/en.txt"),
"--output",
tmpdirname,
"--tgt-lang",
os.path.join(
root_path, "examples", "speech_to_text", "reference/tgt_lang.txt"
),
]
cli.main()

Expand All @@ -46,7 +50,7 @@ def test_s2t(root_path=ROOT_PATH):
instance = LogInstance(line.strip())
assert (
instance.prediction
== "1 second 2 second 3 second 4 second 5 second 6 second 7 second"
== "1 segundos 2 segundos 3 segundos 4 segundos 5 segundos 6 segundos 7 segundos"
)


Expand Down Expand Up @@ -103,7 +107,9 @@ def test_s2t_with_tgt_lang(root_path=ROOT_PATH):
"--output",
tmpdirname,
"--tgt-lang",
"es",
os.path.join(
root_path, "examples", "speech_to_text", "reference/tgt_lang.txt"
),
]
cli.main()

Expand Down
Loading