-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stream aware outputs #5684
Stream aware outputs #5684
Conversation
dali/pipeline/data/dltensor_obj.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Remove this class entirely.
CI MESSAGE: [19571767]: BUILD STARTED |
acb54b0
to
672f6f9
Compare
CI MESSAGE: [19572115]: BUILD STARTED |
CI MESSAGE: [19572115]: BUILD FAILED |
CI MESSAGE: [19600864]: BUILD STARTED |
CI MESSAGE: [19600864]: BUILD FAILED |
# convert the tensors in the batch to DLPack | ||
batch = [torch.from_dlpack(t) for t in out] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This comment is slightly confusing: "convert to DLPack", but the function is from_dlpack
. Maybe rephrase it a bit to indicate that we're not really converting to DLPack, but DALI->DLPack->Torch (without a copy)
for t in batch: | ||
means[flat_idx] = torch.mean(t) | ||
flat_idx += 1 | ||
# those are meant to overwrite the results if synchronization fails |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe check that we're actually sharing the memory:
batch_a = [torch.from_dlpack(t) for t in out]
batch_b = [torch.from_dlpack(t) for t in out]
# now change batch_b and make sure that batch_a is changed as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
assert jax_array.device() == jax.devices()[0] | ||
assert jax_array.device == jax.devices()[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A breaking change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slipped in... That's a breaking change... in JAX 0.4.31 :\
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
#include "dali/core/static_switch.h" | ||
|
||
namespace dali { | ||
|
||
class DLTensorGraveyard { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add some docs why we need this and how it works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
9b42d28
to
71854e7
Compare
* Add output order handling to exec2 * Add CUDA stream to Outputs and SharedOutputs in Python bindings for Pipeline. * Refactor stream pointer handling in Python Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
71854e7
to
ee58ddf
Compare
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
CI MESSAGE: [19878075]: BUILD STARTED |
CI MESSAGE: [19878075]: BUILD FAILED |
Adjust dlpack tests. Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
CI MESSAGE: [19899374]: BUILD STARTED |
CI MESSAGE: [19899374]: BUILD PASSED |
Category:
New feature (non-breaking change which adds functionality)
Refactoring (Redesign of existing code that doesn't affect functionality)
Description:
This PR adds support for returning the pipeline outputs as DLPack without copying.
Additional information:
Affected modules and functionalities:
This PR adds the following features:
__dlpack__
interface for TensorsIt removes the
_expose_dlpack_capsule
as it was incomplete.Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-4075