From 0d250b31234378f0687ee7e94db42202dab3a99d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?No=C3=A9=20Casas?= Date: Tue, 18 Jul 2017 20:25:19 +0200 Subject: [PATCH] Support --t2t_usr_dir also in t2t-datagen (#160) * Refactor user directory loading functionality and use it also from t2t-datagen * Move flag declaration to the binary files --- tensor2tensor/bin/t2t-datagen | 9 +++++++++ tensor2tensor/bin/t2t-trainer | 20 ++----------------- tensor2tensor/utils/usr_dir.py | 35 ++++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 18 deletions(-) create mode 100644 tensor2tensor/utils/usr_dir.py diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 44e4b34d3..63eb7e45e 100755 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -48,6 +48,7 @@ from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing from tensor2tensor.utils import registry +from tensor2tensor.utils import usr_dir import tensorflow as tf @@ -64,6 +65,13 @@ flags.DEFINE_integer("max_cases", 0, "Maximum number of cases to generate (unbounded if 0).") flags.DEFINE_integer("random_seed", 429459, "Random seed to use.") +flags.DEFINE_string("t2t_usr_dir", "", + "Path to a Python module that will be imported. The " + "__init__.py file should include the necessary imports. " + "The imported files should contain registrations, " + "e.g. @registry.register_model calls, that will then be " + "available to the t2t-datagen.") + # Mapping from problems that we can generate data for to their generators. # pylint: disable=g-long-lambda _SUPPORTED_PROBLEM_GENERATORS = { @@ -273,6 +281,7 @@ def set_random_seed(): def main(_): tf.logging.set_verbosity(tf.logging.INFO) + usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # Calculate the list of problems to generate. problems = sorted( diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 322957028..6b3f4de71 100755 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -36,7 +36,7 @@ import sys # Dependency imports from tensor2tensor.utils import trainer_utils as utils - +from tensor2tensor.utils import usr_dir import tensorflow as tf flags = tf.flags @@ -49,25 +49,9 @@ flags.DEFINE_string("t2t_usr_dir", "", "e.g. @registry.register_model calls, that will then be " "available to the t2t-trainer.") - -def import_usr_dir(): - """Import module at FLAGS.t2t_usr_dir, if provided.""" - if not FLAGS.t2t_usr_dir: - return - dir_path = os.path.expanduser(FLAGS.t2t_usr_dir) - if dir_path[-1] == "/": - dir_path = dir_path[:-1] - containing_dir, module_name = os.path.split(dir_path) - tf.logging.info("Importing user module %s from path %s", module_name, - containing_dir) - sys.path.insert(0, containing_dir) - importlib.import_module(module_name) - sys.path.pop(0) - - def main(_): tf.logging.set_verbosity(tf.logging.INFO) - import_usr_dir() + usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) utils.log_registry() utils.validate_flags() utils.run( diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py new file mode 100644 index 000000000..ed5623c8e --- /dev/null +++ b/tensor2tensor/utils/usr_dir.py @@ -0,0 +1,35 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility to load code from an external directory supplied by user.""" + +import os +import sys +import importlib +import tensorflow as tf + + +def import_usr_dir(usr_dir): + """Import user module, if provided.""" + if not usr_dir: + return + dir_path = os.path.expanduser(usr_dir) + if dir_path[-1] == "/": + dir_path = dir_path[:-1] + containing_dir, module_name = os.path.split(dir_path) + tf.logging.info("Importing user module %s from path %s", module_name, + containing_dir) + sys.path.insert(0, containing_dir) + importlib.import_module(module_name) + sys.path.pop(0)