-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to switch output languages for multilingual models #69 #74
Changes from 76 commits
89fbc24
93f9643
fd41dc2
4becf02
cc893ef
a6f00f8
0f41351
78d13d0
98ae6ec
cdffc8d
cacfbc9
f97cdfa
4ddf84a
69e5816
812bcbc
233aa35
f15634c
c06fb0c
5e0165a
c43e9da
bbe6a88
03c06b1
9d97d18
3e5fbe6
6d06a8e
50efb42
1c07b35
0627451
3a63e16
71bc44e
38be852
e3de9f1
8fd77e1
7599842
426c1e9
9bc2563
a8f3e0e
c4853b1
53f42d1
606fc84
3dc4e22
f2856fb
7cc8d4a
5e0af0b
75fce41
133c290
bd65412
d782a4d
790c179
20a77df
70f0dce
52777b4
dfb1069
4817998
b02f372
8b7b3e8
6060f97
7dddb93
1527591
e31714b
a5a007e
d5890da
54936d2
2352d22
2ea30a9
4377457
0521818
fa3b94c
5d72f50
a6d3a58
e8b4517
609d282
f6b8cde
cca97aa
54b1fb6
6b9e465
6e164a7
c5daa0c
0279a59
72254a7
e3b90a9
b28570e
dcc037d
c5e0010
1698117
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
es |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
simuleval \ | ||
--agent english_counter_agent.py \ | ||
--agent counter_in_tgt_lang_agent.py \ | ||
--source-segment-size 1000 \ | ||
--source source.txt --target reference/en.txt \ | ||
--tgt-lang reference/tgt_lang.txt \ | ||
--output output | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
es |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
from __future__ import annotations | ||
from pathlib import Path | ||
from typing import List, Union | ||
from typing import List, Union, Optional | ||
from .dataloader import GenericDataloader | ||
from simuleval.data.dataloader import register_dataloader | ||
from argparse import Namespace | ||
|
@@ -58,6 +58,14 @@ def get_video_id(url): | |
|
||
@register_dataloader("speech-to-text") | ||
class SpeechToTextDataloader(GenericDataloader): | ||
def __init__( | ||
self, | ||
source_list: List[str], | ||
target_list: List[str], | ||
tgt_lang_list: Optional[List[str]] = None, | ||
) -> None: | ||
super().__init__(source_list, target_list, tgt_lang_list) | ||
|
||
def preprocess_source(self, source: Union[Path, str]) -> List[float]: | ||
assert IS_IMPORT_SOUNDFILE, "Please make sure soundfile is properly installed." | ||
samples, _ = soundfile.read(source, dtype="float32") | ||
|
@@ -75,40 +83,50 @@ def get_source_audio_path(self, index: int): | |
|
||
@classmethod | ||
def from_files( | ||
cls, source: Union[Path, str], target: Union[Path, str] | ||
cls, | ||
source: Union[Path, str], | ||
target: Union[Path, str], | ||
tgt_lang: Union[Path, str], | ||
) -> SpeechToTextDataloader: | ||
with open(source) as f: | ||
source_list = [line.strip() for line in f] | ||
with open(target) as f: | ||
target_list = [line.strip() for line in f] | ||
dataloader = cls(source_list, target_list) | ||
with open(tgt_lang) as f: | ||
tgt_lang_list = [line.strip() for line in f] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add the null check for tgt_lang similar to what you added in the S2S dataloader. Good job on adding the null check in the S2S dataloader. |
||
dataloader = cls(source_list, target_list, tgt_lang_list) | ||
return dataloader | ||
|
||
@classmethod | ||
def from_args(cls, args: Namespace): | ||
args.source_type = "speech" | ||
args.target_type = "text" | ||
return cls.from_files(args.source, args.target) | ||
return cls.from_files(args.source, args.target, args.tgt_lang) | ||
|
||
|
||
@register_dataloader("speech-to-speech") | ||
class SpeechToSpeechDataloader(SpeechToTextDataloader): | ||
@classmethod | ||
def from_files( | ||
cls, source: Union[Path, str], target: Union[Path, str] | ||
cls, | ||
source: Union[Path, str], | ||
target: Union[Path, str], | ||
tgt_lang: Union[Path, str], | ||
) -> SpeechToSpeechDataloader: | ||
with open(source) as f: | ||
source_list = [line.strip() for line in f] | ||
with open(target) as f: | ||
target_list = [line.strip() for line in f] | ||
dataloader = cls(source_list, target_list) | ||
with open(tgt_lang, "r") as f: | ||
tgt_lang_list = [line.strip() for line in f] | ||
dataloader = cls(source_list, target_list, tgt_lang_list) | ||
return dataloader | ||
|
||
@classmethod | ||
def from_args(cls, args: Namespace): | ||
args.source_type = "speech" | ||
args.target_type = "speech" | ||
return cls.from_files(args.source, args.target) | ||
return cls.from_files(args.source, args.target, args.tgt_lang) | ||
|
||
|
||
@register_dataloader("youtube-to-text") | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -46,6 +46,14 @@ def __init__( | |||||||||||||||||
if self.dataloader is not None: | ||||||||||||||||||
self.source = self.dataloader[self.index]["source"] | ||||||||||||||||||
self.reference = self.dataloader[self.index]["target"] | ||||||||||||||||||
|
||||||||||||||||||
if self.dataloader.tgt_lang_list is not None and self.index < len( | ||||||||||||||||||
self.dataloader.tgt_lang_list | ||||||||||||||||||
): | ||||||||||||||||||
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] | ||||||||||||||||||
else: | ||||||||||||||||||
self.tgt_lang = None | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, this check here is redundant as we already do this check in the dataloader
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
self.reset() | ||||||||||||||||||
if args is not None: | ||||||||||||||||||
self.args = args | ||||||||||||||||||
|
@@ -235,6 +243,7 @@ def __init__( | |||||||||||||||||
args: Optional[Namespace], | ||||||||||||||||||
): | ||||||||||||||||||
super().__init__(index, dataloader, args) | ||||||||||||||||||
self.args = args | ||||||||||||||||||
self.sample_rate_value = None | ||||||||||||||||||
self.sample_list = None | ||||||||||||||||||
self.source_finished_reading = False | ||||||||||||||||||
|
@@ -282,6 +291,7 @@ def send_source(self, segment_size=10): | |||||||||||||||||
content=samples, | ||||||||||||||||||
sample_rate=self.audio_info.samplerate, | ||||||||||||||||||
finished=is_finished, | ||||||||||||||||||
tgt_lang=self.tgt_lang, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
else: | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -161,6 +161,12 @@ def general_parser(): | |||||||||||||
default="main.yaml", | ||||||||||||||
help="Name of the config yaml of the system configs.", | ||||||||||||||
) | ||||||||||||||
# parser.add_argument( | ||||||||||||||
# "--tgt-lang", | ||||||||||||||
# type=str, | ||||||||||||||
# default=None, | ||||||||||||||
# help="Path to the Target language file.", | ||||||||||||||
# ) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please clean up this remnant commented code. |
||||||||||||||
parser.add_argument("--dataloader", default=None, help="Dataloader to use") | ||||||||||||||
parser.add_argument( | ||||||||||||||
"--log-level", | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any reason to restrict this just to the
SpeechSegment
case, can you add the tgt_lang part toTextSegment
case too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done