Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #154 from vthorsteinsson/iceparse
Browse files Browse the repository at this point in the history
Source/target pair text files; Icelandic parsing support; fixes
  • Loading branch information
lukaszkaiser authored Jul 14, 2017
2 parents 4617c01 + 5a72e5c commit 43bfb9f
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 83 deletions.
8 changes: 8 additions & 0 deletions tensor2tensor/bin/t2t-datagen
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ _SUPPORTED_PROBLEM_GENERATORS = {
"algorithmic_algebra_inverse": (
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
"ice_parsing_tokens": (
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
True, "ice", 2**13, 2**8),
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
False, "ice", 2**13, 2**8)),
"ice_parsing_characters": (
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
"wmt_parsing_tokens_8k": (
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
Expand Down
Empty file modified tensor2tensor/bin/t2t-trainer
100644 → 100755
Empty file.
28 changes: 28 additions & 0 deletions tensor2tensor/data_generators/generator_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,34 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
return vocab


def get_or_generate_tabbed_vocab(tmp_dir, source_filename, index, vocab_filename, vocab_size):
"""Generate a vocabulary from the source file. This is assumed to be
a file of source, target pairs, where each line contains a source string
and a target string, separated by a tab ('\t') character. The index
parameter specifies 0 for the source or 1 for the target."""
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
if os.path.exists(vocab_filepath):
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
return vocab

# Use Tokenizer to count the word occurrences.
token_counts = defaultdict(int)
filepath = os.path.join(tmp_dir, source_filename)
with tf.gfile.GFile(filepath, mode="r") as source_file:
for line in source_file:
line = line.strip()
if line and '\t' in line:
parts = line.split('\t', maxsplit = 1)
part = parts[index].strip()
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
token_counts[tok] += 1

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab


def read_records(filename):
reader = tf.python_io.tf_record_iterator(filename)
records = []
Expand Down
59 changes: 48 additions & 11 deletions tensor2tensor/data_generators/problem_hparams.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ def parse_problem_name(problem_name):
was_copy: A boolean.
"""
# Recursively strip tags until we reach a base name.
if len(problem_name) > 4 and problem_name[-4:] == "_rev":
if problem_name.endswith("_rev"):
base, _, was_copy = parse_problem_name(problem_name[:-4])
return base, True, was_copy
elif len(problem_name) > 5 and problem_name[-5:] == "_copy":
if problem_name.endswith("_copy"):
base, was_reversed, _ = parse_problem_name(problem_name[:-5])
return base, was_reversed, True
else:
return problem_name, False, False
return problem_name, False, False


def _lookup_problem_hparams_fn(name):
Expand Down Expand Up @@ -178,6 +177,9 @@ def default_problem_hparams():
# 14: Parse characters
# 15: Parse tokens
# 16: Chinese tokens
# 17: Icelandic characters
# 18: Icelandic tokens
# 19: Icelandic parse tokens
# Add more above if needed.
input_space_id=0,
target_space_id=0,
Expand All @@ -198,7 +200,8 @@ def default_problem_hparams():
# the targets. For instance `problem_copy` will copy the inputs, but
# `problem_rev_copy` will copy the targets.
was_reversed=False,
was_copy=False,)
was_copy=False,
)


def test_problem_hparams(unused_model_hparams, input_vocab_size,
Expand Down Expand Up @@ -532,7 +535,7 @@ def wmt_concat(model_hparams, wrong_vocab_size):
return p


def wmt_parsing_characters(unused_model_hparams):
def wmt_parsing_characters(model_hparams):
"""English to parse tree translation benchmark."""
p = default_problem_hparams()
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)}
Expand Down Expand Up @@ -576,7 +579,8 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size):
return p


def wsj_parsing_tokens(model_hparams, wrong_source_vocab_size,
def wsj_parsing_tokens(model_hparams, prefix,
wrong_source_vocab_size,
wrong_target_vocab_size):
"""English to parse tree translation benchmark.
Expand All @@ -595,10 +599,10 @@ def wsj_parsing_tokens(model_hparams, wrong_source_vocab_size,
# This vocab file must be present within the data directory.
source_vocab_filename = os.path.join(
model_hparams.data_dir,
"wsj_source.tokens.vocab.%d" % wrong_source_vocab_size)
prefix + "_source.tokens.vocab.%d" % wrong_source_vocab_size)
target_vocab_filename = os.path.join(
model_hparams.data_dir,
"wsj_target.tokens.vocab.%d" % wrong_target_vocab_size)
prefix + "_target.tokens.vocab.%d" % wrong_target_vocab_size)
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
p.input_modality = {
Expand All @@ -615,6 +619,37 @@ def wsj_parsing_tokens(model_hparams, wrong_source_vocab_size,
return p


def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
"""Icelandic to parse tree translation benchmark.
Args:
model_hparams: a tf.contrib.training.HParams
Returns:
a tf.contrib.training.HParams
"""
p = default_problem_hparams()
# This vocab file must be present within the data directory.
source_vocab_filename = os.path.join(
model_hparams.data_dir,
"ice_source.tokens.vocab.%d" % wrong_source_vocab_size)
target_vocab_filename = os.path.join(
model_hparams.data_dir,
"ice_target.tokens.vocab.256")
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
p.input_modality = {
"inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
}
p.target_modality = (registry.Modalities.SYMBOL, 256)
p.vocabulary = {
"inputs": source_subtokenizer,
"targets": target_subtokenizer,
}
p.input_space_id = 18 # Icelandic tokens
p.target_space_id = 19 # Icelandic parse tokens
return p


def image_cifar10(unused_model_hparams):
"""CIFAR-10."""
p = default_problem_hparams()
Expand Down Expand Up @@ -733,9 +768,11 @@ def img2img_imagenet(unused_model_hparams):
"wiki_32k": wiki_32k,
"lmptb_10k": lmptb_10k,
"wmt_parsing_characters": wmt_parsing_characters,
"ice_parsing_characters": wmt_parsing_characters,
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens(p, 2**14, 2**9),
"wsj_parsing_tokens_32k": lambda p: wsj_parsing_tokens(p, 2**15, 2**9),
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens(p, "wsj", 2**14, 2**9),
"wsj_parsing_tokens_32k": lambda p: wsj_parsing_tokens(p, "wsj", 2**15, 2**9),
"wmt_enfr_characters": wmt_enfr_characters,
"wmt_enfr_tokens_8k": lambda p: wmt_enfr_tokens(p, 2**13),
"wmt_enfr_tokens_32k": lambda p: wmt_enfr_tokens(p, 2**15),
Expand Down
37 changes: 22 additions & 15 deletions tensor2tensor/data_generators/text_encoder.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,32 @@


# Conversion between Unicode and UTF-8, if required (on Python2)
def native_to_unicode(s):
return s.decode("utf-8") if (PY2 and not isinstance(s, unicode)) else s


unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
if PY2:
native_to_unicode = lambda s: s if isinstance(s, unicode) else s.decode("utf-8")
unicode_to_native = lambda s: s.encode("utf-8")
else:
# No conversion required on Python3
native_to_unicode = lambda s: s
unicode_to_native = lambda s: s


# Reserved tokens for things like padding and EOS symbols.
PAD = "<pad>"
EOS = "<EOS>"
RESERVED_TOKENS = [PAD, EOS]
if six.PY2:
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1

if PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]


class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""

def __init__(self, num_reserved_ids=2):
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
self._num_reserved_ids = num_reserved_ids

def encode(self, s):
Expand Down Expand Up @@ -105,7 +110,7 @@ class ByteTextEncoder(TextEncoder):

def encode(self, s):
numres = self._num_reserved_ids
if six.PY2:
if PY2:
return [ord(c) + numres for c in s]
# Python3: explicitly convert to UTF-8
return [c + numres for c in s.encode("utf-8")]
Expand All @@ -119,10 +124,10 @@ def decode(self, ids):
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
else:
decoded_ids.append(int2byte(id_ - numres))
if six.PY2:
if PY2:
return "".join(decoded_ids)
# Python3: join byte arrays and then decode string
return b"".join(decoded_ids).decode("utf-8")
return b"".join(decoded_ids).decode("utf-8", "replace")

@property
def vocab_size(self):
Expand All @@ -132,7 +137,7 @@ def vocab_size(self):
class TokenTextEncoder(TextEncoder):
"""Encoder based on a user-supplied vocabulary."""

def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
def __init__(self, vocab_filename, reverse=False, num_reserved_ids=NUM_RESERVED_TOKENS):
"""Initialize from a file, one token per line."""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
Expand Down Expand Up @@ -345,7 +350,7 @@ def build_from_token_counts(self,
token_counts,
min_count,
num_iterations=4,
num_reserved_ids=2):
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Train a SubwordTextEncoder based on a dictionary of word counts.
Args:
Expand All @@ -371,6 +376,8 @@ def build_from_token_counts(self,
# We build iteratively. On each iteration, we segment all the words,
# then count the resulting potential subtokens, keeping the ones
# with high enough counts for our new vocabulary.
if min_count < 1:
min_count = 1
for i in xrange(num_iterations):
tf.logging.info("Iteration {0}".format(i))
counts = defaultdict(int)
Expand Down Expand Up @@ -462,7 +469,7 @@ def store_to_file(self, filename):
f.write("'" + unicode_to_native(subtoken_string) + "'\n")

def _escape_token(self, token):
r"""Escape away underscores and OOV characters and append '_'.
"""Escape away underscores and OOV characters and append '_'.
This allows the token to be experessed as the concatenation of a list
of subtokens from the vocabulary. The underscore acts as a sentinel
Expand All @@ -484,7 +491,7 @@ def _escape_token(self, token):
return ret

def _unescape_token(self, escaped_token):
r"""Inverse of _escape_token().
"""Inverse of _escape_token().
Args:
escaped_token: a unicode string
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/tokenizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def read_corpus():
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
return docs
return docs

counts = defaultdict(int)
for doc in read_corpus():
for tok in encode(_native_to_unicode(doc)):
Expand Down
68 changes: 64 additions & 4 deletions tensor2tensor/data_generators/wmt.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@
FLAGS = tf.flags.FLAGS


# End-of-sentence marker (should correspond to the position of EOS in the
# RESERVED_TOKENS list in text_encoder.py)
EOS = 1
# End-of-sentence marker
EOS = text_encoder.EOS_TOKEN


def character_generator(source_path, target_path, character_vocab, eos=None):
Expand Down Expand Up @@ -72,6 +71,35 @@ def character_generator(source_path, target_path, character_vocab, eos=None):
source, target = source_file.readline(), target_file.readline()


def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
"""Generator for sequence-to-sequence tasks using tokens derived from
text files where each line contains both a source and a target string.
The two strings are separated by a tab character ('\t'). It yields
dictionaries of "inputs" and "targets" where inputs are characters
from the source lines converted to integers, and targets are
characters from the target lines, also converted to integers.
Args:
source_path: path to the file with source and target sentences.
source_vocab: a SunwordTextEncoder to encode the source string.
target_vocab: a SunwordTextEncoder to encode the target string.
eos: integer to append at the end of each sequence (default: None).
Yields:
A dictionary {"inputs": source-line, "targets": target-line} where
the lines are integer lists converted from characters in the file lines.
"""
eos_list = [] if eos is None else [eos]
with tf.gfile.GFile(source_path, mode="r") as source_file:
for line in source_file:
if line and '\t' in line:
parts = line.split('\t', maxsplit = 1)
source, target = parts[0].strip(), parts[1].strip()
source_ints = source_vocab.encode(source) + eos_list
target_ints = source_vocab.encode(target) + eos_list
yield {"inputs": source_ints, "targets": target_ints}


def token_generator(source_path, target_path, token_vocab, eos=None):
"""Generator for sequence-to-sequence tasks that uses tokens.
Expand Down Expand Up @@ -154,7 +182,7 @@ def ende_bpe_token_generator(tmp_dir, train):
train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path)
token_path = os.path.join(tmp_dir, "vocab.bpe.32000")
token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path)
return token_generator(train_path + ".en", train_path + ".de", token_vocab, 1)
return token_generator(train_path + ".en", train_path + ".de", token_vocab, EOS)


_ENDE_TRAIN_DATASETS = [
Expand Down Expand Up @@ -339,6 +367,38 @@ def enfr_character_generator(tmp_dir, train):
return character_generator(data_path + ".lang1", data_path + ".lang2",
character_vocab, EOS)

def parsing_character_generator(tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_%s" % ("train" if train else "dev")
text_filepath = os.path.join(tmp_dir, filename + ".text")
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
return character_generator(text_filepath, tags_filepath, character_vocab, EOS)


def tabbed_parsing_token_generator(tmp_dir, train, prefix, source_vocab_size, target_vocab_size):
"""Generate source and target data from a single file with source/target pairs
separated by a tab character ('\t')"""
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
tmp_dir, "parsing_train.pairs", 0,
prefix + "_source.tokens.vocab.%d" % source_vocab_size,
source_vocab_size)
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
tmp_dir, "parsing_train.pairs", 1,
prefix + "_target.tokens.vocab.%d" % target_vocab_size,
target_vocab_size)
filename = "parsing_%s" % ("train" if train else "dev")
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)


def tabbed_parsing_character_generator(tmp_dir, train):
"""Generate source and target data from a single file with source/target pairs
separated by a tab character ('\t')"""
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_%s" % ("train" if train else "dev")
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)


def parsing_token_generator(tmp_dir, train, vocab_size):
symbolizer_vocab = generator_utils.get_or_generate_vocab(
Expand Down
9 changes: 9 additions & 0 deletions tensor2tensor/models/transformer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,15 @@ def transformer_parsing_base():
return hparams


@registry.register_hparams
def transformer_parsing_ice():
"""Hparams for parsing Icelandic text."""
hparams = transformer_base_single_gpu()
hparams.batch_size = 4096
hparams.shared_embedding_and_softmax_weights = int(False)
return hparams


@registry.register_hparams
def transformer_parsing_big():
"""HParams for parsing on wsj semi-supervised."""
Expand Down
Loading

0 comments on commit 43bfb9f

Please sign in to comment.