-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to switch output languages for multilingual models #69 #74
Conversation
This reverts commit 075c4d3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @SamDewriter for making the requested changes.
This iteration seems to be on the right track, but we still needs some modification before it is ready to be merged.
simuleval/options.py
Outdated
# parser.add_argument( | ||
# "--tgt-lang", | ||
# type=str, | ||
# default=None, | ||
# help="Path to the Target language file.", | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# parser.add_argument( | |
# "--tgt-lang", | |
# type=str, | |
# default=None, | |
# help="Path to the Target language file.", | |
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clean up this remnant commented code.
simuleval/evaluator/instance.py
Outdated
if self.dataloader is not None: | ||
self.tgt_lang = self.dataloader.tgt_lang | ||
if isinstance(self.tgt_lang, list): | ||
if index < len(self.tgt_lang): | ||
self.tgt_lang = self.tgt_lang[self.index] | ||
else: | ||
self.tgt_lang = None | ||
else: | ||
self.tgt_lang = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might work, but is not the cleanest way to implement the tgt_lang assignment for instances.
Like I had pointed out earlier, we want something similar to how the source and target getters work.
https://github.com/facebookresearch/SimulEval/blob/main/simuleval/evaluator/instance.py#L46-L48
https://github.com/facebookresearch/SimulEval/blob/main/simuleval/data/dataloader/dataloader.py#L55-L56
In case it is not clear, we want to be able to do something like self.tgt_lang = self.dataloader[index]["tgt_lang"]
to access the target language of a particular index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another cause for concern here is that all dataloaders are going to have the tgt_lang
attribute.
There are some dataloaders (in internal fairseq repo), which don't have the tgt_lang
attribute that you added in SpeechToTextDataloader
or SpeechToSpeechDataloader
in this PR and self.dataloader.tgt_lang
would throw an exception if the dataloader doesn't have this attribute defined correctly. So, please keep in mind while making these changes that not all dataloader have the tgt_lang attribute and handle it accordingly.
Also, at this point feel free to factor out and move necessary changes to the base class GenericDataloader
whenever it makes sense to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good! We are almost there, just need a few minor changes and cleanup.
simuleval/agents/agent.py
Outdated
@@ -184,6 +184,7 @@ class SpeechToTextAgent(GenericAgent): | |||
|
|||
source_type: str = "speech" | |||
target_type: str = "text" | |||
tgt_lang: str = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tgt_lang: str = None | |
tgt_lang: Optional[str] = None |
elif isinstance(segment, SpeechSegment): | ||
self.source += segment.content | ||
self.source_sample_rate = segment.sample_rate | ||
self.tgt_lang = segment.tgt_lang |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any reason to restrict this just to the SpeechSegment
case, can you add the tgt_lang part to TextSegment
case too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -52,15 +56,28 @@ def get_source(self, index: int) -> Any: | |||
def get_target(self, index: int) -> Any: | |||
return self.preprocess_target(self.target_list[index]) | |||
|
|||
def get_tgt_lang(self, index: int) -> Any: | |||
if self.tgt_lang_list is not None and index < len(self.tgt_lang_list): | |||
return self.preprocess_tgt_lang(self.tgt_lang_list[index]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't thin we need a preprocessor for tgt_lang attribute, preprocessor makes more sense for source and targets which can potentially be sound files.
def preprocess_tgt_lang(self, tgt_lang: Any) -> Any: | ||
raise NotImplementedError | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def preprocess_tgt_lang(self, tgt_lang: Any) -> Any: | |
raise NotImplementedError |
Also, even if we want a a preprocessor for tgt_lang, throwing a NotImplementedError
here would break some of the dataloaders (not visible to you) that derives from this base class and don't have this new method implemented.
@@ -52,15 +56,28 @@ def get_source(self, index: int) -> Any: | |||
def get_target(self, index: int) -> Any: | |||
return self.preprocess_target(self.target_list[index]) | |||
|
|||
def get_tgt_lang(self, index: int) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_tgt_lang(self, index: int) -> Any: | |
def get_tgt_lang(self, index: int) -> Optional[str]: |
simuleval/data/segments.py
Outdated
@@ -34,12 +35,14 @@ class EmptySegment(Segment): | |||
class TextSegment(Segment): | |||
content: str = "" | |||
data_type: str = "text" | |||
tgt_lang: str = "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tgt_lang: str = "" | |
tgt_lang: Optional[str] = None |
simuleval/data/segments.py
Outdated
|
||
|
||
@dataclass | ||
class SpeechSegment(Segment): | ||
sample_rate: int = -1 | ||
data_type: str = "speech" | ||
tgt_lang: str = "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tgt_lang: str = "" | |
tgt_lang: Optional[str] = None |
dataloader = cls(source_list, target_list) | ||
with open(tgt_lang) as f: | ||
tgt_lang_list = [line.strip() for line in f] | ||
print(tgt_lang_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup
print(tgt_lang_list) |
simuleval/evaluator/instance.py
Outdated
@@ -46,6 +46,13 @@ def __init__( | |||
if self.dataloader is not None: | |||
self.source = self.dataloader[self.index]["source"] | |||
self.reference = self.dataloader[self.index]["target"] | |||
|
|||
# Handle when tgt_lang is not provided |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Handle when tgt_lang is not provided |
The comment here seems redundant, it is obvious from the code.
simuleval/evaluator/instance.py
Outdated
if self.dataloader.tgt_lang_list is not None: | ||
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] | ||
else: | ||
self.tgt_lang = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.dataloader.tgt_lang_list is not None: | |
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] | |
else: | |
self.tgt_lang = None | |
if self.dataloader.tgt_lang_list is not None and self.index < len(self.dataloader.tgt_lang_list): | |
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] | |
else: | |
self.tgt_lang = None |
simuleval/evaluator/instance.py
Outdated
|
||
if self.dataloader.tgt_lang_list is not None and self.index < len( | ||
self.dataloader.tgt_lang_list | ||
): | ||
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] | ||
else: | ||
self.tgt_lang = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, this check here is redundant as we already do this check in the dataloader
if self.dataloader.tgt_lang_list is not None and self.index < len( | |
self.dataloader.tgt_lang_list | |
): | |
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] | |
else: | |
self.tgt_lang = None | |
self.tgt_lang = self.dataloader[self.index]["tgt_lang"] |
…dynamic_language
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SamDewriter Thanks for making the requested changes. This looks good and almost ready to be merged.
tgt_lang_list = [] | ||
if tgt_lang is not None: | ||
with open(tgt_lang) as f: | ||
tgt_lang_list = [line.strip() for line in f] | ||
dataloader = cls(source_list, target_list, tgt_lang_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work on catching that S2S dataloader is missing this part and adding it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
with open(tgt_lang) as f: | ||
tgt_lang_list = [line.strip() for line in f] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the null check for tgt_lang similar to what you added in the S2S dataloader. Good job on adding the null check in the S2S dataloader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Changes have been made to add the ability to switch output languages for multilingual models.