Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Support --t2t_usr_dir also in t2t-datagen (#160)
Browse files Browse the repository at this point in the history
* Refactor user directory loading functionality and use it also from t2t-datagen

* Move flag declaration to the binary files
  • Loading branch information
noe authored and rsepassi committed Jul 18, 2017
1 parent c91989c commit 0d250b3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 18 deletions.
9 changes: 9 additions & 0 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing
from tensor2tensor.utils import registry
from tensor2tensor.utils import usr_dir

import tensorflow as tf

Expand All @@ -64,6 +65,13 @@ flags.DEFINE_integer("max_cases", 0,
"Maximum number of cases to generate (unbounded if 0).")
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")

flags.DEFINE_string("t2t_usr_dir", "",
"Path to a Python module that will be imported. The "
"__init__.py file should include the necessary imports. "
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-datagen.")

# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
Expand Down Expand Up @@ -273,6 +281,7 @@ def set_random_seed():

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

# Calculate the list of problems to generate.
problems = sorted(
Expand Down
20 changes: 2 additions & 18 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import sys
# Dependency imports

from tensor2tensor.utils import trainer_utils as utils

from tensor2tensor.utils import usr_dir
import tensorflow as tf

flags = tf.flags
Expand All @@ -49,25 +49,9 @@ flags.DEFINE_string("t2t_usr_dir", "",
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")


def import_usr_dir():
"""Import module at FLAGS.t2t_usr_dir, if provided."""
if not FLAGS.t2t_usr_dir:
return
dir_path = os.path.expanduser(FLAGS.t2t_usr_dir)
if dir_path[-1] == "/":
dir_path = dir_path[:-1]
containing_dir, module_name = os.path.split(dir_path)
tf.logging.info("Importing user module %s from path %s", module_name,
containing_dir)
sys.path.insert(0, containing_dir)
importlib.import_module(module_name)
sys.path.pop(0)


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
import_usr_dir()
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
utils.log_registry()
utils.validate_flags()
utils.run(
Expand Down
35 changes: 35 additions & 0 deletions tensor2tensor/utils/usr_dir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility to load code from an external directory supplied by user."""

import os
import sys
import importlib
import tensorflow as tf


def import_usr_dir(usr_dir):
"""Import user module, if provided."""
if not usr_dir:
return
dir_path = os.path.expanduser(usr_dir)
if dir_path[-1] == "/":
dir_path = dir_path[:-1]
containing_dir, module_name = os.path.split(dir_path)
tf.logging.info("Importing user module %s from path %s", module_name,
containing_dir)
sys.path.insert(0, containing_dir)
importlib.import_module(module_name)
sys.path.pop(0)

0 comments on commit 0d250b3

Please sign in to comment.