Support python dataclass containers as input and output in callable TensorFlow 2 graph.
Make sure that tensorflow>=2.0.0
or tensorflow-gpu>=2.0.0
is installed.
$ pip install tf-dataclass
TensorFlow 2
autograph function
supports only nested structures of python tuples as inputs and output.
(Outputs can be also python dictionaries.)
This is inconvenient once we go beyond small hello world cases,
because we have to work with unstructured armfuls of tensors.
This small package is dedicated to fill this gap by letting
@tf.function
decorated functions to take and return pythonic
dataclass
instancies.
import tensorflow as tf
import tf_dataclass
# Batch of sequential features of different length
@tf_dataclass.dataclass
class Sequential:
feature: tf.Tensor # shape = [batch, length, channels], dtype = tf.float32
length: tf.Tensor # shape = [batch], dtype = tf.int32
# Initialize a batch of two sequences of lengths 6 and 4
input = Sequential(
feature = tf.random.normal(shape=[2, 6, 3]),
length = tf.constant([6, 4], dtype=tf.int32),
)
# Define a convolution operator with a stride such that length -> length / stride
@tf_dataclass.function
def convolution(input: Sequential, filters: tf.Tensor, stride: int) -> Sequential:
return Sequential(
feature = tf.nn.conv1d(input.feature, filters, stride),
length = tf.math.floordiv(input.length, stride),
)
# Output is an instance of Sequential with lengths 3 and 2 due to convolution stride = 2
output = convolution(
input = input,
filters = tf.random.normal(shape=[1, 3, 7]),
stride = 2,
)
assert isinstance(output, Sequential)
print(output.length) # -> tf.Tensor([3 2], shape=(2,), dtype=int32)
import tensorflow as tf
import tf_dataclass
@tf_dataclass.dataclass
class DataBatch:
image: tf.Tensor # shape = [batch, height, width, channels], dtype = tf.flaot32
label: tf.Tensor # shape = [batch], dtype = tf.int32
image_file_path: tf.Tensor # shape = [batch], dtype = tf.string
dataset_name: tf.Tensor # shape = [batch], dtype = tf.string
...
@tf_dataclass.function
def train_step(input: DataBatch) -> None:
...
import tensorflow as tf
import tf_dataclass
@tf_dataclass.dataclass
class ModelOutput:
loss_value: tf.Tensor # shape = [batch], dtype = tf.flaot32
label: tf.Tensor # shape = [batch], dtype = tf.int32
prediction: tf.Tensor # shape = [batch], dtype = tf.int32
...
@property
def mean_loss(self) -> tf.Tensor: # shape = [batch], dtype = tf.float32
return tf.reduce_mean(self.loss_value)
@property
def num_true_predictions(self) -> tf.Tensor: # shape = [batch], dtype = tf.int32
return tf.reduce_sum(tf.cast(self.label == self.prediction, dtype=tf.int32))
@property
def num_false_predictions(self) -> tf.Tensor: # shape = [batch], dtype = tf.int32
return tf.reduce_sum(tf.cast(self.label != self.prediction, dtype=tf.int32))
...
@tf_dataclass.function
def get_loss(...) -> ModelOutput:
...
Such containers can be merged along datasets and workers.
import tensorflow as tf
import tf_dataclass
@tf_dataclass.dataclass
class Sequential:
feature: tf.Tensor # shape = [batch, length, channels], dtype = tf.flaot32
length: tf.Tensor # shape = [batch] dtype = tf.int32
def __post_init__(self):
# Verify feature
assert self.feature.dtype == tf.float32
assert len(self.feature.shape) == 3
# Verify length
assert self.length.dtype == tf.int32
assert len(self.length.shape) == 1
# Verify batch size
# Works only in eager mode for better perfomance
assert self.feature.shape[0] == self.length.shape[0]
@property
def batch_size(self) -> tf.Tensor: # shape = [], dtype = tf.int32
return tf.shape(self.feature)[0]
- Support hierarchical composition.
- Support inheritance including multiple one (for free from original
dataclass
). - Highliting, autocomplete, and refactoring from your IDE.
- Import
dataclass
andfunction
fromtf_dataclass
from tf_dataclasses import dataclass, function
- It is mandatory to use return type hints for the function decorated with
@function
. For example,
from typing import Tuple
@dataclass
class MyDataclass:
...
@function
def my_func(...) -> Tuple[tf.Tensor, MyDataclass]:
...
return some_tensor, my_dataclass_instance
-
Type hints for the arguments are optional but recommended.
-
Positional arguments are not currently supported:
For example, for
@function
def my_graph_func(x: ..., y: ...) -> ... :
...
type
my_graph_func(x=x, y=y)
but not
my_graph_func(x, y)
- IDE autocomplete is currently not well-supported, for example, in PyCharm. Solution: use import
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from tf_dataclass import dataclass
in each *.py
file where dataclass
is used.
Dataclasses and their nested structures are simply converted into nested pythonic tuples and back.
This way we wrap given functions such that all inputs and outputs are nested tuples.
Then
@tf.function
is applied. Afterward the graph function is wrapped bach to dataclass form.
Type hints are used in python runtime for the graph creation as temples to pack and unpack dataclass arguments.
- Support
tf.cond
,tf.case
,tf.switch_case
,tf.while_loop
,tf.Optional
, andtf.data.Iterator
. - Support positional arguments.
- Conversion to
tf.nest
structures.