Skip to content

Commit

Permalink
Review issues. Add notice to the other peek_image_shape.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Nov 7, 2024
1 parent 4f78784 commit 28b2bc4
Show file tree
Hide file tree
Showing 22 changed files with 64 additions and 42 deletions.
10 changes: 9 additions & 1 deletion dali/operators/decoder/peek_shape/peek_image_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
namespace dali {

DALI_SCHEMA(PeekImageShape)
.DocStr(R"code(Obtains the shape of the encoded image.)code")
.DocStr(R"(Obtains the shape of the encoded image.
This operator returns the shape that an image would have after decoding.
.. note::
This operator is not recommended for use with the dynamic executor (`exec_dynamic=True` in the
pipeline constructor).
Use :meth:`nvidia.dali.pipeline.DataNode.shape()` instead on the decoded images.
)")
.NumInput(1)
.NumOutput(1)
.AddOptionalTypeArg("dtype",
Expand Down
20 changes: 17 additions & 3 deletions dali/operators/generic/shapes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,33 @@
namespace dali {

DALI_SCHEMA(Shapes)
.DocStr(R"(Returns the shapes of tensors in the input batch.)")
.NumInput(1)
.InputDevice(0, InputDevice::Metadata)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric()
.AddOptionalTypeArg("dtype", "Data type to which the sizes are converted.", DALI_INT64)
.DeprecateArgInFavorOf("type", "dtype") // deprecated since 0.27dev
// deprecated since 1.44dev
.Deprecate("", "Use :meth:`nvidia.dali.pipeline.DataNode.shape` instead.");

DALI_SCHEMA(_Shape)
.DocStr(R"(Returns the shapes of tensors in the input batch.
NOT RECOMMENDED FOR NEW CODE. Use `DataNode.shape()` instead.
INTERNAL ONLY; used by DataNode.shape()
)")
.NumInput(1)
.InputDevice(0, InputDevice::Metadata)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric()
.AddOptionalTypeArg("dtype", "Data type to which the sizes are converted.", DALI_INT64)
.DeprecateArgInFavorOf("type", "dtype"); // deprecated since 0.27dev
.MakeDocHidden()
.AddOptionalTypeArg("dtype", "Data type to which the sizes are converted.", DALI_INT64);

DALI_REGISTER_OPERATOR(Shapes, Shapes<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(Shapes, Shapes<GPUBackend>, GPU);
DALI_REGISTER_OPERATOR(_Shape, Shapes<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(_Shape, Shapes<GPUBackend>, GPU);

} // namespace dali
3 changes: 2 additions & 1 deletion dali/operators/imgcodec/peek_image_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ This operator returns the shape that an image would have after decoding.
.. note::
This operator is not recommended for use with the dynamic executor (`exec_dynamic=True` in the
Pipeline). Use `images.shape()` instead on the decoded images.
pipeline constructor).
Use :meth:`nvidia.dali.pipeline.DataNode.shape()` instead on the decoded images.
)")
.NumInput(1)
.NumOutput(1)
Expand Down
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def shape(self, *, dtype=None, device="cpu"):

if device == "cpu":
self._check_gpu2cpu()
return fn.shapes(self, dtype=dtype, device=device)
return fn._shape(self, dtype=dtype, device=device)

def property(self, key, *, device="cpu"):
"""Returns a metadata property associated with a DataNode
Expand Down
11 changes: 6 additions & 5 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,17 @@ class Pipeline(object):
more resistant to uneven execution time of each batch, but it
also consumes more memory for internal buffers.
Specifying a dict:
``{ "cpu_size": x, "gpu_size": y }``
instead of an integer will cause the pipeline to use separated
queues executor, with buffer queue size `x` for cpu stage
and `y` for mixed and gpu stages. It is not supported when both `exec_async`
and `exec_pipelined` are set to `False`.
Executor will buffer cpu and gpu stages separatelly,
and `y` for mixed and gpu stages.
Executor will buffer cpu and gpu stages separately,
and will fill the buffer queues when the first :meth:`run`
is issued.
Separated execution is requires that `exec_async=True`, `exec_pipelined=True` and
`exec_dynamic=False`.
Separated execution requires that ``exec_async=True``, ``exec_pipelined=True`` and
``exec_dynamic=False``.
`exec_async` : bool, optional, default = True
Whether to execute the pipeline asynchronously.
This makes :meth:`run` method
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/auto_aug/test_augmentations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,7 +77,7 @@ def pipeline():
)
extra = {}
if use_shape:
shape = fn.shapes(data)
shape = data.shape()
extra["shape"] = shape[int(modality == "video") :]
output = dali_aug(op_data, num_magnitude_bins=batch_size, magnitude_bin=mag_bin, **extra)
return output, data
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/auto_aug/test_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def pipeline():
batch=True,
layout="FHWC",
)
extra = {} if not use_shape else {"shape": fn.shapes(video)[1:]}
extra = {} if not use_shape else {"shape": video.shape()[1:]}
if device == "gpu":
video = video.gpu()
video = auto_augment.auto_augment(video, policy_name, **extra)
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/auto_aug/test_rand_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def pipeline():
batch=True,
layout="FHWC",
)
extra = {} if not use_shape else {"shape": fn.shapes(video)[1:]}
extra = {} if not use_shape else {"shape": fn.video.shape()[1:]}
extra["monotonic_mag"] = monotonic_mag
if device == "gpu":
video = video.gpu()
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/auto_aug/test_trivial_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def pipeline():
batch=True,
layout="FHWC",
)
extra = {} if not use_shape else {"shape": fn.shapes(video)[1:]}
extra = {} if not use_shape else {"shape": video.shape()[1:]}
if num_magnitude_bins is not None:
extra["num_magnitude_bins"] = num_magnitude_bins
if device == "gpu":
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_1/test_audio_resample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,7 +46,7 @@ def audio_decoder_pipe(device):
audio0 = audio0.gpu()
audio2 = fn.audio_resample(audio0, in_rate=sr0, out_rate=out_sr)
audio3 = fn.audio_resample(audio0, scale=out_sr / sr0)
audio4 = fn.audio_resample(audio0, out_length=fn.shapes(audio1)[0])
audio4 = fn.audio_resample(audio0, out_length=audio1.shape()[0])
return audio1, audio2, audio3, audio4


Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_1/test_coin_flip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,7 +43,7 @@ def shape_gen_f():
lambda: np.zeros(shape_gen_f()), device=device, batch=False
)
inputs += [shape_like_in]
shape_out = dali.fn.shapes(shape_like_in)
shape_out = shape_like_in.shape()
else:
shape_arg = dali.fn.external_source(shape_gen_f, batch=False)
shape_out = shape_arg
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_1/test_normal_distribution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,7 +69,7 @@ def shape_gen_f():
shape_like_in = fn.external_source(
lambda: np.zeros(shape_gen_f()), device=device, batch=False
)
shape_out = fn.shapes(shape_like_in)
shape_out = shape_like_in.shape()
else:
shape_arg = fn.external_source(shape_gen_f, batch=False)
shape_out = shape_arg
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_1/test_pad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -256,7 +256,7 @@ def check_pad_to_square(device="cpu", batch_size=3, ndim=2, num_iter=3):
with pipe:
in_shape = fn.cast(fn.random.uniform(range=(10, 20), shape=(ndim,)), dtype=types.INT32)
in_data = fn.reshape(fn.random.uniform(range=(0.0, 1.0), shape=in_shape), layout="HW")
shape = fn.shapes(in_data, dtype=types.INT32)
shape = in_data.shape(dtype=types.INT32)
h = fn.slice(shape, 0, 1, axes=[0])
w = fn.slice(shape, 1, 1, axes=[0])
side = math.max(h, w)
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_1/test_peek_image_shape.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,7 @@ def run_decode(data_path, out_type):
pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=0)
input, _ = fn.readers.file(file_root=data_path, shard_id=0, num_shards=1, name="reader")
decoded = fn.decoders.image(input, output_type=types.RGB)
decoded_shape = fn.shapes(decoded)
decoded_shape = decoded.shape()
raw_shape = fn.peek_image_shape(input, dtype=out_type)
pipe.set_outputs(decoded, decoded_shape, raw_shape)
pipe.build()
Expand Down
4 changes: 1 addition & 3 deletions dali/test/python/operator_2/test_random_crop_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def pipe():

# Second: Crop twice
images_crop0_B = fn.decoders.image_random_crop(encoded, seed=seed0)
crop_anchor1_B, crop_shape1_B = fn.random_crop_generator(
fn.shapes(images_crop0_B), seed=seed1
)
crop_anchor1_B, crop_shape1_B = fn.random_crop_generator(images_crop0_B.shape(), seed=seed1)
images_crop1_B = fn.slice(
images_crop0_B, start=crop_anchor1_B, shape=crop_shape1_B, axes=[0, 1]
)
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_2/test_random_resized_crop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -170,7 +170,7 @@ def _test_rrc(
input = fn.external_source(
source=generator(batch_size, max_frames, channel_dim, input_type), layout=layout
)
shape = fn.shapes(input)
shape = input.shape()
if device == "gpu":
input = input.gpu()
out = fn.random_resized_crop(
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/operator_2/test_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def build_pipes(
if use_roi:
# Calculate absolute RoI
in_size = fn.slice(
fn.shapes(images_cpu),
images_cpu.shape(),
types.Constant(0, dtype=types.FLOAT, device="cpu"),
types.Constant(dim, dtype=types.FLOAT, device="cpu"),
axes=[0],
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_2/test_resize_seq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -132,7 +132,7 @@ def create_dali_pipe(channel_first, seq_len, interp, dtype, w, h, batch_size=2):
dali_resized_gpu, size_gpu = resize_gpu_out
# extract just HW part from the input shape
ext_size = fn.slice(
fn.cast(fn.shapes(ext), dtype=types.INT32), 2 if channel_first else 1, 2, axes=[0]
fn.cast(ext.shape(), dtype=types.INT32), 2 if channel_first else 1, 2, axes=[0]
)
pipe.set_outputs(dali_resized_cpu, dali_resized_gpu, ext_size, size_cpu, size_gpu)
return pipe
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/operator_2/test_roi_random_crop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,7 +59,7 @@ def data_gen_f():
return batch_gen(max_batch_size, shape_gen_fn)

shape_like_in = dali.fn.external_source(data_gen_f, device="cpu")
in_shape = dali.fn.shapes(shape_like_in, dtype=types.INT32)
in_shape = shape_like_in.shape(dtype=types.INT32)

if random.choice([True, False]):
crop_shape = [(crop_min_extent + crop_max_extent) // 2] * ndim
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,8 +2255,8 @@ def pdef():
enc, _ = fn.readers.file(file_root=jpeg_folder)
img = fn.decoders.image(enc, device="mixed")
peek = fn.peek_image_shape(enc)
shapes_of_gpu = fn.shapes(img, device="cpu")
shapes_of_cpu = fn.shapes(img.cpu())
shapes_of_gpu = fn._shape(img, device="cpu")
shapes_of_cpu = fn._shape(img.cpu())
return peek, shapes_of_gpu, shapes_of_cpu, img.shape(), img.cpu().shape()

pipe = pdef()
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/test_torch_pipeline_rnnt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -374,7 +374,7 @@ def rnnt_train_pipe(
begin, length = fn.nonsilent_region(audio, cutoff_db=-80)
audio = audio[begin : begin + length]

audio_shape = fn.shapes(audio, dtype=types.INT32)
audio_shape = audio.shape(dtype=types.INT32)
orig_audio_len = audio_shape[0]

# If we couldn't move to GPU earlier, do it now
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/use_cases/tensorflow/yolov4/src/dali/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def bbox_adjust_ltrb(bboxes, shape_x, shape_y, pos_x, pos_y):
# Note: this function is a workaround and should be replaced
# with the dedicated operator once available
def select(predicate, if_true, if_false):
true_shape = dali.fn.shapes(if_true, dtype=dali.types.DALIDataType.INT32)
false_shape = dali.fn.shapes(if_false, dtype=dali.types.DALIDataType.INT32)
true_shape = if_true.shape(dtype=dali.types.DALIDataType.INT32)
false_shape = if_false.shape(dtype=dali.types.DALIDataType.INT32)

joined = dali.fn.cat(if_true, if_false)
sh = predicate * true_shape + (1 - predicate) * false_shape
Expand Down

0 comments on commit 28b2bc4

Please sign in to comment.