Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to switch output languages for multilingual models #69 #74

Merged
merged 85 commits into from
Sep 21, 2023

Conversation

SamDewriter
Copy link
Contributor

Changes have been made to add the ability to switch output languages for multilingual models.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 28, 2023
Copy link
Contributor

@ibanesh ibanesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SamDewriter for making the requested changes.
This iteration seems to be on the right track, but we still needs some modification before it is ready to be merged.

Comment on lines 164 to 169
# parser.add_argument(
# "--tgt-lang",
# type=str,
# default=None,
# help="Path to the Target language file.",
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# parser.add_argument(
# "--tgt-lang",
# type=str,
# default=None,
# help="Path to the Target language file.",
# )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up this remnant commented code.

Comment on lines 243 to 251
if self.dataloader is not None:
self.tgt_lang = self.dataloader.tgt_lang
if isinstance(self.tgt_lang, list):
if index < len(self.tgt_lang):
self.tgt_lang = self.tgt_lang[self.index]
else:
self.tgt_lang = None
else:
self.tgt_lang = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work, but is not the cleanest way to implement the tgt_lang assignment for instances.

Like I had pointed out earlier, we want something similar to how the source and target getters work.
https://github.com/facebookresearch/SimulEval/blob/main/simuleval/evaluator/instance.py#L46-L48
https://github.com/facebookresearch/SimulEval/blob/main/simuleval/data/dataloader/dataloader.py#L55-L56

In case it is not clear, we want to be able to do something like self.tgt_lang = self.dataloader[index]["tgt_lang"] to access the target language of a particular index.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another cause for concern here is that all dataloaders are going to have the tgt_lang attribute.
There are some dataloaders (in internal fairseq repo), which don't have the tgt_lang attribute that you added in SpeechToTextDataloader or SpeechToSpeechDataloaderin this PR and self.dataloader.tgt_lang would throw an exception if the dataloader doesn't have this attribute defined correctly. So, please keep in mind while making these changes that not all dataloader have the tgt_lang attribute and handle it accordingly.

Also, at this point feel free to factor out and move necessary changes to the base class GenericDataloader whenever it makes sense to do so.

Copy link
Contributor

@ibanesh ibanesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good! We are almost there, just need a few minor changes and cleanup.

@@ -184,6 +184,7 @@ class SpeechToTextAgent(GenericAgent):

source_type: str = "speech"
target_type: str = "text"
tgt_lang: str = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tgt_lang: str = None
tgt_lang: Optional[str] = None

Comment on lines 46 to +49
elif isinstance(segment, SpeechSegment):
self.source += segment.content
self.source_sample_rate = segment.sample_rate
self.tgt_lang = segment.tgt_lang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to restrict this just to the SpeechSegment case, can you add the tgt_lang part to TextSegment case too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -52,15 +56,28 @@ def get_source(self, index: int) -> Any:
def get_target(self, index: int) -> Any:
return self.preprocess_target(self.target_list[index])

def get_tgt_lang(self, index: int) -> Any:
if self.tgt_lang_list is not None and index < len(self.tgt_lang_list):
return self.preprocess_tgt_lang(self.tgt_lang_list[index])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't thin we need a preprocessor for tgt_lang attribute, preprocessor makes more sense for source and targets which can potentially be sound files.

Comment on lines 78 to 80
def preprocess_tgt_lang(self, tgt_lang: Any) -> Any:
raise NotImplementedError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def preprocess_tgt_lang(self, tgt_lang: Any) -> Any:
raise NotImplementedError

Also, even if we want a a preprocessor for tgt_lang, throwing a NotImplementedError here would break some of the dataloaders (not visible to you) that derives from this base class and don't have this new method implemented.

@@ -52,15 +56,28 @@ def get_source(self, index: int) -> Any:
def get_target(self, index: int) -> Any:
return self.preprocess_target(self.target_list[index])

def get_tgt_lang(self, index: int) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_tgt_lang(self, index: int) -> Any:
def get_tgt_lang(self, index: int) -> Optional[str]:

@@ -34,12 +35,14 @@ class EmptySegment(Segment):
class TextSegment(Segment):
content: str = ""
data_type: str = "text"
tgt_lang: str = ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tgt_lang: str = ""
tgt_lang: Optional[str] = None



@dataclass
class SpeechSegment(Segment):
sample_rate: int = -1
data_type: str = "speech"
tgt_lang: str = ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tgt_lang: str = ""
tgt_lang: Optional[str] = None

dataloader = cls(source_list, target_list)
with open(tgt_lang) as f:
tgt_lang_list = [line.strip() for line in f]
print(tgt_lang_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup

Suggested change
print(tgt_lang_list)

@@ -46,6 +46,13 @@ def __init__(
if self.dataloader is not None:
self.source = self.dataloader[self.index]["source"]
self.reference = self.dataloader[self.index]["target"]

# Handle when tgt_lang is not provided
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Handle when tgt_lang is not provided

The comment here seems redundant, it is obvious from the code.

Comment on lines 51 to 54
if self.dataloader.tgt_lang_list is not None:
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.dataloader.tgt_lang_list is not None:
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None
if self.dataloader.tgt_lang_list is not None and self.index < len(self.dataloader.tgt_lang_list):
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None

Comment on lines 49 to 55

if self.dataloader.tgt_lang_list is not None and self.index < len(
self.dataloader.tgt_lang_list
):
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, this check here is redundant as we already do this check in the dataloader

Suggested change
if self.dataloader.tgt_lang_list is not None and self.index < len(
self.dataloader.tgt_lang_list
):
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]
else:
self.tgt_lang = None
self.tgt_lang = self.dataloader[self.index]["tgt_lang"]

Copy link
Contributor

@ibanesh ibanesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamDewriter Thanks for making the requested changes. This looks good and almost ready to be merged.

Comment on lines 124 to 128
tgt_lang_list = []
if tgt_lang is not None:
with open(tgt_lang) as f:
tgt_lang_list = [line.strip() for line in f]
dataloader = cls(source_list, target_list, tgt_lang_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on catching that S2S dataloader is missing this part and adding it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Comment on lines 99 to 100
with open(tgt_lang) as f:
tgt_lang_list = [line.strip() for line in f]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the null check for tgt_lang similar to what you added in the S2S dataloader. Good job on adding the null check in the S2S dataloader.

Copy link
Contributor

@ibanesh ibanesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@ibanesh ibanesh merged commit b07447a into facebookresearch:main Sep 21, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants