Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #165 from chiphuyen/master
Browse files Browse the repository at this point in the history
Add 'num_shards" to generate_data
  • Loading branch information
lukaszkaiser authored Jul 18, 2017
2 parents 0d250b3 + c3a59b4 commit 963730e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def generate_data_for_problem(problem):

def generate_data_for_registered_problem(problem_name):
problem = registry.problem(problem_name)
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir, FLAGS.num_shards)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/data_generators/algorithmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ class AlgorithmicIdentityBinary40(problem.Problem):
def num_symbols(self):
return 2

def generate_data(self, data_dir, _):
def generate_data(self, data_dir, _, num_shards=100):
utils.generate_dataset_and_shuffle(
identity_generator(self.num_symbols, 40, 100000),
self.training_filepaths(data_dir, 100, shuffled=True),
self.training_filepaths(data_dir, num_shards, shuffled=True),
identity_generator(self.num_symbols, 400, 10000),
self.dev_filepaths(data_dir, 1, shuffled=True),
shuffle=False)
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class Problem(object):
# BEGIN SUBCLASS INTERFACE
# ============================================================================

def generate_data(self, data_dir, tmp_dir):
def generate_data(self, data_dir, tmp_dir, num_shards=100):
raise NotImplementedError()

def hparams(self, defaults, model_hparams):
Expand Down
8 changes: 4 additions & 4 deletions tensor2tensor/data_generators/wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def target_vocab_size(self):
def feature_encoders(self, data_dir):
return _default_wmt_feature_encoders(data_dir, self.target_vocab_size)

def generate_data(self, data_dir, tmp_dir):
def generate_data(self, data_dir, tmp_dir, num_shards=100):
generator_utils.generate_dataset_and_shuffle(
ende_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size),
self.training_filepaths(data_dir, 100, shuffled=False),
self.training_filepaths(data_dir, num_shards, shuffled=False),
ende_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size),
self.dev_filepaths(data_dir, 1, shuffled=False))

Expand Down Expand Up @@ -92,10 +92,10 @@ def target_vocab_size(self):
def feature_encoders(self, data_dir):
return _default_wmt_feature_encoders(data_dir, self.target_vocab_size)

def generate_data(self, data_dir, tmp_dir):
def generate_data(self, data_dir, tmp_dir, num_shards=100):
generator_utils.generate_dataset_and_shuffle(
mken_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size),
self.training_filepaths(data_dir, 100, shuffled=False),
self.training_filepaths(data_dir, num_shards, shuffled=False),
mken_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size),
self.dev_filepaths(data_dir, 1, shuffled=False))

Expand Down

0 comments on commit 963730e

Please sign in to comment.