diff --git a/compiler/src/iree/compiler/ConstEval/BUILD.bazel b/compiler/src/iree/compiler/ConstEval/BUILD.bazel index 7147a6919965..998b2c2b0c15 100644 --- a/compiler/src/iree/compiler/ConstEval/BUILD.bazel +++ b/compiler/src/iree/compiler/ConstEval/BUILD.bazel @@ -79,7 +79,6 @@ iree_compiler_cc_library( "//runtime/src/iree/hal", "//runtime/src/iree/hal/drivers/local_task/registration", "//runtime/src/iree/modules/hal", - "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/vm", "//runtime/src/iree/vm/bytecode:module", "@llvm-project//llvm:Support", diff --git a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt index d0fe421a0476..f8679bd7ec27 100644 --- a/compiler/src/iree/compiler/ConstEval/CMakeLists.txt +++ b/compiler/src/iree/compiler/ConstEval/CMakeLists.txt @@ -71,7 +71,6 @@ iree_cc_library( iree::hal iree::hal::drivers::local_task::registration iree::modules::hal - iree::tooling::vm_util iree::vm iree::vm::bytecode::module PUBLIC diff --git a/runtime/src/iree/base/string_builder.c b/runtime/src/iree/base/string_builder.c index 3056ae4e951b..abcbb9fefc61 100644 --- a/runtime/src/iree/base/string_builder.c +++ b/runtime/src/iree/base/string_builder.c @@ -104,6 +104,10 @@ IREE_API_EXPORT iree_status_t iree_string_builder_reserve( return iree_ok_status(); } +IREE_API_EXPORT void iree_string_builder_reset(iree_string_builder_t* builder) { + builder->size = 0; +} + IREE_API_EXPORT iree_status_t iree_string_builder_append_inline( iree_string_builder_t* builder, iree_host_size_t count, char** out_head) { *out_head = NULL; diff --git a/runtime/src/iree/base/string_builder.h b/runtime/src/iree/base/string_builder.h index 007f87cb1d61..48a3fdb3aed9 100644 --- a/runtime/src/iree/base/string_builder.h +++ b/runtime/src/iree/base/string_builder.h @@ -106,6 +106,9 @@ IREE_API_EXPORT IREE_MUST_USE_RESULT char* iree_string_builder_take_storage( IREE_API_EXPORT iree_status_t iree_string_builder_reserve( iree_string_builder_t* builder, iree_host_size_t minimum_capacity); +// Resets the string builder length to 0 without releasing storage. +IREE_API_EXPORT void iree_string_builder_reset(iree_string_builder_t* builder); + // Reserves storage for |count| characters (including NUL) and returns a mutable // pointer in |out_head| for the caller to write the characters. // The pointer is only valid so long as the string builder is initialized and diff --git a/runtime/src/iree/io/BUILD.bazel b/runtime/src/iree/io/BUILD.bazel index 7e8171582e43..7d4c0e1005a0 100644 --- a/runtime/src/iree/io/BUILD.bazel +++ b/runtime/src/iree/io/BUILD.bazel @@ -14,13 +14,10 @@ package( iree_runtime_cc_library( name = "file_handle", - srcs = [ - "file_handle.c", - ], - hdrs = [ - "file_handle.h", - ], + srcs = ["file_handle.c"], + hdrs = ["file_handle.h"], deps = [ + ":memory_stream", ":stream", "//runtime/src/iree/base", "//runtime/src/iree/base/internal", @@ -28,13 +25,31 @@ iree_runtime_cc_library( ) iree_runtime_cc_library( - name = "parameter_index", - srcs = [ - "parameter_index.c", + name = "memory_stream", + srcs = ["memory_stream.c"], + hdrs = ["memory_stream.h"], + deps = [ + ":stream", + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal", ], - hdrs = [ - "parameter_index.h", +) + +iree_runtime_cc_test( + name = "memory_stream_test", + srcs = ["memory_stream_test.cc"], + deps = [ + ":memory_stream", + "//runtime/src/iree/base", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/testing:gtest_main", ], +) + +iree_runtime_cc_library( + name = "parameter_index", + srcs = ["parameter_index.c"], + hdrs = ["parameter_index.h"], deps = [ ":file_handle", "//runtime/src/iree/base", @@ -45,12 +60,8 @@ iree_runtime_cc_library( iree_runtime_cc_library( name = "parameter_index_provider", - srcs = [ - "parameter_index_provider.c", - ], - hdrs = [ - "parameter_index_provider.h", - ], + srcs = ["parameter_index_provider.c"], + hdrs = ["parameter_index_provider.h"], deps = [ ":parameter_index", ":parameter_provider", @@ -62,12 +73,8 @@ iree_runtime_cc_library( iree_runtime_cc_library( name = "parameter_provider", - srcs = [ - "parameter_provider.c", - ], - hdrs = [ - "parameter_provider.h", - ], + srcs = ["parameter_provider.c"], + hdrs = ["parameter_provider.h"], deps = [ "//runtime/src/iree/base", "//runtime/src/iree/hal", @@ -76,12 +83,8 @@ iree_runtime_cc_library( iree_runtime_cc_library( name = "scope_map", - srcs = [ - "scope_map.c", - ], - hdrs = [ - "scope_map.h", - ], + srcs = ["scope_map.c"], + hdrs = ["scope_map.h"], deps = [ ":parameter_index", "//runtime/src/iree/base", @@ -90,24 +93,42 @@ iree_runtime_cc_library( ) iree_runtime_cc_library( - name = "stream", - srcs = [ - "stream.c", + name = "stdio_stream", + srcs = ["stdio_stream.c"], + hdrs = ["stdio_stream.h"], + deps = [ + ":stream", + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal", ], - hdrs = [ - "stream.h", +) + +iree_runtime_cc_library( + name = "stream", + srcs = ["stream.c"], + hdrs = ["stream.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal", ], +) + +iree_runtime_cc_library( + name = "vec_stream", + srcs = ["vec_stream.c"], + hdrs = ["vec_stream.h"], deps = [ + ":stream", "//runtime/src/iree/base", "//runtime/src/iree/base/internal", ], ) iree_runtime_cc_test( - name = "stream_test", - srcs = ["stream_test.cc"], + name = "vec_stream_test", + srcs = ["vec_stream_test.cc"], deps = [ - ":stream", + ":vec_stream", "//runtime/src/iree/base", "//runtime/src/iree/testing:gtest", "//runtime/src/iree/testing:gtest_main", diff --git a/runtime/src/iree/io/CMakeLists.txt b/runtime/src/iree/io/CMakeLists.txt index af41f50a4a1c..cf4fd9f88acd 100644 --- a/runtime/src/iree/io/CMakeLists.txt +++ b/runtime/src/iree/io/CMakeLists.txt @@ -18,12 +18,39 @@ iree_cc_library( SRCS "file_handle.c" DEPS + ::memory_stream ::stream iree::base iree::base::internal PUBLIC ) +iree_cc_library( + NAME + memory_stream + HDRS + "memory_stream.h" + SRCS + "memory_stream.c" + DEPS + ::stream + iree::base + iree::base::internal + PUBLIC +) + +iree_cc_test( + NAME + memory_stream_test + SRCS + "memory_stream_test.cc" + DEPS + ::memory_stream + iree::base + iree::testing::gtest + iree::testing::gtest_main +) + iree_cc_library( NAME parameter_index @@ -82,6 +109,20 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + stdio_stream + HDRS + "stdio_stream.h" + SRCS + "stdio_stream.c" + DEPS + ::stream + iree::base + iree::base::internal + PUBLIC +) + iree_cc_library( NAME stream @@ -95,14 +136,28 @@ iree_cc_library( PUBLIC ) -iree_cc_test( +iree_cc_library( NAME - stream_test + vec_stream + HDRS + "vec_stream.h" SRCS - "stream_test.cc" + "vec_stream.c" DEPS ::stream iree::base + iree::base::internal + PUBLIC +) + +iree_cc_test( + NAME + vec_stream_test + SRCS + "vec_stream_test.cc" + DEPS + ::vec_stream + iree::base iree::testing::gtest iree::testing::gtest_main ) diff --git a/runtime/src/iree/io/file_handle.c b/runtime/src/iree/io/file_handle.c index dc06b78b5c04..ee45ce490065 100644 --- a/runtime/src/iree/io/file_handle.c +++ b/runtime/src/iree/io/file_handle.c @@ -7,6 +7,7 @@ #include "iree/io/file_handle.h" #include "iree/base/internal/atomics.h" +#include "iree/io/memory_stream.h" //===----------------------------------------------------------------------===// // iree_io_file_handle_t diff --git a/runtime/src/iree/io/file_handle.h b/runtime/src/iree/io/file_handle.h index 6279b6973616..6831cbe6c3e1 100644 --- a/runtime/src/iree/io/file_handle.h +++ b/runtime/src/iree/io/file_handle.h @@ -155,6 +155,9 @@ iree_io_file_handle_flush(iree_io_file_handle_t* handle); // iree_io_stream_t utilities //===----------------------------------------------------------------------===// +// TODO(benvanik): remove/rework iree_io_stream_open so that it doesn't pull in +// any implementations by putting callbacks on the file handle constructors. + // Opens a stream from the given |file_handle| at the absolute |file_offset|. // The returned stream will retain the file until it is released. IREE_API_EXPORT iree_status_t iree_io_stream_open( diff --git a/runtime/src/iree/io/memory_stream.c b/runtime/src/iree/io/memory_stream.c new file mode 100644 index 000000000000..046799d8cf5e --- /dev/null +++ b/runtime/src/iree/io/memory_stream.c @@ -0,0 +1,350 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/io/memory_stream.h" + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +static const char* iree_io_stream_seek_mode_string( + iree_io_stream_seek_mode_t seek_mode) { + switch (seek_mode) { + case IREE_IO_STREAM_SEEK_SET: + return "set"; + case IREE_IO_STREAM_SEEK_FROM_CURRENT: + return "from-current"; + case IREE_IO_STREAM_SEEK_FROM_END: + return "from-end"; + default: + return "?"; + } +} + +// Validates that at least |access_length| bytes are available at the current +// stream offset. If the optional |out_available_length| is provided the +// available length will be returned matching the requested |access_length| or +// the maximum remaining length to the caller with an OK status even if the full +// |access_length| is not available and otherwise an error is returned. +static iree_status_t iree_io_stream_validate_fixed_range( + iree_io_stream_pos_t stream_offset, iree_io_stream_pos_t stream_length, + iree_io_stream_pos_t access_length, + iree_io_stream_pos_t* out_available_length) { + if (out_available_length) *out_available_length = 0; + + iree_io_stream_pos_t remaining_length = stream_length - stream_offset; + if (access_length > remaining_length) { + // Access exceeds remaining length. + if (out_available_length) { + // Let caller know how much is available and return OK so they can use it. + *out_available_length = remaining_length; + return iree_ok_status(); + } + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "access to range [%" PRIu64 ", %" PRIu64 + ") (%" PRIu64 + " bytes) out of range; stream offset %" PRIu64 + " and length %" PRIu64 " insufficient", + stream_offset, stream_offset + access_length, + access_length, stream_offset, stream_length); + } + + if (out_available_length) *out_available_length = access_length; + return iree_ok_status(); +} + +// Applies a seek operation with |seek_mode| and |seek_offset| against a stream. +// |out_stream_offset| will contain the position after the seek or an error will +// be returned if the seek could not be completed. +static iree_status_t iree_io_stream_apply_fixed_seek( + iree_io_stream_pos_t stream_offset, iree_io_stream_pos_t stream_length, + iree_io_stream_seek_mode_t seek_mode, iree_io_stream_pos_t seek_offset, + iree_io_stream_pos_t* out_stream_offset) { + *out_stream_offset = stream_offset; + + iree_io_stream_pos_t new_offset = stream_offset; + switch (seek_mode) { + case IREE_IO_STREAM_SEEK_SET: + new_offset = seek_offset; + break; + case IREE_IO_STREAM_SEEK_FROM_CURRENT: + new_offset = stream_offset + seek_offset; + break; + case IREE_IO_STREAM_SEEK_FROM_END: + new_offset = stream_length + seek_offset; + break; + default: + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unrecognized seek mode %u", (uint32_t)seek_mode); + } + + if (new_offset < 0 || new_offset > stream_length) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "seek %s offset %" PRIi64 + " out of stream bounds; expected 0 <= %" PRIi64 + " < %" PRIi64, + iree_io_stream_seek_mode_string(seek_mode), + seek_offset, new_offset, stream_length); + } + + *out_stream_offset = new_offset; + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_io_memory_stream_t +//===----------------------------------------------------------------------===// + +typedef struct iree_io_memory_stream_t { + iree_io_stream_t base; + iree_allocator_t host_allocator; + iree_io_memory_stream_release_callback_t release_callback; + iree_io_stream_pos_t offset; + iree_io_stream_pos_t length; + uint8_t* contents; +} iree_io_memory_stream_t; + +static const iree_io_stream_vtable_t iree_io_memory_stream_vtable; + +static iree_io_memory_stream_t* iree_io_memory_stream_cast( + iree_io_stream_t* IREE_RESTRICT base_stream) { + return (iree_io_memory_stream_t*)base_stream; +} + +IREE_API_EXPORT iree_status_t iree_io_memory_stream_wrap( + iree_io_stream_mode_t mode, iree_byte_span_t contents, + iree_io_memory_stream_release_callback_t release_callback, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)contents.data_length); + + iree_io_memory_stream_t* stream = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*stream), (void**)&stream)); + iree_atomic_ref_count_init(&stream->base.ref_count); + stream->base.vtable = &iree_io_memory_stream_vtable; + stream->base.mode = mode; + stream->host_allocator = host_allocator; + stream->release_callback = release_callback; + stream->offset = 0; + stream->length = contents.data_length; + stream->contents = contents.data; + + *out_stream = &stream->base; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_io_memory_stream_destroy( + iree_io_stream_t* IREE_RESTRICT base_stream) { + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + iree_allocator_t host_allocator = stream->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + if (stream->release_callback.fn) { + stream->release_callback.fn(stream->release_callback.user_data, + base_stream); + } + + iree_allocator_free(host_allocator, stream); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_io_stream_pos_t iree_io_memory_stream_offset( + iree_io_stream_t* base_stream) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + return stream->offset; +} + +static iree_io_stream_pos_t iree_io_memory_stream_length( + iree_io_stream_t* base_stream) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + return stream->length; +} + +static iree_status_t iree_io_memory_stream_seek( + iree_io_stream_t* base_stream, iree_io_stream_seek_mode_t seek_mode, + iree_io_stream_pos_t seek_offset) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_io_stream_apply_fixed_seek( + stream->offset, stream->length, seek_mode, seek_offset, &stream->offset); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_memory_stream_read( + iree_io_stream_t* base_stream, iree_host_size_t buffer_capacity, + void* buffer, iree_host_size_t* out_buffer_length) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(buffer); + if (out_buffer_length) *out_buffer_length = 0; + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_io_stream_pos_t read_length = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, + buffer_capacity, &read_length)); + if (!out_buffer_length && read_length != buffer_capacity) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "read of range [%" PRIu64 ", %" PRIu64 ") (%" PRIu64 + " bytes) out of range; stream offset %" PRIu64 + " and length %" PRIu64 " insufficient", + stream->offset, stream->offset + buffer_capacity, + (iree_io_stream_pos_t)buffer_capacity, stream->offset, + stream->length)); + } + + memcpy(buffer, stream->contents + stream->offset, + (iree_host_size_t)read_length); + stream->offset += read_length; + + if (out_buffer_length) *out_buffer_length = (iree_host_size_t)read_length; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_memory_stream_write(iree_io_stream_t* base_stream, + iree_host_size_t buffer_length, + const void* buffer) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(buffer); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, + buffer_length, NULL)); + + memcpy(stream->contents + stream->offset, buffer, buffer_length); + stream->offset += buffer_length; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_memory_stream_fill( + iree_io_stream_t* base_stream, iree_io_stream_pos_t count, + const void* pattern, iree_host_size_t pattern_length) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(pattern); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_io_stream_pos_t access_length = count * pattern_length; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, + access_length, NULL)); + + iree_status_t status = iree_ok_status(); + uint8_t* data_ptr = stream->contents + stream->offset; + switch (pattern_length) { + case 1: { + uint8_t* data = (uint8_t*)data_ptr; + uint8_t value_bits = *(const uint8_t*)(pattern); + memset(data, value_bits, count); + break; + } + case 2: { + uint16_t* data = (uint16_t*)data_ptr; + uint16_t value_bits = *(const uint16_t*)(pattern); + for (iree_device_size_t i = 0; i < count; ++i) { + iree_unaligned_store(&data[i], value_bits); + } + break; + } + case 4: { + uint32_t* data = (uint32_t*)data_ptr; + uint32_t value_bits = *(const uint32_t*)(pattern); + for (iree_device_size_t i = 0; i < count; ++i) { + iree_unaligned_store(&data[i], value_bits); + } + break; + } + case 8: { + uint64_t* data = (uint64_t*)data_ptr; + uint64_t value_bits = *(const uint64_t*)(pattern); + for (iree_device_size_t i = 0; i < count; ++i) { + iree_unaligned_store(&data[i], value_bits); + } + break; + } + default: + IREE_ASSERT_UNREACHABLE("verified in iree_io_stream_fill"); + break; + } + if (iree_status_is_ok(status)) { + stream->offset += access_length; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_memory_stream_map_read( + iree_io_stream_t* base_stream, iree_host_size_t length, + iree_const_byte_span_t* out_span) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(out_span); + *out_span = iree_const_byte_span_empty(); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, + length, NULL)); + + *out_span = + iree_make_const_byte_span(stream->contents + stream->offset, length); + stream->offset += length; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_memory_stream_map_write( + iree_io_stream_t* base_stream, iree_host_size_t length, + iree_byte_span_t* out_span) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(out_span); + *out_span = iree_byte_span_empty(); + iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, + length, NULL)); + + *out_span = iree_make_byte_span(stream->contents + stream->offset, length); + stream->offset += length; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static const iree_io_stream_vtable_t iree_io_memory_stream_vtable = { + .destroy = iree_io_memory_stream_destroy, + .offset = iree_io_memory_stream_offset, + .length = iree_io_memory_stream_length, + .seek = iree_io_memory_stream_seek, + .read = iree_io_memory_stream_read, + .write = iree_io_memory_stream_write, + .fill = iree_io_memory_stream_fill, + .map_read = iree_io_memory_stream_map_read, + .map_write = iree_io_memory_stream_map_write, +}; diff --git a/runtime/src/iree/io/memory_stream.h b/runtime/src/iree/io/memory_stream.h new file mode 100644 index 000000000000..f55b947eea24 --- /dev/null +++ b/runtime/src/iree/io/memory_stream.h @@ -0,0 +1,52 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_IO_MEMORY_STREAM_H_ +#define IREE_IO_MEMORY_STREAM_H_ + +#include "iree/base/api.h" +#include "iree/io/stream.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_io_memory_stream_t +//===----------------------------------------------------------------------===// + +typedef void(IREE_API_PTR* iree_io_memory_stream_release_fn_t)( + void* user_data, iree_io_stream_t* stream); + +// A callback issued when a memory stream is released. +typedef struct { + // Callback function pointer. + iree_io_memory_stream_release_fn_t fn; + // User data passed to the callback function. Unowned. + void* user_data; +} iree_io_memory_stream_release_callback_t; + +// Returns a no-op file release callback that implies that no cleanup is +// required. +static inline iree_io_memory_stream_release_callback_t +iree_io_memory_stream_release_callback_null(void) { + iree_io_memory_stream_release_callback_t callback = {NULL, NULL}; + return callback; +} + +// Wraps a fixed-size host memory allocation |contents| in a stream. +// |release_callback| can be used to receive a callback when the stream is +// destroyed and the reference to the contents is no longer required. +IREE_API_EXPORT iree_status_t iree_io_memory_stream_wrap( + iree_io_stream_mode_t mode, iree_byte_span_t contents, + iree_io_memory_stream_release_callback_t release_callback, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_IO_MEMORY_STREAM_H_ diff --git a/runtime/src/iree/io/stream_test.cc b/runtime/src/iree/io/memory_stream_test.cc similarity index 99% rename from runtime/src/iree/io/stream_test.cc rename to runtime/src/iree/io/memory_stream_test.cc index 94cc554a35eb..ebd519275380 100644 --- a/runtime/src/iree/io/stream_test.cc +++ b/runtime/src/iree/io/memory_stream_test.cc @@ -4,7 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/io/stream.h" +#include "iree/io/memory_stream.h" #include #include @@ -23,10 +23,6 @@ using testing::ElementsAre; using testing::ElementsAreArray; using testing::Eq; -//===----------------------------------------------------------------------===// -// iree_io_memory_stream_t -//===----------------------------------------------------------------------===// - TEST(MemoryStreamTest, Wrap) { uint8_t data[5] = {0, 1, 2, 3, 4}; iree_io_stream_t* stream = NULL; diff --git a/runtime/src/iree/io/stdio_stream.c b/runtime/src/iree/io/stdio_stream.c new file mode 100644 index 000000000000..209be50dcfe2 --- /dev/null +++ b/runtime/src/iree/io/stdio_stream.c @@ -0,0 +1,407 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/io/stdio_stream.h" + +#include +#include +#include + +#if defined(IREE_PLATFORM_WINDOWS) + +#include +#include + +#define IREE_SET_BINARY_MODE(handle) _setmode(_fileno(handle), O_BINARY) + +#define iree_fseek _fseeki64 +#define iree_ftell _ftelli64 + +#else + +#define IREE_SET_BINARY_MODE(handle) ((void)0) + +#if _FILE_OFFSET_BITS == 64 || _POSIX_C_SOURCE >= 200112L +#define iree_fseek fseeko +#define iree_ftell ftello +#else +#define iree_fseek fseek +#define iree_ftell ftell +#endif // 64-bit file offset support + +#endif // IREE_PLATFORM_WINDOWS + +// Makes a new status message ala iree_make_status but includes the error number +// and optional string message on platforms that support it. +#if defined(IREE_PLATFORM_WINDOWS) +#define iree_make_stdio_status(message) \ + iree_make_status(iree_status_code_from_errno(errno), message " (%d: %s)", \ + errno, strerror(errno)) +#define iree_make_stdio_statusf(format, ...) \ + iree_make_status(iree_status_code_from_errno(errno), format " (%d: %s)", \ + __VA_ARGS__, errno, strerror(errno)) +#else +#define iree_make_stdio_status(...) \ + iree_make_status(IREE_STATUS_UNKNOWN, __VA_ARGS__) +#define iree_make_stdio_statusf iree_make_stdio_status +#endif // IREE_PLATFORM_* + +//===----------------------------------------------------------------------===// +// iree_io_stdio_stream_t +//===----------------------------------------------------------------------===// + +#define IREE_MAX_PATH ((size_t)2048) + +typedef struct iree_io_stdio_stream_t { + iree_io_stream_t base; + iree_allocator_t host_allocator; + FILE* handle; + bool owns_handle; +} iree_io_stdio_stream_t; + +static const iree_io_stream_vtable_t iree_io_stdio_stream_vtable; + +static iree_io_stdio_stream_t* iree_io_stdio_stream_cast( + iree_io_stream_t* IREE_RESTRICT base_stream) { + return (iree_io_stdio_stream_t*)base_stream; +} + +IREE_API_EXPORT iree_status_t iree_io_stdio_stream_wrap( + iree_io_stream_mode_t mode, FILE* handle, bool owns_handle, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_io_stdio_stream_t* stream = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*stream), (void**)&stream)); + iree_atomic_ref_count_init(&stream->base.ref_count); + stream->base.vtable = &iree_io_stdio_stream_vtable; + stream->base.mode = mode; + stream->host_allocator = host_allocator; + stream->handle = handle; + stream->owns_handle = owns_handle; + + *out_stream = &stream->base; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +#if IREE_FILE_IO_ENABLE +IREE_API_EXPORT iree_status_t iree_io_stdio_stream_open( + iree_io_stdio_stream_mode_t mode, iree_string_view_t path, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size); + + iree_io_stream_mode_t stream_mode = IREE_IO_STREAM_MODE_SEEKABLE; + if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_READ)) { + stream_mode |= IREE_IO_STREAM_MODE_READABLE; + } + if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_WRITE)) { + stream_mode |= IREE_IO_STREAM_MODE_WRITABLE; + } + + // NOTE: not all implementations support all mode flags and this may have + // different behavior. We should paper over it here but don't today given the + // limited usage of this and our intent to rewrite it all using + // platform-optimal APIs instead of stdio. + char fopen_mode[16] = {0}; + if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_READ | + IREE_IO_STDIO_STREAM_MODE_WRITE | + IREE_IO_STDIO_STREAM_MODE_APPEND)) { + strcat(fopen_mode, "a+"); + } else if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_READ | + IREE_IO_STDIO_STREAM_MODE_WRITE | + IREE_IO_STDIO_STREAM_MODE_DISCARD)) { + strcat(fopen_mode, "w+"); + } else if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_READ | + IREE_IO_STDIO_STREAM_MODE_WRITE)) { + strcat(fopen_mode, "r+"); + } else if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_WRITE | + IREE_IO_STDIO_STREAM_MODE_APPEND)) { + strcat(fopen_mode, "a"); + } else if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_WRITE)) { + strcat(fopen_mode, "w"); + } else if (iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_READ)) { + strcat(fopen_mode, "r"); + } + if (iree_all_bits_set(stream_mode, IREE_IO_STREAM_MODE_WRITABLE) && + !iree_all_bits_set(mode, IREE_IO_STDIO_STREAM_MODE_DISCARD)) { + // If writable and not discard then the file must not exist. + // TODO(benvanik): actually observe this; the C11 spec says `x` is supported + // but at least on MSVC's CRT it isn't. We can emulate this with stat and + // such but today we don't have any uses that require it. + // strcat(fopen_mode, "x"); + } + // Force binary mode (avoid Windows CRLF expansion). + strcat(fopen_mode, "b"); + + // Since we stack alloc the path we want to keep it reasonable. + // We could heap allocate instead but a few thousand chars is quite long and + // since Windows doesn't support more than ~256 we generally keep them short + // anyway. + if (path.size > IREE_MAX_PATH) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "path exceeds reasonable maximum (%" PRIhsz + " > %" PRIhsz ")", + path.size, IREE_MAX_PATH); + } + char* fopen_path = (char*)iree_alloca(path.size + 1); + memcpy(fopen_path, path.data, path.size); + fopen_path[path.size] = 0; // NUL + + iree_status_t status = iree_ok_status(); + FILE* handle = fopen(fopen_path, fopen_mode); + if (handle == NULL) { + // NOTE: for some crazy reason errno isn't set by all implementations. We + // know it is on Windows but currently leave all others to :shrug:. We could + // check C library implementations and versions to make this better. + status = iree_make_stdio_statusf("unable to open file `%.*s` with mode %d", + (int)path.size, path.data, mode); + } + + iree_io_stream_t* stream = NULL; + if (iree_status_is_ok(status)) { + status = iree_io_stdio_stream_wrap( + stream_mode, handle, /*owns_handle=*/true, host_allocator, &stream); + } + + if (iree_status_is_ok(status)) { + *out_stream = stream; + } else { + if (stream) { + iree_io_stream_release(stream); + } else { + fclose(handle); + } + } + IREE_TRACE_ZONE_END(z0); + return status; +} +#else +IREE_API_EXPORT iree_status_t iree_io_stdio_stream_open( + iree_io_stdio_stream_mode_t mode, iree_string_view_t path, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "file support has been compiled out of this binary; " + "set IREE_FILE_IO_ENABLE=1 to include it"); +} +#endif // IREE_FILE_IO_ENABLE + +static void iree_io_stdio_stream_destroy( + iree_io_stream_t* IREE_RESTRICT base_stream) { + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + iree_allocator_t host_allocator = stream->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + fflush(stream->handle); + if (stream->owns_handle) { + fclose(stream->handle); + } + + iree_allocator_free(host_allocator, stream); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_io_stream_pos_t iree_io_stdio_stream_offset( + iree_io_stream_t* base_stream) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + int64_t pos = iree_ftell(stream->handle); + if (pos == -1) return 0; + return (iree_io_stream_pos_t)pos; +} + +static iree_io_stream_pos_t iree_io_stdio_stream_length( + iree_io_stream_t* base_stream) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + + // Capture original offset so we can return to it. + int64_t origin = iree_ftell(stream->handle); + if (origin == -1) return 0; + + // Seek to the end of the file. + if (iree_fseek(stream->handle, 0, SEEK_END) != 0) return 0; + + // Query the position, telling us the total file length in bytes. + int64_t length = iree_ftell(stream->handle); + if (length == -1) return 0; + + // Seek back to the file origin. + if (iree_fseek(stream->handle, origin, SEEK_SET) != 0) return 0; + + return (iree_io_stream_pos_t)length; +} + +static iree_status_t iree_io_stdio_stream_seek( + iree_io_stream_t* base_stream, iree_io_stream_seek_mode_t seek_mode, + iree_io_stream_pos_t seek_offset) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + int origin = 0; + switch (seek_mode) { + case IREE_IO_STREAM_SEEK_SET: + origin = SEEK_SET; + break; + case IREE_IO_STREAM_SEEK_FROM_CURRENT: + origin = SEEK_CUR; + break; + case IREE_IO_STREAM_SEEK_FROM_END: + origin = SEEK_END; + break; + default: + status = + iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "invalid seek mode"); + break; + } + + if (iree_status_is_ok(status)) { + if (iree_fseek(stream->handle, seek_offset, origin) != 0) { + status = iree_make_stdio_status("failed to seek"); + } + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_stdio_stream_read( + iree_io_stream_t* base_stream, iree_host_size_t buffer_capacity, + void* buffer, iree_host_size_t* out_buffer_length) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(buffer); + if (out_buffer_length) *out_buffer_length = 0; + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + // Read in ~2GB chunks - even platforms with 64-bit support sometimes don't + // like read lengths >2GB and there's not really any benefit to doing 12GB + // reads in one go anyway. + iree_host_size_t bytes_read = 0; + while (bytes_read < buffer_capacity) { + iree_host_size_t chunk_size = + iree_min(buffer_capacity - bytes_read, INT_MAX); + iree_host_size_t read_size = + fread((uint8_t*)buffer + bytes_read, 1, chunk_size, stream->handle); + if (read_size != chunk_size) { + // Failed to read chunk - may have reached EOF. + if (feof(stream->handle)) { + if (out_buffer_length) { + // Ok to hit EOF; just return what's valid. + *out_buffer_length = bytes_read + read_size; + } else { + status = iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "end-of-file encountered during read"); + } + } else { + status = iree_make_stdio_status("read failed"); + } + break; + } + bytes_read += read_size; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_stdio_stream_write(iree_io_stream_t* base_stream, + iree_host_size_t buffer_length, + const void* buffer) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(buffer); + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + // Write in ~2GB chunks - even platforms with 64-bit support sometimes don't + // like write lengths >2GB and there's not really any benefit to doing 12GB + // writes in one go anyway. + iree_host_size_t bytes_written = 0; + while (bytes_written < buffer_length) { + iree_host_size_t chunk_size = + iree_min(buffer_length - bytes_written, INT_MAX); + iree_host_size_t write_size = + fwrite((uint8_t*)buffer + bytes_written, 1, chunk_size, stream->handle); + if (write_size != chunk_size) { + // Failed to write chunk; likely exhausted disk space. + status = iree_make_stdio_status( + "write failed, possibly out of disk space or device lost"); + break; + } + bytes_written += write_size; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_stdio_stream_fill( + iree_io_stream_t* base_stream, iree_io_stream_pos_t count, + const void* pattern, iree_host_size_t pattern_length) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(pattern); + iree_io_stdio_stream_t* stream = iree_io_stdio_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + // There's not an stdio API for filling contents. When using platform APIs we + // can extend files when the pattern is zeros to quickly do things but here + // we just bash fwrite. We could buffer up a reasonable size (4096 etc) of the + // pattern repeating but this shouldn't be performance critical. + for (iree_io_stream_pos_t i = 0; i < count; ++i) { + if (fwrite(pattern, pattern_length, 1, stream->handle) != pattern_length) { + status = iree_make_stdio_status( + "write failed, possibly out of disk space or device lost"); + break; + } + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_stdio_stream_map_read( + iree_io_stream_t* stream, iree_host_size_t length, + iree_const_byte_span_t* out_span) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "stdio streams do not support mapping"); +} + +static iree_status_t iree_io_stdio_stream_map_write( + iree_io_stream_t* stream, iree_host_size_t length, + iree_byte_span_t* out_span) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "stdio streams do not support mapping"); +} + +static const iree_io_stream_vtable_t iree_io_stdio_stream_vtable = { + .destroy = iree_io_stdio_stream_destroy, + .offset = iree_io_stdio_stream_offset, + .length = iree_io_stdio_stream_length, + .seek = iree_io_stdio_stream_seek, + .read = iree_io_stdio_stream_read, + .write = iree_io_stdio_stream_write, + .fill = iree_io_stdio_stream_fill, + .map_read = iree_io_stdio_stream_map_read, + .map_write = iree_io_stdio_stream_map_write, +}; diff --git a/runtime/src/iree/io/stdio_stream.h b/runtime/src/iree/io/stdio_stream.h new file mode 100644 index 000000000000..ddb277332322 --- /dev/null +++ b/runtime/src/iree/io/stdio_stream.h @@ -0,0 +1,58 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_IO_STDIO_STREAM_H_ +#define IREE_IO_STDIO_STREAM_H_ + +#include + +#include "iree/base/api.h" +#include "iree/io/stream.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_io_stdio_stream_t +//===----------------------------------------------------------------------===// + +// TODO(benvanik): rework this to handle optional features like EXISTS and such. +// fopen support is not very well specced and some impls don't support all modes +// and we'll have to emulate. +// +// Roughly aligns to the fopen modes: +// READ: open existing file for reading. +// WRITE: open existing file for writing. +// READ|WRITE: open existing file for reading/writing. +// WRITE|DISCARD: open new file (discarding existing) for writing. +// WRITE|APPEND: open existing file for writing, start at end. +// READ|WRITE|DISCARD: open new file (discarding existing) for reading/writing. +// READ|WRITE|APPEND: open existing file for reading/writing, start at end. +enum iree_io_stdio_stream_mode_bits_t { + IREE_IO_STDIO_STREAM_MODE_DISCARD = 1u << 0, + IREE_IO_STDIO_STREAM_MODE_READ = 1u << 1, + IREE_IO_STDIO_STREAM_MODE_WRITE = 1u << 2, + IREE_IO_STDIO_STREAM_MODE_APPEND = 1u << 3, +}; +typedef uint32_t iree_io_stdio_stream_mode_t; + +// Wraps an existing stdio |handle|. If |owns_handle| is true then the +// file will be closed when the stream is destroyed. +IREE_API_EXPORT iree_status_t iree_io_stdio_stream_wrap( + iree_io_stream_mode_t mode, FILE* handle, bool owns_handle, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream); + +// Opens a file at |path| using fopen with the mode determined by |mode|. +IREE_API_EXPORT iree_status_t iree_io_stdio_stream_open( + iree_io_stdio_stream_mode_t mode, iree_string_view_t path, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_IO_STDIO_STREAM_H_ diff --git a/runtime/src/iree/io/stream.c b/runtime/src/iree/io/stream.c index 08471a78e051..6c59d98f8815 100644 --- a/runtime/src/iree/io/stream.c +++ b/runtime/src/iree/io/stream.c @@ -63,76 +63,6 @@ static iree_status_t iree_io_stream_validate_mode( return iree_ok_status(); } -// Validates that at least |access_length| bytes are available at the current -// stream offset. If the optional |out_available_length| is provided the -// available length will be returned matching the requested |access_length| or -// the maximum remaining length to the caller with an OK status even if the full -// |access_length| is not available and otherwise an error is returned. -static iree_status_t iree_io_stream_validate_fixed_range( - iree_io_stream_pos_t stream_offset, iree_io_stream_pos_t stream_length, - iree_io_stream_pos_t access_length, - iree_io_stream_pos_t* out_available_length) { - if (out_available_length) *out_available_length = 0; - - iree_io_stream_pos_t remaining_length = stream_length - stream_offset; - if (access_length > remaining_length) { - // Access exceeds remaining length. - if (out_available_length) { - // Let caller know how much is available and return OK so they can use it. - *out_available_length = remaining_length; - return iree_ok_status(); - } - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "access to range [%" PRIu64 ", %" PRIu64 - ") (%" PRIu64 - " bytes) out of range; stream offset %" PRIu64 - " and length %" PRIu64 " insufficient", - stream_offset, stream_offset + access_length, - access_length, stream_offset, stream_length); - } - - if (out_available_length) *out_available_length = access_length; - return iree_ok_status(); -} - -// Applies a seek operation with |seek_mode| and |seek_offset| against a stream. -// |out_stream_offset| will contain the position after the seek or an error will -// be returned if the seek could not be completed. -static iree_status_t iree_io_stream_apply_fixed_seek( - iree_io_stream_pos_t stream_offset, iree_io_stream_pos_t stream_length, - iree_io_stream_seek_mode_t seek_mode, iree_io_stream_pos_t seek_offset, - iree_io_stream_pos_t* out_stream_offset) { - *out_stream_offset = stream_offset; - - iree_io_stream_pos_t new_offset = stream_offset; - switch (seek_mode) { - case IREE_IO_STREAM_SEEK_SET: - new_offset = seek_offset; - break; - case IREE_IO_STREAM_SEEK_FROM_CURRENT: - new_offset = stream_offset + seek_offset; - break; - case IREE_IO_STREAM_SEEK_FROM_END: - new_offset = stream_length + seek_offset; - break; - default: - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "unrecognized seek mode %u", (uint32_t)seek_mode); - } - - if (new_offset < 0 || new_offset > stream_length) { - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "seek %s offset %" PRIi64 - " out of stream bounds; expected 0 <= %" PRIi64 - " < %" PRIi64, - iree_io_stream_seek_mode_string(seek_mode), - seek_offset, new_offset, stream_length); - } - - *out_stream_offset = new_offset; - return iree_ok_status(); -} - //===----------------------------------------------------------------------===// // iree_io_stream_t //===----------------------------------------------------------------------===// @@ -249,6 +179,17 @@ iree_io_stream_write(iree_io_stream_t* stream, iree_host_size_t buffer_length, return status; } +IREE_API_EXPORT iree_status_t +iree_io_stream_write_char(iree_io_stream_t* stream, char c) { + return iree_io_stream_write(stream, sizeof(c), &c); +} + +IREE_API_EXPORT iree_status_t iree_io_stream_write_string( + iree_io_stream_t* stream, iree_string_view_t value) { + if (!value.size) return iree_ok_status(); + return iree_io_stream_write(stream, value.size, value.data); +} + IREE_API_EXPORT iree_status_t iree_io_stream_fill(iree_io_stream_t* stream, iree_io_stream_pos_t count, const void* pattern, iree_host_size_t pattern_length) { @@ -351,258 +292,3 @@ IREE_API_EXPORT iree_status_t iree_io_stream_copy( IREE_TRACE_ZONE_END(z0); return status; } - -//===----------------------------------------------------------------------===// -// iree_io_memory_stream_t -//===----------------------------------------------------------------------===// - -typedef struct iree_io_memory_stream_t { - iree_io_stream_t base; - iree_allocator_t host_allocator; - iree_io_memory_stream_release_callback_t release_callback; - iree_io_stream_pos_t offset; - iree_io_stream_pos_t length; - uint8_t* contents; -} iree_io_memory_stream_t; - -static const iree_io_stream_vtable_t iree_io_memory_stream_vtable; - -static iree_io_memory_stream_t* iree_io_memory_stream_cast( - iree_io_stream_t* IREE_RESTRICT base_stream) { - return (iree_io_memory_stream_t*)base_stream; -} - -IREE_API_EXPORT iree_status_t iree_io_memory_stream_wrap( - iree_io_stream_mode_t mode, iree_byte_span_t contents, - iree_io_memory_stream_release_callback_t release_callback, - iree_allocator_t host_allocator, iree_io_stream_t** out_stream) { - IREE_ASSERT_ARGUMENT(out_stream); - *out_stream = NULL; - IREE_TRACE_ZONE_BEGIN(z0); - IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)contents.data_length); - - iree_io_memory_stream_t* stream = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - iree_allocator_malloc(host_allocator, sizeof(*stream), (void**)&stream)); - iree_atomic_ref_count_init(&stream->base.ref_count); - stream->base.vtable = &iree_io_memory_stream_vtable; - stream->base.mode = mode; - stream->host_allocator = host_allocator; - stream->release_callback = release_callback; - stream->offset = 0; - stream->length = contents.data_length; - stream->contents = contents.data; - - *out_stream = &stream->base; - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static void iree_io_memory_stream_destroy( - iree_io_stream_t* IREE_RESTRICT base_stream) { - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - iree_allocator_t host_allocator = stream->host_allocator; - IREE_TRACE_ZONE_BEGIN(z0); - - if (stream->release_callback.fn) { - stream->release_callback.fn(stream->release_callback.user_data, - base_stream); - } - - iree_allocator_free(host_allocator, stream); - - IREE_TRACE_ZONE_END(z0); -} - -static iree_io_stream_pos_t iree_io_memory_stream_offset( - iree_io_stream_t* base_stream) { - IREE_ASSERT_ARGUMENT(base_stream); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - return stream->offset; -} - -static iree_io_stream_pos_t iree_io_memory_stream_length( - iree_io_stream_t* base_stream) { - IREE_ASSERT_ARGUMENT(base_stream); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - return stream->length; -} - -static iree_status_t iree_io_memory_stream_seek( - iree_io_stream_t* base_stream, iree_io_stream_seek_mode_t seek_mode, - iree_io_stream_pos_t offset) { - IREE_ASSERT_ARGUMENT(base_stream); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_status_t status = iree_io_stream_apply_fixed_seek( - stream->offset, stream->length, seek_mode, offset, &stream->offset); - - IREE_TRACE_ZONE_END(z0); - return status; -} - -static iree_status_t iree_io_memory_stream_read( - iree_io_stream_t* base_stream, iree_host_size_t buffer_capacity, - void* buffer, iree_host_size_t* out_buffer_length) { - IREE_ASSERT_ARGUMENT(base_stream); - IREE_ASSERT_ARGUMENT(buffer); - if (out_buffer_length) *out_buffer_length = 0; - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_io_stream_pos_t read_length = 0; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, - buffer_capacity, &read_length)); - if (!out_buffer_length && read_length != buffer_capacity) { - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "read of range [%" PRIu64 ", %" PRIu64 ") (%" PRIu64 - " bytes) out of range; stream offset %" PRIu64 - " and length %" PRIu64 " insufficient", - stream->offset, stream->offset + buffer_capacity, - (iree_io_stream_pos_t)buffer_capacity, stream->offset, - stream->length)); - } - - memcpy(buffer, stream->contents + stream->offset, - (iree_host_size_t)read_length); - stream->offset += read_length; - - if (out_buffer_length) *out_buffer_length = (iree_host_size_t)read_length; - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_io_memory_stream_write(iree_io_stream_t* base_stream, - iree_host_size_t buffer_length, - const void* buffer) { - IREE_ASSERT_ARGUMENT(base_stream); - IREE_ASSERT_ARGUMENT(buffer); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, - buffer_length, NULL)); - - memcpy(stream->contents + stream->offset, buffer, buffer_length); - stream->offset += buffer_length; - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_io_memory_stream_fill( - iree_io_stream_t* base_stream, iree_io_stream_pos_t count, - const void* pattern, iree_host_size_t pattern_length) { - IREE_ASSERT_ARGUMENT(base_stream); - IREE_ASSERT_ARGUMENT(pattern); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_io_stream_pos_t access_length = count * pattern_length; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, - access_length, NULL)); - - iree_status_t status = iree_ok_status(); - uint8_t* data_ptr = stream->contents + stream->offset; - switch (pattern_length) { - case 1: { - uint8_t* data = (uint8_t*)data_ptr; - uint8_t value_bits = *(const uint8_t*)(pattern); - memset(data, value_bits, count); - break; - } - case 2: { - uint16_t* data = (uint16_t*)data_ptr; - uint16_t value_bits = *(const uint16_t*)(pattern); - for (iree_device_size_t i = 0; i < count; ++i) { - iree_unaligned_store(&data[i], value_bits); - } - break; - } - case 4: { - uint32_t* data = (uint32_t*)data_ptr; - uint32_t value_bits = *(const uint32_t*)(pattern); - for (iree_device_size_t i = 0; i < count; ++i) { - iree_unaligned_store(&data[i], value_bits); - } - break; - } - case 8: { - uint64_t* data = (uint64_t*)data_ptr; - uint64_t value_bits = *(const uint64_t*)(pattern); - for (iree_device_size_t i = 0; i < count; ++i) { - iree_unaligned_store(&data[i], value_bits); - } - break; - } - default: - IREE_ASSERT_UNREACHABLE("verified in iree_io_stream_fill"); - break; - } - if (iree_status_is_ok(status)) { - stream->offset += access_length; - } - - IREE_TRACE_ZONE_END(z0); - return status; -} - -static iree_status_t iree_io_memory_stream_map_read( - iree_io_stream_t* base_stream, iree_host_size_t length, - iree_const_byte_span_t* out_span) { - IREE_ASSERT_ARGUMENT(base_stream); - IREE_ASSERT_ARGUMENT(out_span); - *out_span = iree_const_byte_span_empty(); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, - length, NULL)); - - *out_span = - iree_make_const_byte_span(stream->contents + stream->offset, length); - stream->offset += length; - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_status_t iree_io_memory_stream_map_write( - iree_io_stream_t* base_stream, iree_host_size_t length, - iree_byte_span_t* out_span) { - IREE_ASSERT_ARGUMENT(base_stream); - IREE_ASSERT_ARGUMENT(out_span); - *out_span = iree_byte_span_empty(); - iree_io_memory_stream_t* stream = iree_io_memory_stream_cast(base_stream); - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_io_stream_validate_fixed_range(stream->offset, stream->length, - length, NULL)); - - *out_span = iree_make_byte_span(stream->contents + stream->offset, length); - stream->offset += length; - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static const iree_io_stream_vtable_t iree_io_memory_stream_vtable = { - .destroy = iree_io_memory_stream_destroy, - .offset = iree_io_memory_stream_offset, - .length = iree_io_memory_stream_length, - .seek = iree_io_memory_stream_seek, - .read = iree_io_memory_stream_read, - .write = iree_io_memory_stream_write, - .fill = iree_io_memory_stream_fill, - .map_read = iree_io_memory_stream_map_read, - .map_write = iree_io_memory_stream_map_write, -}; diff --git a/runtime/src/iree/io/stream.h b/runtime/src/iree/io/stream.h index 91d05d7200d4..10834719b3f3 100644 --- a/runtime/src/iree/io/stream.h +++ b/runtime/src/iree/io/stream.h @@ -83,12 +83,12 @@ iree_io_stream_length(iree_io_stream_t* stream); // When at the end of stream reads will fail and writes will append. IREE_API_EXPORT bool iree_io_stream_is_eos(iree_io_stream_t* stream); -// Seeks within |stream| to |offset| based on the given |seek_mode|. +// Seeks within |stream| to |seek_offset| based on the given |seek_mode|. // If IREE_IO_STREAM_MODE_SEEKABLE is not set then only forward relative seeks // are supported. IREE_API_EXPORT iree_status_t iree_io_stream_seek( iree_io_stream_t* stream, iree_io_stream_seek_mode_t seek_mode, - iree_io_stream_pos_t offset); + iree_io_stream_pos_t seek_offset); // Seeks within |stream| to the next offset with the specified |alignment|. // The alignment is expected to be a power-of-two value. @@ -113,6 +113,16 @@ IREE_API_EXPORT iree_status_t iree_io_stream_write(iree_io_stream_t* stream, iree_host_size_t buffer_length, const void* buffer); +// Writes a single character/byte to the stream. +// Requires the stream have IREE_IO_STREAM_MODE_WRITABLE. +IREE_API_EXPORT iree_status_t +iree_io_stream_write_char(iree_io_stream_t* stream, char c); + +// Writes a string view to the stream (excluding NUL terminator). +// Requires the stream have IREE_IO_STREAM_MODE_WRITABLE. +IREE_API_EXPORT iree_status_t +iree_io_stream_write_string(iree_io_stream_t* stream, iree_string_view_t value); + // Writes |count| elements of |pattern_length| with the given |pattern| value. // Requires the stream have IREE_IO_STREAM_MODE_WRITABLE. IREE_API_EXPORT iree_status_t @@ -155,7 +165,7 @@ typedef struct iree_io_stream_vtable_t { iree_io_stream_pos_t(IREE_API_PTR* length)(iree_io_stream_t* stream); iree_status_t(IREE_API_PTR* seek)(iree_io_stream_t* stream, iree_io_stream_seek_mode_t seek_mode, - iree_io_stream_pos_t offset); + iree_io_stream_pos_t seek_offset); iree_status_t(IREE_API_PTR* read)(iree_io_stream_t* stream, iree_host_size_t buffer_capacity, void* buffer, @@ -181,46 +191,6 @@ struct iree_io_stream_t { iree_io_stream_mode_t mode; }; -//===----------------------------------------------------------------------===// -// iree_io_memory_stream_t -//===----------------------------------------------------------------------===// - -typedef void(IREE_API_PTR* iree_io_memory_stream_release_fn_t)( - void* user_data, iree_io_stream_t* stream); - -// A callback issued when a memory stream is released. -typedef struct { - // Callback function pointer. - iree_io_memory_stream_release_fn_t fn; - // User data passed to the callback function. Unowned. - void* user_data; -} iree_io_memory_stream_release_callback_t; - -// Returns a no-op file release callback that implies that no cleanup is -// required. -static inline iree_io_memory_stream_release_callback_t -iree_io_memory_stream_release_callback_null(void) { - iree_io_memory_stream_release_callback_t callback = {NULL, NULL}; - return callback; -} - -// Wraps a fixed-size host memory allocation |contents| in a stream. -// |release_callback| can be used to receive a callback when the stream is -// destroyed and the reference to the contents is no longer required. -IREE_API_EXPORT iree_status_t iree_io_memory_stream_wrap( - iree_io_stream_mode_t mode, iree_byte_span_t contents, - iree_io_memory_stream_release_callback_t release_callback, - iree_allocator_t host_allocator, iree_io_stream_t** out_stream); - -//===----------------------------------------------------------------------===// -// iree_io_buffer_stream_t -//===----------------------------------------------------------------------===// - -// TODO(benvanik): buffer stream that grows with a specified block size. Provide -// a iree_io_buffer_stream_enumerate_data_blocks(stream, callback) to enumerate -// the blocks in order (ala iovecs). Take an iree_allocator_t dedicated to the -// block storage separate from the iree_io_stream_t metadata. - #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/runtime/src/iree/io/vec_stream.c b/runtime/src/iree/io/vec_stream.c new file mode 100644 index 000000000000..32b87eb06c9c --- /dev/null +++ b/runtime/src/iree/io/vec_stream.c @@ -0,0 +1,503 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/io/vec_stream.h" + +//===----------------------------------------------------------------------===// +// iree_io_vec_stream_t +//===----------------------------------------------------------------------===// + +#define IREE_IO_VEC_BLOCK_ALIGNMENT 16 +#define IREE_IO_VEC_BLOCK_MIN_SIZE 1024 + +// Metadata and storage for a block. +// Block sizes provided by users include the metadata so that users can pick +// bucketing allocator-friendly sizes and not end up tripping into the next +// bucket. This means each block actually stores a bit less than whatever they +// request. Use IREE_IO_VEC_BLOCK_STORAGE_CAPACITY to determine the actual +// storage capacity per block based on a block size. +typedef struct iree_io_vec_block_t { + // Next block in the stream's block linked list. + struct iree_io_vec_block_t* next; + // Previous block in the stream's block linked list. + struct iree_io_vec_block_t* prev; + // Global offset within the stream. + iree_io_stream_pos_t offset; + // Capacity of the block storage in bytes. + iree_host_size_t capacity; + // Current length of the block storage in bytes. + // This will be under capacity if the block is at the end of the stream. + iree_host_size_t length; + // Block contents of size capacity. + iree_alignas(IREE_IO_VEC_BLOCK_ALIGNMENT) uint8_t contents[/*capacity*/]; +} iree_io_vec_block_t; + +#define IREE_IO_VEC_BLOCK_STORAGE_CAPACITY(block_size) \ + ((block_size)-offsetof(iree_io_vec_block_t, contents)) + +typedef struct iree_io_vec_stream_t { + iree_io_stream_t base; + iree_allocator_t host_allocator; + // Current offset within the stream. block_pos is the block containing the + // offset. + iree_io_stream_pos_t offset; + // Total length of the stream. The available capacity of all blocks allocated + // will be greater than or equal to this. + iree_io_stream_pos_t length; + // Size of each block allocated in bytes. + // Uniform block sizing prevents allocator fragmentation. + iree_host_size_t block_size; + // Head of the block linked list, NULL if none allocated. + iree_io_vec_block_t* block_head; + // Tail of the block linked list, NULL if none allocated. + iree_io_vec_block_t* block_tail; + // Block containing the current stream offset, NULL if none allocated. + iree_io_vec_block_t* block_pos; +} iree_io_vec_stream_t; + +static const iree_io_stream_vtable_t iree_io_vec_stream_vtable; + +static iree_io_vec_stream_t* iree_io_vec_stream_cast( + iree_io_stream_t* IREE_RESTRICT base_stream) { + return (iree_io_vec_stream_t*)base_stream; +} + +IREE_API_EXPORT iree_status_t iree_io_vec_stream_create( + iree_io_stream_mode_t mode, iree_host_size_t block_size, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + block_size = + iree_max(IREE_IO_VEC_BLOCK_MIN_SIZE, + iree_host_align(block_size, IREE_IO_VEC_BLOCK_ALIGNMENT)); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)block_size); + + iree_io_vec_stream_t* stream = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*stream), (void**)&stream)); + iree_atomic_ref_count_init(&stream->base.ref_count); + stream->base.vtable = &iree_io_vec_stream_vtable; + stream->base.mode = mode; + stream->host_allocator = host_allocator; + stream->offset = 0; + stream->length = 0; + stream->block_size = block_size; + stream->block_head = stream->block_tail = stream->block_pos = NULL; + + *out_stream = &stream->base; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_io_vec_stream_destroy( + iree_io_stream_t* IREE_RESTRICT base_stream) { + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + iree_allocator_t host_allocator = stream->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_io_vec_block_t* block = stream->block_head; + while (block) { + iree_io_vec_block_t* next = block->next; + iree_allocator_free(host_allocator, block); + block = next; + } + + iree_allocator_free(host_allocator, stream); + + IREE_TRACE_ZONE_END(z0); +} + +IREE_API_EXPORT iree_status_t iree_io_vec_stream_enumerate_blocks( + iree_io_stream_t* base_stream, iree_io_vec_stream_callback_fn_t callback, + void* user_data) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)stream->length); + + iree_status_t status = iree_ok_status(); + for (iree_io_vec_block_t* block = stream->block_head; block != NULL; + block = block->next) { + status = callback( + user_data, iree_make_const_byte_span(block->contents, block->length)); + if (!iree_status_is_ok(status)) break; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_io_stream_pos_t iree_io_vec_stream_offset( + iree_io_stream_t* base_stream) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + return stream->offset; +} + +static iree_io_stream_pos_t iree_io_vec_stream_length( + iree_io_stream_t* base_stream) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + return stream->length; +} + +// Asserts the block list and current offset match. +static void iree_io_vec_stream_assert_valid(iree_io_vec_stream_t* stream) { + if (!stream->block_head) return; + IREE_ASSERT(stream->block_pos); + IREE_ASSERT_LE(stream->block_pos->offset, stream->offset); + IREE_ASSERT_GE(stream->block_pos->offset + stream->block_pos->length, + stream->offset); +} + +// Extends the stream up to the new total length. +// The current stream offset is not changed though both block_head and +// block_tail may be. +static iree_status_t iree_io_vec_stream_extend( + iree_io_vec_stream_t* stream, iree_io_stream_pos_t new_length) { + IREE_ASSERT_ARGUMENT(stream); + if (!new_length) return iree_ok_status(); + if (stream->length >= new_length) return iree_ok_status(); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, new_length); + + // Determine how many bytes we need to allocate and then allocate blocks up + // until we reach that new total. We'll fill the current block (if any) first + // and that may be all we need. + iree_io_stream_pos_t remaining_bytes = new_length - stream->length; + if (stream->block_tail != NULL) { + // Fill the current block first. This may satisfy the entire request and + // we can bail early. + iree_host_size_t block_bytes = + iree_min(remaining_bytes, + stream->block_tail->capacity - stream->block_tail->length); + stream->block_tail->length += block_bytes; + stream->length += block_bytes; + remaining_bytes -= block_bytes; + } + iree_status_t status = iree_ok_status(); + iree_host_size_t block_capacity = + IREE_IO_VEC_BLOCK_STORAGE_CAPACITY(stream->block_size); + while (remaining_bytes > 0) { + // Allocate a new block. + iree_io_vec_block_t* block = NULL; + status = iree_allocator_malloc(stream->host_allocator, stream->block_size, + (void**)&block); + if (!iree_status_is_ok(status)) break; + iree_host_size_t block_bytes = iree_min(remaining_bytes, block_capacity); + block->prev = stream->block_tail; + if (block->prev) { + block->prev->next = block; + } + stream->block_tail = block; + block->next = NULL; + if (!stream->block_head) { + // First block, set as head. + stream->block_head = block; + } + block->offset = stream->length; + stream->length += block_bytes; + block->capacity = block_capacity; + block->length = block_bytes; + remaining_bytes -= block_bytes; + // NOTE: iree_allocator_malloc guarantees contents are zeroed. + } + IREE_ASSERT_EQ(stream->length, new_length); + if (!stream->block_pos) { + // If we just allocated the stream then set the offset 0 block. + stream->block_pos = stream->block_head; + } + iree_io_vec_stream_assert_valid(stream); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_io_vec_stream_seek( + iree_io_stream_t* base_stream, iree_io_stream_seek_mode_t seek_mode, + iree_io_stream_pos_t seek_offset) { + IREE_ASSERT_ARGUMENT(base_stream); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + // We compute a new global offset and then navigate the list based on that. We + // could use the seek mode as a discriminator for that instead but before + // walking we have to handle extends on the common path anyway. + iree_io_stream_pos_t new_offset = stream->offset; + switch (seek_mode) { + case IREE_IO_STREAM_SEEK_SET: + new_offset = seek_offset; + break; + case IREE_IO_STREAM_SEEK_FROM_CURRENT: + new_offset = stream->offset + seek_offset; + break; + case IREE_IO_STREAM_SEEK_FROM_END: + new_offset = stream->length + seek_offset; + break; + default: + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unrecognized seek mode %u", (uint32_t)seek_mode); + } + if (new_offset == stream->offset) { + // No change fast-path. + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } else if (new_offset < 0) { + // Trying to seek off the beginning of the stream. + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "seek %u offset %" PRIi64 + " out of stream bounds; expected 0 <= %" PRIi64, + (uint32_t)seek_mode, seek_offset, new_offset); + } + + // Extend the stream if the new offset is off the current end. This will + // allocate new empty blocks with zeroed contents. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_vec_stream_extend(stream, new_offset)); + + // If the stream is not allocated then bail (seeking to offset 0 of an empty + // stream doesn't allocate anything). + if (!stream->block_head) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + // Seek to find the block containing the new offset. + // Since we only have a linked list we have to do a walk but the direction we + // walk will be based on where we are starting from. We special case some + // common cases like 0 and end to avoid walking the whole list. + if (new_offset == 0 || + (stream->block_head && new_offset < stream->block_head->length)) { + // Within the first block. + stream->block_pos = stream->block_head; + } else if (new_offset == stream->length || + (stream->block_tail && new_offset >= stream->block_tail->offset)) { + // Within the last block. + stream->block_pos = stream->block_tail; + } else { + // Somewhere in the middle of the list; walk forward or backward. + IREE_ASSERT(stream->block_pos); + if (new_offset < stream->offset) { + // Seeking backward. + iree_io_vec_block_t* block = stream->block_pos; + for (; block && block->offset > new_offset; block = block->prev) { + } + stream->block_pos = block; + } else { + // Seeking forward. + iree_io_vec_block_t* block = stream->block_pos; + for (; block && block->offset + block->length < new_offset; + block = block->next) { + } + stream->block_pos = block; + } + } + stream->offset = new_offset; + iree_io_vec_stream_assert_valid(stream); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_vec_stream_read( + iree_io_stream_t* base_stream, iree_host_size_t buffer_capacity, + void* buffer, iree_host_size_t* out_buffer_length) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(buffer); + if (out_buffer_length) *out_buffer_length = 0; + if (buffer_capacity == 0) return iree_ok_status(); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + // Determine how many bytes to read based on how many bytes are available + // in the stream from the current offset. + iree_io_stream_pos_t remaining_length = stream->length - stream->offset; + iree_host_size_t read_bytes = buffer_capacity; + if (buffer_capacity > remaining_length) { + // Access exceeds remaining length. + if (out_buffer_length) { + // Read-to-end; we'll return less than the full capacity. + read_bytes = remaining_length; + } else { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "access to range [%" PRIu64 ", %" PRIu64 + ") (%" PRIhsz + " bytes) out of range; stream offset %" PRIu64 + " and length %" PRIu64 " insufficient", + stream->offset, stream->offset + buffer_capacity, + buffer_capacity, stream->offset, stream->length); + } + } + + // Copy bytes from blocks for the entire read length. + uint8_t* buffer_ptr = (uint8_t*)buffer; + iree_host_size_t read_offset = 0; + iree_io_stream_pos_t new_offset = stream->offset; + iree_io_vec_block_t* block = stream->block_pos; + iree_io_stream_pos_t block_offset = new_offset - block->offset; + while (read_offset < read_bytes) { + IREE_ASSERT(block); + if (new_offset >= block->offset + block->length) { + IREE_ASSERT(block->next, + "should have verified length and have a next block"); + block = block->next; + block_offset = 0; + } + IREE_ASSERT(block); + iree_host_size_t block_bytes = + iree_min(read_bytes - read_offset, block->length); + memcpy(buffer_ptr, &block->contents[block_offset], block_bytes); + buffer_ptr += block_bytes; + read_offset += block_bytes; + new_offset += block_bytes; + } + stream->offset = new_offset; + stream->block_pos = block; + iree_io_vec_stream_assert_valid(stream); + + if (out_buffer_length) *out_buffer_length = read_bytes; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_vec_stream_write(iree_io_stream_t* base_stream, + iree_host_size_t buffer_length, + const void* buffer) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(buffer); + if (!buffer_length) return iree_ok_status(); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + // Extend the stream storage up to the final size from the current position. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_vec_stream_extend(stream, stream->offset + buffer_length)); + + // Copy the source buffer to the blocks. + iree_host_size_t remaining_bytes = buffer_length; + iree_io_vec_block_t* block = stream->block_pos; + iree_host_size_t block_offset = stream->offset - block->offset; + const uint8_t* buffer_ptr = (const uint8_t*)buffer; + while (remaining_bytes > 0) { + IREE_ASSERT(block); + if (block_offset == block->capacity) { + IREE_ASSERT(block->next, "should have resized and have a next block"); + block = block->next; + block_offset = 0; + } + IREE_ASSERT(block); + iree_host_size_t write_bytes = + iree_min(block->capacity - block_offset, remaining_bytes); + memcpy(&block->contents[block_offset], buffer_ptr, write_bytes); + buffer_ptr += write_bytes; + remaining_bytes -= write_bytes; + block_offset += write_bytes; + } + + // Update the offset and block containing it for future operations. + stream->offset += buffer_length; + stream->block_pos = block; + iree_io_vec_stream_assert_valid(stream); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_vec_stream_fill_1(iree_io_vec_stream_t* stream, + iree_io_stream_pos_t count, + uint8_t pattern) { + // Copy the source buffer to the blocks. + iree_host_size_t remaining_bytes = count; + iree_io_vec_block_t* block = stream->block_pos; + iree_host_size_t block_offset = stream->offset - block->offset; + while (remaining_bytes > 0) { + IREE_ASSERT(block); + if (block_offset == block->capacity) { + IREE_ASSERT(block->next, "should have resized and have a next block"); + block = block->next; + block_offset = 0; + } + IREE_ASSERT(block); + iree_host_size_t write_bytes = + iree_min(block->capacity - block_offset, remaining_bytes); + memset(&block->contents[block_offset], pattern, write_bytes); + remaining_bytes -= write_bytes; + block_offset += write_bytes; + } + + // Update the offset and block containing it for future operations. + stream->offset += count; + stream->block_pos = block; + iree_io_vec_stream_assert_valid(stream); + + return iree_ok_status(); +} + +static iree_status_t iree_io_vec_stream_fill(iree_io_stream_t* base_stream, + iree_io_stream_pos_t count, + const void* pattern, + iree_host_size_t pattern_length) { + IREE_ASSERT_ARGUMENT(base_stream); + IREE_ASSERT_ARGUMENT(pattern); + iree_io_vec_stream_t* stream = iree_io_vec_stream_cast(base_stream); + IREE_TRACE_ZONE_BEGIN(z0); + + // Grow the stream to the entire new length (if needed). + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_io_vec_stream_extend(stream, + stream->offset + count * pattern_length), + "growing stream to fill bounds"); + + // TODO(benvanik): efficient fill - we should be able to partition into + // prior block and some new number of blocks. The tricky part is that the + // alignment is 1 so we may need to split the pattern across the boundary. + // For now we fast path pattern_length 1 and are slow for everything else. + if (pattern_length == 1) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_io_vec_stream_fill_1(stream, count, *((const uint8_t*)pattern))); + } else { + for (iree_io_stream_pos_t i = 0; i < count; ++i) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_vec_stream_write(base_stream, pattern_length, pattern)); + } + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_io_vec_stream_map_read( + iree_io_stream_t* base_stream, iree_host_size_t length, + iree_const_byte_span_t* out_span) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "vec streams do not support mapping"); +} + +static iree_status_t iree_io_vec_stream_map_write(iree_io_stream_t* base_stream, + iree_host_size_t length, + iree_byte_span_t* out_span) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "vec streams do not support mapping"); +} + +static const iree_io_stream_vtable_t iree_io_vec_stream_vtable = { + .destroy = iree_io_vec_stream_destroy, + .offset = iree_io_vec_stream_offset, + .length = iree_io_vec_stream_length, + .seek = iree_io_vec_stream_seek, + .read = iree_io_vec_stream_read, + .write = iree_io_vec_stream_write, + .fill = iree_io_vec_stream_fill, + .map_read = iree_io_vec_stream_map_read, + .map_write = iree_io_vec_stream_map_write, +}; diff --git a/runtime/src/iree/io/vec_stream.h b/runtime/src/iree/io/vec_stream.h new file mode 100644 index 000000000000..e79c30bc7704 --- /dev/null +++ b/runtime/src/iree/io/vec_stream.h @@ -0,0 +1,44 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_IO_VEC_STREAM_H_ +#define IREE_IO_VEC_STREAM_H_ + +#include "iree/base/api.h" +#include "iree/io/stream.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_io_vec_stream_t +//===----------------------------------------------------------------------===// + +// Creates an in-memory stream that grows as data is written. +// Blocks of |block_size| (16-byte aligned) are allocated each time growth is +// required and writes will be split to fit into blocks. To retrieve the data +// from the stream use iree_io_vec_stream_enumerate_blocks or seek and read it +// back. +IREE_API_EXPORT iree_status_t iree_io_vec_stream_create( + iree_io_stream_mode_t mode, iree_host_size_t block_size, + iree_allocator_t host_allocator, iree_io_stream_t** out_stream); + +// Called for each block in stream order. Blocks may be sized under the +// requested block size if they contain partial data. +typedef iree_status_t(IREE_API_PTR* iree_io_vec_stream_callback_fn_t)( + void* user_data, iree_const_byte_span_t block); + +// Issues |callback| for each block of data in the stream. +IREE_API_EXPORT iree_status_t iree_io_vec_stream_enumerate_blocks( + iree_io_stream_t* stream, iree_io_vec_stream_callback_fn_t callback, + void* user_data); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_IO_VEC_STREAM_H_ diff --git a/runtime/src/iree/io/vec_stream_test.cc b/runtime/src/iree/io/vec_stream_test.cc new file mode 100644 index 000000000000..1626db0493aa --- /dev/null +++ b/runtime/src/iree/io/vec_stream_test.cc @@ -0,0 +1,578 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/io/vec_stream.h" + +#include +#include +#include + +#include "iree/base/api.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace { + +using iree::Status; +using iree::StatusCode; +using iree::testing::status::StatusIs; +using testing::ElementsAre; +using testing::ElementsAreArray; +using testing::Eq; + +using StreamPtr = + std::unique_ptr; + +static StreamPtr CreateStream(iree_io_stream_mode_t mode, + size_t block_size = 1 * 1024) { + iree_io_stream_t* stream = NULL; + IREE_CHECK_OK(iree_io_vec_stream_create(mode, block_size, + iree_allocator_system(), &stream)); + return StreamPtr(stream, iree_io_stream_release); +} + +template +static StreamPtr CreateStreamWithContents(iree_io_stream_mode_t mode, + T (&elements)[N], + size_t block_size = 1 * 1024) { + iree_io_stream_t* stream = NULL; + IREE_CHECK_OK(iree_io_vec_stream_create(mode | IREE_IO_STREAM_MODE_WRITABLE, + block_size, iree_allocator_system(), + &stream)); + IREE_CHECK_OK(iree_io_stream_write(stream, sizeof(T) * N, elements)); + IREE_CHECK_OK(iree_io_stream_seek(stream, IREE_IO_STREAM_SEEK_SET, 0)); + return StreamPtr(stream, iree_io_stream_release); +} + +TEST(VecStreamTest, Empty) { + auto stream = CreateStream(IREE_IO_STREAM_MODE_READABLE); + EXPECT_EQ(iree_io_stream_mode(stream.get()), IREE_IO_STREAM_MODE_READABLE); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), 0); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); +} + +TEST(VecStreamTest, SeekSet) { + uint8_t data[5] = {0, 1, 2, 3, 4}; + auto stream = CreateStreamWithContents(IREE_IO_STREAM_MODE_READABLE, data); + + // Streams start at origin 0. + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), sizeof(data)); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // No-op seek to origin. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek to end-of-stream. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, + iree_io_stream_length(stream.get()))); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + + // Seek to absolute offset 1. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek to absolute offset length-1 (last valid byte). + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 4)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 4); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Try seeking out of bounds (off the front of the list). + EXPECT_THAT( + Status(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, -1)), + StatusIs(StatusCode::kOutOfRange)); + + // Seek off the end of the stream to extend it. + EXPECT_EQ(iree_io_stream_length(stream.get()), 5); + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 6)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 6); + EXPECT_EQ(iree_io_stream_length(stream.get()), 6); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); +} + +TEST(VecStreamTest, SeekFromCurrent) { + uint8_t data[5] = {0, 1, 2, 3, 4}; + auto stream = CreateStreamWithContents(IREE_IO_STREAM_MODE_READABLE, data); + + // Streams start at origin 0. + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), sizeof(data)); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek to end-of-stream by jumping the full length. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), + IREE_IO_STREAM_SEEK_FROM_CURRENT, + iree_io_stream_length(stream.get()))); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + + // Reset back to origin by seeking back the full length. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), + IREE_IO_STREAM_SEEK_FROM_CURRENT, + -iree_io_stream_length(stream.get()))); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek forward to absolute position 1. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_CURRENT, 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // No-op seek to current location (absolute 1). + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_CURRENT, 0)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek to absolute offset length-1 (last valid byte) - here (5-1) - 1 = 3. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_CURRENT, 3)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 4); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek forward 1 to absolute end-of-stream. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_CURRENT, 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + + // Reset back to origin. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + // Try seeking out of bounds. + EXPECT_THAT(Status(iree_io_stream_seek( + stream.get(), IREE_IO_STREAM_SEEK_FROM_CURRENT, -100)), + StatusIs(StatusCode::kOutOfRange)); + + // Seek off the end of the stream to extend it. + EXPECT_EQ(iree_io_stream_length(stream.get()), 5); + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_CURRENT, 600)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 600); + EXPECT_EQ(iree_io_stream_length(stream.get()), 600); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); +} + +TEST(VecStreamTest, SeekFromEnd) { + uint8_t data[5] = {0, 1, 2, 3, 4}; + auto stream = CreateStreamWithContents(IREE_IO_STREAM_MODE_READABLE, data); + + // Streams start at origin 0. + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), sizeof(data)); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Jump to end-of-stream. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, 0)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + + // Reset back to origin by seeking back the full length. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, + -iree_io_stream_length(stream.get()))); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Seek to absolute offset length-1 (last valid byte) - here 5 - 1 = 4. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, -1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 4); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Reset back to origin. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + // Try seeking out of bounds. + EXPECT_THAT(Status(iree_io_stream_seek(stream.get(), + IREE_IO_STREAM_SEEK_FROM_END, -100)), + StatusIs(StatusCode::kOutOfRange)); + + // Seek off the end of the stream to extend it. + EXPECT_EQ(iree_io_stream_length(stream.get()), 5); + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, 100)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 105); + EXPECT_EQ(iree_io_stream_length(stream.get()), 105); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); +} + +TEST(VecStreamTest, SeekToAlignment) { + uint8_t data[5] = {0, 1, 2, 3, 4}; + auto stream = CreateStreamWithContents(IREE_IO_STREAM_MODE_READABLE, data); + + // Streams start at origin 0. + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), sizeof(data)); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + // Alignment must be a power of two. + EXPECT_THAT(Status(iree_io_stream_seek_to_alignment(stream.get(), 3)), + StatusIs(StatusCode::kInvalidArgument)); + EXPECT_THAT(Status(iree_io_stream_seek_to_alignment(stream.get(), 63)), + StatusIs(StatusCode::kInvalidArgument)); + EXPECT_THAT(Status(iree_io_stream_seek_to_alignment(stream.get(), -2)), + StatusIs(StatusCode::kInvalidArgument)); + + // Alignment at 0 should always be ok. + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 0)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 2)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + + // Seek forward to an unaligned absolute offset 1. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + + // Seek forward to alignment 2, which should be absolute offset 2. + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 2)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 2); + + // Alignment that matches the current offset (2) should be a no-op. + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 2)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 2); + + // Align up from an aligned value. + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 4)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 4); + + // Align off the end of the stream to extend. + EXPECT_EQ(iree_io_stream_length(stream.get()), 5); + IREE_EXPECT_OK(iree_io_stream_seek_to_alignment(stream.get(), 16)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 16); + EXPECT_EQ(iree_io_stream_length(stream.get()), 16); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); +} + +TEST(VecStreamTest, ReadUpTo) { + uint8_t data[5] = {0, 1, 2, 3, 4}; + auto stream = CreateStreamWithContents(IREE_IO_STREAM_MODE_READABLE, data); + + // Streams start at origin 0. + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), sizeof(data)); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + uint8_t read_buffer[64] = {0xDD}; + iree_host_size_t read_length = 0; + + // Reads of zero length should no-op. + IREE_EXPECT_OK( + iree_io_stream_read(stream.get(), 0, read_buffer, &read_length)); + EXPECT_EQ(read_length, 0); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + + // Reads should advance the stream offset. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK( + iree_io_stream_read(stream.get(), 1, read_buffer, &read_length)); + EXPECT_EQ(read_length, 1); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + EXPECT_EQ(read_buffer[0], 0); + EXPECT_EQ(read_buffer[1], 0xDD); + + // Read another chunk of 2 bytes. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK( + iree_io_stream_read(stream.get(), 2, read_buffer, &read_length)); + EXPECT_EQ(read_length, 2); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 3); + EXPECT_EQ(read_buffer[0], 1); + EXPECT_EQ(read_buffer[1], 2); + EXPECT_EQ(read_buffer[2], 0xDD); + + // Read up to the end of the stream (2 bytes remaining) by reading over. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK(iree_io_stream_read(stream.get(), sizeof(read_buffer), + read_buffer, &read_length)); + EXPECT_EQ(read_length, 2); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + EXPECT_EQ(read_buffer[0], 3); + EXPECT_EQ(read_buffer[1], 4); + EXPECT_EQ(read_buffer[2], 0xDD); + + // Reading from the end of the stream should be a no-op. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK(iree_io_stream_read(stream.get(), sizeof(read_buffer), + read_buffer, &read_length)); + EXPECT_EQ(read_length, 0); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + EXPECT_EQ(read_buffer[0], 0xDD); +} + +TEST(VecStreamTest, ReadExact) { + uint8_t data[5] = {0, 1, 2, 3, 4}; + auto stream = CreateStreamWithContents(IREE_IO_STREAM_MODE_READABLE, data); + + // Streams start at origin 0. + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), sizeof(data)); + EXPECT_FALSE(iree_io_stream_is_eos(stream.get())); + + uint8_t read_buffer[64] = {0xDD}; + + // Reads of zero length should no-op. + IREE_EXPECT_OK(iree_io_stream_read(stream.get(), 0, read_buffer, NULL)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + + // Reads should advance the stream offset. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK(iree_io_stream_read(stream.get(), 1, read_buffer, NULL)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + EXPECT_EQ(read_buffer[0], 0); + EXPECT_EQ(read_buffer[1], 0xDD); + + // Read another chunk of 2 bytes. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK(iree_io_stream_read(stream.get(), 2, read_buffer, NULL)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 3); + EXPECT_EQ(read_buffer[0], 1); + EXPECT_EQ(read_buffer[1], 2); + EXPECT_EQ(read_buffer[2], 0xDD); + + // Read up to the end of the stream (2 bytes remaining) by reading over. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + IREE_EXPECT_OK(iree_io_stream_read(stream.get(), 2, read_buffer, NULL)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), + iree_io_stream_length(stream.get())); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + EXPECT_EQ(read_buffer[0], 3); + EXPECT_EQ(read_buffer[1], 4); + EXPECT_EQ(read_buffer[2], 0xDD); + + // Reading from the end of the stream fails with no read length arg. + memset(read_buffer, 0xDD, sizeof(read_buffer)); + EXPECT_THAT(Status(iree_io_stream_read(stream.get(), sizeof(read_buffer), + read_buffer, NULL)), + StatusIs(StatusCode::kOutOfRange)); + + // Reset back to the origin and try reading off the end. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + EXPECT_THAT(Status(iree_io_stream_read(stream.get(), sizeof(read_buffer), + read_buffer, NULL)), + StatusIs(StatusCode::kOutOfRange)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); +} + +TEST(VecStreamTest, Write) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t data[5] = {0xDD}; + const uint8_t write_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + + // Writes of zero length should be a no-op. + memset(data, 0xDD, sizeof(data)); + IREE_EXPECT_OK(iree_io_stream_write(stream.get(), 0, write_buffer)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); + EXPECT_EQ(iree_io_stream_length(stream.get()), 0); + EXPECT_EQ(data[0], 0xDD); + + // Writes should advance the stream. + memset(data, 0xDD, sizeof(data)); + IREE_EXPECT_OK(iree_io_stream_write(stream.get(), 1, write_buffer)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1); + EXPECT_EQ(iree_io_stream_length(stream.get()), 1); + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), 1, data, NULL)); + EXPECT_EQ(data[0], 0); + EXPECT_EQ(data[1], 0xDD); + + // Write 2 more bytes and ensure only those are mutated. + memset(data, 0xDD, sizeof(data)); + IREE_EXPECT_OK(iree_io_stream_write(stream.get(), 2, &write_buffer[1])); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1 + 2); + EXPECT_EQ(iree_io_stream_length(stream.get()), 1 + 2); + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), 3, data, NULL)); + EXPECT_EQ(data[0], 0); + EXPECT_EQ(data[1], 1); + EXPECT_EQ(data[2], 2); + EXPECT_EQ(data[3], 0xDD); + + // Seek to the end of the stream and try to write 0 bytes (should be a no-op). + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, 0)); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + IREE_EXPECT_OK(iree_io_stream_write(stream.get(), 0, write_buffer)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 3); + EXPECT_EQ(iree_io_stream_length(stream.get()), 3); + EXPECT_TRUE(iree_io_stream_is_eos(stream.get())); + + // Overwrite the entire contents of the storage. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_EXPECT_OK( + iree_io_stream_write(stream.get(), sizeof(data), write_buffer)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 5); + EXPECT_EQ(iree_io_stream_length(stream.get()), 5); + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), sizeof(data), data, NULL)); + EXPECT_THAT(data, + ElementsAre(write_buffer[0], write_buffer[1], write_buffer[2], + write_buffer[3], write_buffer[4])); +} + +TEST(VecStreamTest, FillSizes) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t pattern[] = {0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + + // Fill patterns must be 1,2,4,8 bytes. + EXPECT_THAT(Status(iree_io_stream_fill(stream.get(), 1, pattern, 3)), + StatusIs(StatusCode::kInvalidArgument)); + EXPECT_THAT(Status(iree_io_stream_fill(stream.get(), 1, pattern, 9)), + StatusIs(StatusCode::kInvalidArgument)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 0); +} + +TEST(VecStreamTest, Fill1) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t pattern[] = {0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + + // Extend to 16 bytes for easy fill testing. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 16)); + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + // Fill with pattern size 1. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 1)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 3, pattern, 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1 + 3); + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, -2)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 2, pattern, 1)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 16 - 2 + 2); + + uint8_t data[16] = {0xDD}; + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), sizeof(data), data, NULL)); + EXPECT_THAT(data, + ElementsAre(0x00, 0x80, 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x80)); +} + +TEST(VecStreamTest, Fill2) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t pattern[] = {0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + + // Extend to 16 bytes for easy fill testing. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 16)); + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + // Fill with pattern size 2. + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 1)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 3, pattern, 2)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1 + 3 * 2); + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, -4)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 2, pattern, 2)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 16 - 4 + 2 * 2); + + uint8_t data[16] = {0xDD}; + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), sizeof(data), data, NULL)); + EXPECT_THAT(data, + ElementsAre(0x00, 0x80, 0x90, 0x80, 0x90, 0x80, 0x90, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, 0x90, 0x80, 0x90)); +} + +TEST(VecStreamTest, Fill4) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t pattern[] = {0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + + // Extend to 16 bytes for easy fill testing. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 16)); + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 1)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 2, pattern, 4)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1 + 2 * 4); + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, -4)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 1, pattern, 4)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 16 - 4 + 1 * 4); + + uint8_t data[16] = {0xDD}; + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), sizeof(data), data, NULL)); + EXPECT_THAT(data, + ElementsAre(0x00, 0x80, 0x90, 0xA0, 0xB0, 0x80, 0x90, 0xA0, 0xB0, + 0x00, 0x00, 0x00, 0x80, 0x90, 0xA0, 0xB0)); +} + +TEST(VecStreamTest, Fill8Unaligned) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t pattern[] = {0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + + // Extend to 16 bytes for easy fill testing. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 16)); + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 1)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 1, pattern, 8)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 1 + 1 * 8); + + uint8_t data[16] = {0xDD}; + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), sizeof(data), data, NULL)); + EXPECT_THAT(data, + ElementsAre(0x00, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); +} + +TEST(VecStreamTest, Fill8End) { + auto stream = + CreateStream(IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE); + + uint8_t pattern[] = {0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + + // Extend to 16 bytes for easy fill testing. + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 16)); + IREE_EXPECT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + + IREE_EXPECT_OK( + iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_FROM_END, -8)); + IREE_EXPECT_OK(iree_io_stream_fill(stream.get(), 1, pattern, 8)); + EXPECT_EQ(iree_io_stream_offset(stream.get()), 16); + + uint8_t data[16] = {0xDD}; + IREE_ASSERT_OK(iree_io_stream_seek(stream.get(), IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK(iree_io_stream_read(stream.get(), sizeof(data), data, NULL)); + EXPECT_THAT(data, + ElementsAre(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, + 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0)); +} + +} // namespace diff --git a/runtime/src/iree/tooling/BUILD.bazel b/runtime/src/iree/tooling/BUILD.bazel index 8f99c9ad31db..07ebfabea8c5 100644 --- a/runtime/src/iree/tooling/BUILD.bazel +++ b/runtime/src/iree/tooling/BUILD.bazel @@ -43,7 +43,6 @@ iree_runtime_cc_library( hdrs = ["comparison.h"], deps = [ ":buffer_view_matchers", - ":vm_util", "//runtime/src/iree/base", "//runtime/src/iree/base/internal:flags", "//runtime/src/iree/base/internal:span", @@ -58,7 +57,7 @@ iree_runtime_cc_test( srcs = ["comparison_test.cc"], deps = [ ":comparison", - ":vm_util", + ":function_io", "//runtime/src/iree/base", "//runtime/src/iree/base/internal:span", "//runtime/src/iree/hal", @@ -108,6 +107,49 @@ iree_runtime_cc_library( ], ) +iree_runtime_cc_library( + name = "function_io", + srcs = ["function_io.c"], + hdrs = ["function_io.h"], + deps = [ + ":numpy_io", + "//runtime/src/iree/base", + "//runtime/src/iree/hal", + "//runtime/src/iree/io:stdio_stream", + "//runtime/src/iree/io:stream", + "//runtime/src/iree/io:vec_stream", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/vm", + ], +) + +iree_runtime_cc_test( + name = "function_io_test", + srcs = ["function_io_test.cc"], + deps = [ + ":function_io", + "//runtime/src/iree/base", + "//runtime/src/iree/hal", + "//runtime/src/iree/io:vec_stream", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/testing:gtest", + "//runtime/src/iree/testing:gtest_main", + "//runtime/src/iree/vm", + ], +) + +iree_runtime_cc_library( + name = "function_util", + srcs = ["function_util.c"], + hdrs = ["function_util.h"], + deps = [ + "//runtime/src/iree/base", + "//runtime/src/iree/hal", + "//runtime/src/iree/modules/hal", + "//runtime/src/iree/vm", + ], +) + iree_runtime_cc_library( name = "instrument_util", srcs = ["instrument_util.c"], @@ -129,6 +171,7 @@ iree_runtime_cc_library( deps = [ "//runtime/src/iree/base", "//runtime/src/iree/hal", + "//runtime/src/iree/io:stream", ], ) @@ -139,7 +182,8 @@ iree_runtime_cc_test( deps = [ ":device_util", ":numpy_io", - "//runtime/src/iree/base/internal:file_io", + "//runtime/src/iree/io:memory_stream", + "//runtime/src/iree/io:vec_stream", "//runtime/src/iree/testing:gtest", "//runtime/src/iree/testing:gtest_main", "//runtime/src/iree/tooling/testdata/npy", @@ -176,45 +220,15 @@ iree_runtime_cc_library( ":comparison", ":context_util", ":device_util", + ":function_io", + ":function_util", ":instrument_util", - ":vm_util", "//runtime/src/iree/base", "//runtime/src/iree/base/internal:flags", "//runtime/src/iree/hal", + "//runtime/src/iree/io:stdio_stream", "//runtime/src/iree/modules/hal:types", "//runtime/src/iree/vm", "//runtime/src/iree/vm/bytecode:module", ], ) - -# TODO(benvanik): fold these into iree/runtime and use that instead. -iree_runtime_cc_library( - name = "vm_util", - srcs = ["vm_util.c"], - hdrs = ["vm_util.h"], - deps = [ - ":numpy_io", - "//runtime/src/iree/base", - "//runtime/src/iree/base/internal:file_io", - "//runtime/src/iree/hal", - "//runtime/src/iree/modules/hal", - "//runtime/src/iree/vm", - "//runtime/src/iree/vm/bytecode:module", - ], -) - -iree_runtime_cc_test( - name = "vm_util_test", - srcs = ["vm_util_test.cc"], - deps = [ - ":device_util", - ":vm_util", - "//runtime/src/iree/base", - "//runtime/src/iree/base/internal:span", - "//runtime/src/iree/hal", - "//runtime/src/iree/modules/hal", - "//runtime/src/iree/testing:gtest", - "//runtime/src/iree/testing:gtest_main", - "//runtime/src/iree/vm", - ], -) diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt index 24531cbc0df9..2b5a1c83f53a 100644 --- a/runtime/src/iree/tooling/CMakeLists.txt +++ b/runtime/src/iree/tooling/CMakeLists.txt @@ -48,7 +48,6 @@ iree_cc_library( "comparison.cc" DEPS ::buffer_view_matchers - ::vm_util iree::base iree::base::internal::flags iree::base::internal::span @@ -65,7 +64,7 @@ iree_cc_test( "comparison_test.cc" DEPS ::comparison - ::vm_util + ::function_io iree::base iree::base::internal::span iree::hal @@ -120,6 +119,56 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + function_io + HDRS + "function_io.h" + SRCS + "function_io.c" + DEPS + ::numpy_io + iree::base + iree::hal + iree::io::stdio_stream + iree::io::stream + iree::io::vec_stream + iree::modules::hal + iree::vm + PUBLIC +) + +iree_cc_test( + NAME + function_io_test + SRCS + "function_io_test.cc" + DEPS + ::function_io + iree::base + iree::hal + iree::io::vec_stream + iree::modules::hal + iree::testing::gtest + iree::testing::gtest_main + iree::vm +) + +iree_cc_library( + NAME + function_util + HDRS + "function_util.h" + SRCS + "function_util.c" + DEPS + iree::base + iree::hal + iree::modules::hal + iree::vm + PUBLIC +) + iree_cc_library( NAME instrument_util @@ -147,6 +196,7 @@ iree_cc_library( DEPS iree::base iree::hal + iree::io::stream PUBLIC ) @@ -158,7 +208,8 @@ iree_cc_test( DEPS ::device_util ::numpy_io - iree::base::internal::file_io + iree::io::memory_stream + iree::io::vec_stream iree::testing::gtest iree::testing::gtest_main iree::tooling::testdata::npy @@ -202,52 +253,19 @@ iree_cc_library( ::comparison ::context_util ::device_util + ::function_io + ::function_util ::instrument_util - ::vm_util iree::base iree::base::internal::flags iree::hal + iree::io::stdio_stream iree::modules::hal::types iree::vm iree::vm::bytecode::module PUBLIC ) -iree_cc_library( - NAME - vm_util - HDRS - "vm_util.h" - SRCS - "vm_util.c" - DEPS - ::numpy_io - iree::base - iree::base::internal::file_io - iree::hal - iree::modules::hal - iree::vm - iree::vm::bytecode::module - PUBLIC -) - -iree_cc_test( - NAME - vm_util_test - SRCS - "vm_util_test.cc" - DEPS - ::device_util - ::vm_util - iree::base - iree::base::internal::span - iree::hal - iree::modules::hal - iree::testing::gtest - iree::testing::gtest_main - iree::vm -) - ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### # We're co-opting the VMVX module loader option for this as the inline-static diff --git a/runtime/src/iree/tooling/comparison.cc b/runtime/src/iree/tooling/comparison.cc index 65aab556f066..31f53e1db309 100644 --- a/runtime/src/iree/tooling/comparison.cc +++ b/runtime/src/iree/tooling/comparison.cc @@ -14,7 +14,6 @@ #include "iree/hal/api.h" #include "iree/modules/hal/module.h" #include "iree/tooling/buffer_view_matchers.h" -#include "iree/tooling/vm_util.h" #include "iree/vm/api.h" using namespace iree; diff --git a/runtime/src/iree/tooling/comparison_test.cc b/runtime/src/iree/tooling/comparison_test.cc index d5f65c07444e..b79364aae907 100644 --- a/runtime/src/iree/tooling/comparison_test.cc +++ b/runtime/src/iree/tooling/comparison_test.cc @@ -12,7 +12,7 @@ #include "iree/modules/hal/module.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" -#include "iree/tooling/vm_util.h" +#include "iree/tooling/function_io.h" #include "iree/vm/api.h" namespace iree { @@ -22,6 +22,7 @@ using ::testing::HasSubstr; static void ParseToVariantList(iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_string_view_t cconv, iree::span input_strings, iree_allocator_t host_allocator, iree_vm_list_t** out_list) { @@ -30,9 +31,11 @@ static void ParseToVariantList(iree_hal_device_t* device, input_string_views[i].data = input_strings[i].data(); input_string_views[i].size = input_strings[i].size(); } - IREE_CHECK_OK(iree_tooling_parse_to_variant_list( - device, device_allocator, input_string_views.data(), - input_string_views.size(), host_allocator, out_list)); + IREE_CHECK_OK(iree_tooling_parse_variants( + cconv, + iree_string_view_list_t{input_string_views.size(), + input_string_views.data()}, + device, device_allocator, host_allocator, out_list)); } class ComparisonTest : public ::testing::Test { @@ -51,15 +54,15 @@ class ComparisonTest : public ::testing::Test { } bool ParseAndCompareVariantLists( - iree::span expected_strings, + iree_string_view_t cconv, iree::span expected_strings, iree::span actual_strings, std::string* out_string) { vm::ref expected_list; - ParseToVariantList(/*device=*/NULL, device_allocator_, expected_strings, - host_allocator_, &expected_list); + ParseToVariantList(/*device=*/NULL, device_allocator_, cconv, + expected_strings, host_allocator_, &expected_list); vm::ref actual_list; - ParseToVariantList(/*device=*/NULL, device_allocator_, actual_strings, - host_allocator_, &actual_list); + ParseToVariantList(/*device=*/NULL, device_allocator_, cconv, + actual_strings, host_allocator_, &actual_list); iree_string_builder_t builder; iree_string_builder_initialize(host_allocator_, &builder); @@ -82,7 +85,8 @@ TEST_F(ComparisonTest, CompareEqualLists) { std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]"; auto buf_strings = std::vector{buf_string1, buf_string2}; std::string result; - EXPECT_TRUE(ParseAndCompareVariantLists(buf_strings, buf_strings, &result)); + EXPECT_TRUE(ParseAndCompareVariantLists(IREE_SV("rr"), buf_strings, + buf_strings, &result)); EXPECT_EQ(result, ""); } @@ -94,8 +98,8 @@ TEST_F(ComparisonTest, CompareListsWithIgnored) { auto expected_strings = std::vector{buf_string1, buf_string2_ignored}; std::string result; - EXPECT_TRUE( - ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_TRUE(ParseAndCompareVariantLists(IREE_SV("rr"), expected_strings, + actual_strings, &result)); EXPECT_EQ(result, ""); } @@ -105,8 +109,8 @@ TEST_F(ComparisonTest, CompareTruncatedLists) { auto actual_strings = std::vector{buf_string1, buf_string2}; auto expected_strings = std::vector{buf_string1}; std::string result; - EXPECT_FALSE( - ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_FALSE(ParseAndCompareVariantLists(IREE_SV("rr"), expected_strings, + actual_strings, &result)); EXPECT_THAT(result, HasSubstr("expected 1 list elements but 2 provided")); } @@ -118,8 +122,8 @@ TEST_F(ComparisonTest, CompareDifferingLists) { auto expected_strings = std::vector{buf_string1, buf_string2_good}; std::string result; - EXPECT_FALSE( - ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); + EXPECT_FALSE(ParseAndCompareVariantLists(IREE_SV("rr"), expected_strings, + actual_strings, &result)); EXPECT_THAT( result, HasSubstr("element at index 2 (999) does not match the expected (3)")); @@ -127,15 +131,17 @@ TEST_F(ComparisonTest, CompareDifferingLists) { TEST_F(ComparisonTest, CompareListsWithDifferingTypes) { std::string buf_string1 = "2x2xi32=[42 43][44 45]"; - std::string buf_string2 = "123"; + std::string buf_string2 = "2xi32"; std::string buf_string2_good = "2x3xf64=[1 2 3][4 5 6]"; auto actual_strings = std::vector{buf_string1, buf_string2}; auto expected_strings = std::vector{buf_string1, buf_string2_good}; std::string result; - EXPECT_FALSE( - ParseAndCompareVariantLists(expected_strings, actual_strings, &result)); - EXPECT_THAT(result, HasSubstr("variant types mismatch")); + EXPECT_FALSE(ParseAndCompareVariantLists(IREE_SV("rr"), expected_strings, + actual_strings, &result)); + EXPECT_THAT( + result, + HasSubstr("metadata is 2xi32; expected that the view matches 2x3xf64")); } } // namespace diff --git a/runtime/src/iree/tooling/function_io.c b/runtime/src/iree/tooling/function_io.c new file mode 100644 index 000000000000..d416a0dfd7ce --- /dev/null +++ b/runtime/src/iree/tooling/function_io.c @@ -0,0 +1,1198 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/function_io.h" + +#include "iree/io/stdio_stream.h" +#include "iree/io/stream.h" +#include "iree/modules/hal/module.h" +#include "iree/tooling/numpy_io.h" + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// NOTE: this will get moved at some point but is staged here while it's figured +// out. I'm still not sure how best to factor things so that this doesn't get +// pulled in all the time even if IO is never used. We may end up needing some +// kind of registry that the main iree_io_stream_open uses or allow +// iree_io_file_handle_t to carry a factory function for opening the handles of +// certain types. For now we shim things here at the leaf. +static iree_status_t iree_io_stream_open_path(iree_io_stdio_stream_mode_t mode, + iree_string_view_t path, + uint64_t file_offset, + iree_allocator_t host_allocator, + iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size); + + iree_status_t status = iree_ok_status(); + iree_io_stream_t* stream = NULL; + + status = iree_io_stdio_stream_open(mode, path, host_allocator, &stream); + if (iree_status_is_ok(status) && file_offset > 0) { + status = iree_io_stream_seek(stream, IREE_IO_STREAM_SEEK_SET, file_offset); + } + + if (iree_status_is_ok(status)) { + *out_stream = stream; + } else { + iree_io_stream_release(stream); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// iree_io_stream_list_t +//===----------------------------------------------------------------------===// + +typedef struct iree_io_stream_list_entry_t { + iree_string_view_t path; + iree_io_stream_t* stream; + // + path char storage of path.size +} iree_io_stream_list_entry_t; + +// A list of streams indexed by their verbatim paths. +// Streams track their positions and repeated accesses will continue where +// they left off when reading and writing. +// +// NOTE: this is not thread-safe or safe for concurrent access to the same +// streams via different accesses. Re-opening a stream will reset the stream +// offset to 0 and any extant references may end up doing the wrong thing. +typedef struct iree_io_stream_list_t { + iree_allocator_t host_allocator; + iree_io_stdio_stream_mode_t mode; + iree_host_size_t capacity; + iree_host_size_t count; + iree_io_stream_list_entry_t** entries; +} iree_io_stream_list_t; + +// Allocates a new stream list where all streams will share the same |mode|. +iree_status_t iree_io_stream_list_allocate(iree_io_stdio_stream_mode_t mode, + iree_allocator_t host_allocator, + iree_io_stream_list_t** out_list) { + IREE_ASSERT_ARGUMENT(out_list); + *out_list = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_io_stream_list_t* list = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*list), (void**)&list)); + list->host_allocator = host_allocator; + list->mode = mode; + + *out_list = list; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Frees a stream list and releases all stream resources. +void iree_io_stream_list_free(iree_io_stream_list_t* list) { + if (!list) return; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < list->count; ++i) { + iree_io_stream_list_entry_t* entry = list->entries[i]; + iree_io_stream_release(entry->stream); + iree_allocator_free(list->host_allocator, entry); + } + iree_allocator_free(list->host_allocator, list->entries); + + iree_allocator_free(list->host_allocator, list); + + IREE_TRACE_ZONE_END(z0); +} + +// Returns the entry matching the given verbatim |path| or NULL if not found. +static iree_io_stream_list_entry_t* iree_io_stream_list_find_entry( + iree_io_stream_list_t* list, iree_string_view_t path) { + IREE_ASSERT_ARGUMENT(list); + for (iree_host_size_t i = 0; i < list->count; ++i) { + iree_io_stream_list_entry_t* entry = list->entries[i]; + if (iree_string_view_equal(path, entry->path)) { + return entry; + } + } + return NULL; +} + +// Appends a stream to the list. The |path| will be cloned into the list +// storage and the stream will be retained until the list is freed. +static iree_status_t iree_io_stream_list_append_entry( + iree_io_stream_list_t* list, iree_string_view_t path, + iree_io_stream_t* stream) { + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(stream); + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size); + + // Grow if needed. + if (list->count + 1 > list->capacity) { + iree_host_size_t new_capacity = iree_max(list->capacity * 2, 16); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_realloc(list->host_allocator, + new_capacity * sizeof(*list->entries[0]), + (void**)&list->entries)); + list->capacity = new_capacity; + } + + // Allocate the entry and its string storage. + iree_io_stream_list_entry_t* entry = NULL; + iree_status_t status = iree_allocator_malloc( + list->host_allocator, sizeof(*entry) + path.size, (void**)&entry); + if (iree_status_is_ok(status)) { + entry->path.data = (const char*)entry + sizeof(*entry); + entry->path.size = path.size; + memcpy((void*)entry->path.data, path.data, path.size); + entry->stream = stream; + iree_io_stream_retain(entry->stream); + } + + // Store in list. + if (iree_status_is_ok(status)) { + list->entries[list->count++] = entry; + } else { + iree_allocator_free(list->host_allocator, entry); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Opens an |entry| that has already been opened, resetting its position to +// 0 if needed or directly returning it if appending. +static iree_status_t iree_io_stream_list_open_existing( + iree_io_stream_list_t* list, iree_string_view_t path, + iree_io_stream_list_entry_t* entry, bool is_append, + iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(entry); + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size); + + iree_status_t status = iree_ok_status(); + + // Reset the stream back to 0 if not appending. Today we require seekable + // streams for this. We could re-open the file but the only non-seekable + // stream we have today is stdin and reopening that isn't possible without + // buffering that we don't/won't do. + if (!is_append) { + if (iree_all_bits_set(iree_io_stream_mode(entry->stream), + IREE_IO_STREAM_MODE_SEEKABLE)) { + status = iree_io_stream_seek(entry->stream, IREE_IO_STREAM_SEEK_SET, 0); + } else { + status = iree_make_status( + IREE_STATUS_UNIMPLEMENTED, + "opened stream from `%.*s` is not seekable and cannot be reopened", + (int)path.size, path.data); + } + } + + iree_io_stream_retain(entry->stream); + *out_stream = entry->stream; + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Opens a stream at |path|, possibly reusing an existing stream if it has +// already been opened. If |append| is true the stream will be returned at its +// existing position for continued reading/writing and otherwise it will be +// reset to position 0. +iree_status_t iree_io_stream_list_open(iree_io_stream_list_t* list, + iree_string_view_t path, bool is_append, + iree_io_stream_t** out_stream) { + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(out_stream); + *out_stream = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, path.data, path.size); + + // Lookup an existing entry - if found we can reuse it. + iree_io_stream_list_entry_t* entry = + iree_io_stream_list_find_entry(list, path); + if (entry) { + iree_status_t status = iree_io_stream_list_open_existing( + list, path, entry, is_append, out_stream); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Open the file at the path specified. + iree_io_stream_t* stream = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_open_path(list->mode, path, 0ull, list->host_allocator, + &stream)); + + // Append the stream entry to the list so it's retained for future opens. + iree_status_t status = iree_io_stream_list_append_entry(list, path, stream); + + if (iree_status_is_ok(status)) { + *out_stream = stream; + } else { + iree_io_stream_release(stream); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Parsing +//===----------------------------------------------------------------------===// + +static iree_status_t iree_tooling_consume_cconv_any(iree_string_view_t* cconv, + char* out_type) { + IREE_ASSERT_ARGUMENT(out_type); + *out_type = 0; + if (cconv->size == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "function expected fewer input values"); + } + *out_type = cconv->data[0]; + ++cconv->data; + --cconv->size; + return iree_ok_status(); +} + +static iree_status_t iree_tooling_consume_cconv(iree_string_view_t* cconv, + char expected_type) { + char actual_type = 0; + IREE_RETURN_IF_ERROR(iree_tooling_consume_cconv_any(cconv, &actual_type)); + if (actual_type != expected_type) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "function signature mismatch: expected cconv type " + "`%c` but provided type `%c`", + expected_type, actual_type); + } + return iree_ok_status(); +} + +static iree_status_t iree_tooling_parse_null_into(iree_string_view_t* cconv, + iree_string_view_t string, + iree_vm_list_t* list) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, string.data, string.size); + + // Get the expected cconv type so we can handle 0 for primitives and + // NULL for refs. + char cconv_type = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_consume_cconv_any(cconv, &cconv_type)); + + // Add the appropriate variant type to the list. + iree_status_t status = iree_ok_status(); + switch (cconv_type) { + case IREE_VM_CCONV_TYPE_I32: { + iree_vm_value_t value = iree_vm_value_make_i32(0); + status = iree_vm_list_push_value(list, &value); + break; + } + case IREE_VM_CCONV_TYPE_F32: { + iree_vm_value_t value = iree_vm_value_make_f32(0.0f); + status = iree_vm_list_push_value(list, &value); + break; + } + case IREE_VM_CCONV_TYPE_I64: { + iree_vm_value_t value = iree_vm_value_make_i64(0ll); + status = iree_vm_list_push_value(list, &value); + break; + } + case IREE_VM_CCONV_TYPE_F64: { + iree_vm_value_t value = iree_vm_value_make_f64(0.0); + status = iree_vm_list_push_value(list, &value); + break; + } + case IREE_VM_CCONV_TYPE_REF: { + iree_vm_ref_t null_ref = iree_vm_ref_null(); + status = iree_vm_list_push_ref_retain(list, &null_ref); + break; + } + default: { + status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "unimplemented cconv type `%c`", cconv_type); + break; + } + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_parse_primitive_into( + iree_string_view_t* cconv, iree_string_view_t string, iree_vm_list_t* list, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, string.data, string.size); + + // Get the expected cconv type to help us parse the primitive value. + char cconv_type = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_consume_cconv_any(cconv, &cconv_type)); + + iree_status_t status = iree_ok_status(); + switch (cconv_type) { + case IREE_VM_CCONV_TYPE_I32: { + iree_vm_value_t value = iree_vm_value_make_i32(0); + if (iree_string_view_atoi_int32(string, &value.i32)) { + status = iree_vm_list_push_value(list, &value); + } else { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "parsing value `%.*s` as i32", + (int)string.size, string.data); + } + break; + } + case IREE_VM_CCONV_TYPE_F32: { + iree_vm_value_t value = iree_vm_value_make_f32(0.0f); + if (iree_string_view_atof(string, &value.f32)) { + status = iree_vm_list_push_value(list, &value); + } else { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "parsing value `%.*s` as f32", + (int)string.size, string.data); + } + break; + } + case IREE_VM_CCONV_TYPE_I64: { + iree_vm_value_t value = iree_vm_value_make_i64(0ll); + if (iree_string_view_atoi_int64(string, &value.i64)) { + status = iree_vm_list_push_value(list, &value); + } else { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "parsing value `%.*s` as i64", + (int)string.size, string.data); + } + break; + } + case IREE_VM_CCONV_TYPE_F64: { + iree_vm_value_t value = iree_vm_value_make_f64(0.0); + if (iree_string_view_atod(string, &value.f64)) { + status = iree_vm_list_push_value(list, &value); + } else { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "parsing value `%.*s` as f64", + (int)string.size, string.data); + } + break; + } + default: { + status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "unimplemented cconv type `%c`", cconv_type); + break; + } + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_parse_buffer_view_file_callback( + iree_hal_buffer_mapping_t* mapping, void* user_data) { + iree_io_stream_t* stream = (iree_io_stream_t*)user_data; + return iree_io_stream_read(stream, mapping->contents.data_length, + mapping->contents.data, + /*out_buffer_length=*/NULL); +} + +// Creates a HAL buffer view with the given |metadata| and reads the contents +// from the file reference in |string| which has the prefix `@` to indicate +// the contents starting from 0 and `+` for the next contents in an already +// opened stream. +// The file contents are directly read in to memory with no processing. +static iree_status_t iree_tooling_parse_buffer_view_file( + iree_string_view_t metadata, iree_string_view_t string, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_io_stream_list_t* stream_list, + iree_hal_buffer_view_t** out_buffer_view) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, string.data, string.size); + + // Parse shape and element type used to allocate the buffer view. + iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; + iree_host_size_t shape_rank = 0; + iree_hal_dim_t shape[128] = {0}; + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_parse_shape_and_element_type( + metadata, IREE_ARRAYSIZE(shape), + &shape_rank, shape, &element_type)); + + // TODO(benvanik): allow specifying the encoding. + iree_hal_encoding_type_t encoding_type = + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR; + + // @ = open new + // + = append + bool is_append = !iree_string_view_starts_with(string, IREE_SV("@")); + iree_string_view_t path = + iree_string_view_substr(string, 1, IREE_HOST_SIZE_MAX); + + // Open (or retrieve) the file. + iree_io_stream_t* stream = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_list_open(stream_list, path, is_append, &stream)); + + // TODO(benvanik): support mapping on allocators that can handle importing + // host memory. We only want to do this when it won't hurt performance + // (unified memory systems and where the mapped memory meets alignment + // requirements). For now we always wire new device memory to read into. + // A real application would want to either do the import or use + // iree_hal_file_t to stream the file contents into device memory without + // going through host memory. + + // Read the stream contents into the buffer. + iree_hal_buffer_params_t buffer_params = { + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }; + iree_status_t status = iree_hal_buffer_view_generate_buffer( + device, device_allocator, shape_rank, shape, element_type, encoding_type, + buffer_params, iree_tooling_parse_buffer_view_file_callback, stream, + out_buffer_view); + + iree_io_stream_release(stream); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Parses a shaped tensor type into a HAL buffer view. +static iree_status_t iree_tooling_parse_tensor( + iree_string_view_t string, iree_hal_device_t* device, + iree_hal_allocator_t* device_allocator, iree_io_stream_list_t* stream_list, + iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { + // If contents are sourced from a file then route to that, and otherwise + // parse as a normal HAL buffer view with inline contents (or none). + iree_string_view_t metadata, contents; + if (iree_string_view_split(string, '=', &metadata, &contents) != -1) { + if (iree_string_view_starts_with(contents, IREE_SV("@")) || + iree_string_view_starts_with(contents, IREE_SV("+"))) { + return iree_tooling_parse_buffer_view_file(metadata, contents, device, + device_allocator, stream_list, + out_buffer_view); + } + } + + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, string.data, string.size); + iree_status_t status = iree_hal_buffer_view_parse( + string, device, device_allocator, out_buffer_view); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Parses a shaped tensor type into a HAL buffer view and appends it to |list|. +static iree_status_t iree_tooling_parse_tensor_into( + iree_string_view_t* cconv, iree_string_view_t string, iree_vm_list_t* list, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_io_stream_list_t* stream_list, iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Expect a ref holding the buffer view. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_tooling_consume_cconv(cconv, 'r')); + + // Expect tensors to have a shape/type. Kinda sketchy but filters out some + // typos (scalar values). + bool has_equal = + iree_string_view_find_char(string, '=', 0) != IREE_STRING_VIEW_NPOS; + bool has_x = + iree_string_view_find_char(string, 'x', 0) != IREE_STRING_VIEW_NPOS; + if (!has_equal && !has_x) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid tensor specification, requires at least a " + "shape/type (`4x2xf32=...`)"); + } + + // Parse the tensor contents. + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_parse_tensor(string, device, device_allocator, + stream_list, host_allocator, &buffer_view)); + + // Add buffer view to list. + iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); + iree_status_t status = iree_vm_list_push_ref_retain(list, &buffer_view_ref); + + iree_hal_buffer_view_release(buffer_view); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Parses a shaped tensor type with optional contents to size a HAL buffer for +// output storage and appends it to |list|. +static iree_status_t iree_tooling_parse_storage_into( + iree_string_view_t* cconv, iree_string_view_t string, iree_vm_list_t* list, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_io_stream_list_t* stream_list, iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Expect a ref holding the buffer. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_tooling_consume_cconv(cconv, 'r')); + + // Parse the tensor contents. + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_parse_tensor(string, device, device_allocator, + stream_list, host_allocator, &buffer_view)); + + // Add just the storage buffer to the list - we don't need the metadata. + iree_vm_ref_t buffer_ref = + iree_hal_buffer_move_ref(iree_hal_buffer_view_buffer(buffer_view)); + iree_status_t status = iree_vm_list_push_ref_retain(list, &buffer_ref); + + iree_hal_buffer_view_release(buffer_view); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Parses a single ndarray from |stream| as a HAL buffer view and appends it to +// |list|. +static iree_status_t iree_tooling_parse_ndarray_into( + iree_string_view_t* cconv, iree_vm_list_t* list, iree_io_stream_t* stream, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Expect a ref holding the buffer view. + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_tooling_consume_cconv(cconv, 'r')); + + iree_hal_buffer_params_t buffer_params = { + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + .access = IREE_HAL_MEMORY_ACCESS_READ, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }; + iree_hal_buffer_view_t* buffer_view = NULL; + iree_status_t status = iree_numpy_npy_load_ndarray( + stream, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params, device, + device_allocator, &buffer_view); + + if (iree_status_is_ok(status)) { + iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); + status = iree_vm_list_push_ref_retain(list, &buffer_view_ref); + } + + iree_hal_buffer_view_release(buffer_view); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Parses zero or more variants from a file into the |list|. +// The |string| defines the file mode (`@` new, `+` existing, `*` splat) and +// the path to source from. +static iree_status_t iree_tooling_parse_file_into( + iree_string_view_t* cconv, iree_string_view_t string, iree_vm_list_t* list, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_io_stream_list_t* stream_list, iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, string.data, string.size); + + // @ = open new + // +/* = append + // * = splat + bool is_append = !iree_string_view_starts_with(string, IREE_SV("@")); + bool is_splat = iree_string_view_starts_with(string, IREE_SV("*")); + iree_string_view_t path = + iree_string_view_substr(string, 1, IREE_HOST_SIZE_MAX); + + // Today we only support numpy files here but could make this pluggable or at + // least a little smarter (sniff file header/etc) instead of relying on ext. + if (!iree_string_view_ends_with(path, IREE_SV(".npy"))) { + return iree_make_status( + IREE_STATUS_UNIMPLEMENTED, + "only numpy (.npy) files are supported for metadata-less variant I/O"); + } + + // Open (or retrieve) the file. + iree_io_stream_t* stream = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_io_stream_list_open(stream_list, path, is_append, &stream)); + + iree_status_t status = iree_ok_status(); + if (!is_splat) { + // Read a single ndarray from the stream at the current offset. + status = iree_tooling_parse_ndarray_into(cconv, list, stream, device, + device_allocator); + } else { + // Read zero or more ndarrays from the stream - note that it may already be + // at EOS. + while (iree_status_is_ok(status) && !iree_io_stream_is_eos(stream)) { + status = iree_tooling_parse_ndarray_into(cconv, list, stream, device, + device_allocator); + } + } + + iree_io_stream_release(stream); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_parse_variant_into( + iree_string_view_t* cconv, iree_string_view_t string, iree_vm_list_t* list, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_io_stream_list_t* stream_list, iree_allocator_t host_allocator) { + if (iree_string_view_is_empty(string)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no value specified for input"); + } else if (iree_string_view_equal(string, IREE_SV("(null)")) || + iree_string_view_equal(string, IREE_SV("(ignored)"))) { + return iree_tooling_parse_null_into(cconv, string, list); + } else if (iree_string_view_starts_with(string, IREE_SV("@")) || + iree_string_view_starts_with(string, IREE_SV("+")) || + iree_string_view_starts_with(string, IREE_SV("*"))) { + return iree_tooling_parse_file_into(cconv, string, list, device, + device_allocator, stream_list, + host_allocator); + } else if (iree_string_view_consume_prefix(&string, IREE_SV("&"))) { + return iree_tooling_parse_storage_into(cconv, string, list, device, + device_allocator, stream_list, + host_allocator); + } else if (!iree_string_view_starts_with(*cconv, IREE_SV("r"))) { + return iree_tooling_parse_primitive_into(cconv, string, list, device, + device_allocator, host_allocator); + } + // Shaped tensor as a buffer view. + // NOTE: we could support more things here - strings, VM buffers or lists, + // etc. Today if it's not a null or primitive value it's a tensor. + return iree_tooling_parse_tensor_into(cconv, string, list, device, + device_allocator, stream_list, + host_allocator); +} + +static iree_status_t iree_tooling_parse_variants_into( + iree_string_view_t cconv, iree_string_view_list_t specs, + iree_vm_list_t* list, iree_hal_device_t* device, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator) { + IREE_ASSERT_ARGUMENT(list); + IREE_TRACE_ZONE_BEGIN(z0); + + // List of opened streams used for allowing multiple arguments to source from + // the same file sequentially. + iree_io_stream_list_t* stream_list = NULL; + IREE_RETURN_IF_ERROR(iree_io_stream_list_allocate( + IREE_IO_STDIO_STREAM_MODE_READ, host_allocator, &stream_list)); + + // Parse each variant string. Note that some strings may expand to zero or + // more variants and so we need to consume the cconv based on how many were + // parsed. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < specs.count; ++i) { + iree_string_view_t string = iree_string_view_trim(specs.values[i]); + status = iree_status_annotate_f( + iree_tooling_parse_variant_into(&cconv, string, list, device, + device_allocator, stream_list, + host_allocator), + "parsing input `%.*s`", (int)string.size, string.data); + if (!iree_status_is_ok(status)) break; + } + + iree_io_stream_list_free(stream_list); + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_parse_variants( + iree_string_view_t cconv, iree_string_view_list_t specs, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_allocator_t host_allocator, iree_vm_list_t** out_list) { + IREE_ASSERT_ARGUMENT(out_list); + *out_list = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Argument list that will be populated - possibly returning with 0 entries. + iree_vm_list_t* list = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_create(iree_vm_make_undefined_type_def(), specs.count, + host_allocator, &list)); + + // Parse into the argument list. + iree_status_t status = iree_tooling_parse_variants_into( + cconv, specs, list, device, device_allocator, host_allocator); + + if (iree_status_is_ok(status)) { + *out_list = list; + } else { + iree_vm_list_release(list); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Printing +//===----------------------------------------------------------------------===// + +#define IREE_PRINT_VARIANT_CASE_I(SIZE, B, V) \ + case IREE_VM_VALUE_TYPE_I##SIZE: \ + status = iree_string_builder_append_format(B, "i" #SIZE "=%" PRIi##SIZE, \ + (V).i##SIZE); \ + break; + +#define IREE_PRINT_VARIANT_CASE_F(SIZE, B, V) \ + case IREE_VM_VALUE_TYPE_F##SIZE: \ + status = \ + iree_string_builder_append_format(B, "f" #SIZE "=%g", (V).f##SIZE); \ + break; + +static iree_status_t iree_tooling_format_variant( + iree_vm_variant_t variant, iree_host_size_t max_element_count, + iree_string_builder_t* builder) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); + if (iree_vm_variant_is_empty(variant)) { + status = iree_string_builder_append_string(builder, IREE_SV("(null)")); + } else if (iree_vm_variant_is_value(variant)) { + switch (iree_vm_type_def_as_value(variant.type)) { + IREE_PRINT_VARIANT_CASE_I(8, builder, variant) + IREE_PRINT_VARIANT_CASE_I(16, builder, variant) + IREE_PRINT_VARIANT_CASE_I(32, builder, variant) + IREE_PRINT_VARIANT_CASE_I(64, builder, variant) + IREE_PRINT_VARIANT_CASE_F(32, builder, variant) + IREE_PRINT_VARIANT_CASE_F(64, builder, variant) + default: + status = iree_string_builder_append_string(builder, IREE_SV("?")); + break; + } + } else if (iree_vm_variant_is_ref(variant)) { + iree_string_view_t type_name = + iree_vm_ref_type_name(iree_vm_type_def_as_ref(variant.type)); + status = iree_string_builder_append_string(builder, type_name); + if (iree_status_is_ok(status)) { + status = iree_string_builder_append_string(builder, IREE_SV("\n")); + } + if (iree_status_is_ok(status)) { + if (iree_vm_list_isa(variant.ref)) { + iree_vm_list_t* child_list = iree_vm_list_deref(variant.ref); + status = iree_tooling_format_variants(IREE_SV("child_list"), child_list, + max_element_count, builder); + } else if (iree_hal_buffer_view_isa(variant.ref)) { + iree_hal_buffer_view_t* buffer_view = + iree_hal_buffer_view_deref(variant.ref); + status = iree_hal_buffer_view_append_to_builder( + buffer_view, max_element_count, builder); + } else { + // TODO(benvanik): a way for ref types to describe themselves. + status = + iree_string_builder_append_string(builder, IREE_SV("(no printer)")); + } + } + } else { + status = iree_string_builder_append_string(builder, IREE_SV("(null)")); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_format_variants(iree_string_view_t list_name, + iree_vm_list_t* list, + iree_host_size_t max_element_count, + iree_string_builder_t* builder) { + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(builder); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_variant_t variant = iree_vm_variant_empty(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_variant_assign(list, i, &variant)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_string_builder_append_format( + builder, "%.*s[%" PRIhsz "]: ", (int)list_name.size, + list_name.data, i)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_tooling_format_variant(variant, max_element_count, builder)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_string_builder_append_string(builder, IREE_SV("\n"))); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_print_variant( + iree_vm_variant_t variant, iree_host_size_t max_element_count, + iree_io_stream_t* stream, iree_allocator_t host_allocator) { + IREE_ASSERT_ARGUMENT(stream); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_string_builder_t builder; + iree_string_builder_initialize(host_allocator, &builder); + + iree_status_t status = + iree_tooling_format_variant(variant, max_element_count, &builder); + if (iree_status_is_ok(status)) { + status = iree_string_builder_append_string(&builder, IREE_SV("\n")); + } + if (iree_status_is_ok(status)) { + status = iree_io_stream_write(stream, iree_string_builder_size(&builder), + iree_string_builder_buffer(&builder)); + } + + iree_string_builder_deinitialize(&builder); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_print_variants(iree_string_view_t list_name, + iree_vm_list_t* list, + iree_host_size_t max_element_count, + iree_io_stream_t* stream, + iree_allocator_t host_allocator) { + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(stream); + IREE_TRACE_ZONE_BEGIN(z0); + + // Reused across each variant print to amortize allocations. + iree_string_builder_t builder; + iree_string_builder_initialize(host_allocator, &builder); + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_variant_t variant = iree_vm_variant_empty(); + status = iree_vm_list_get_variant_assign(list, i, &variant); + if (!iree_status_is_ok(status)) break; + status = iree_string_builder_append_format( + &builder, "%.*s[%" PRIhsz "]: ", (int)list_name.size, list_name.data, + i); + if (!iree_status_is_ok(status)) break; + status = iree_tooling_format_variant(variant, max_element_count, &builder); + if (!iree_status_is_ok(status)) break; + status = iree_string_builder_append_string(&builder, IREE_SV("\n")); + if (!iree_status_is_ok(status)) break; + status = iree_io_stream_write(stream, iree_string_builder_size(&builder), + iree_string_builder_buffer(&builder)); + if (!iree_status_is_ok(status)) break; + iree_string_builder_reset(&builder); + } + + iree_string_builder_deinitialize(&builder); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Writing +//===----------------------------------------------------------------------===// + +static iree_status_t iree_tooling_create_buffer_view_with_hal_buffer( + iree_hal_buffer_t* hal_buffer, iree_allocator_t host_allocator, + iree_hal_buffer_view_t** out_buffer_view) { + iree_hal_dim_t shape[1] = { + (iree_hal_dim_t)iree_hal_buffer_byte_length(hal_buffer), + }; + return iree_hal_buffer_view_create( + hal_buffer, IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_INT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, host_allocator, out_buffer_view); +} + +static void iree_hal_buffer_release_vm_buffer( + void* user_data, struct iree_hal_buffer_t* buffer) { + iree_vm_buffer_release((iree_vm_buffer_t*)user_data); +} + +static iree_status_t iree_tooling_create_buffer_view_with_vm_buffer( + iree_vm_buffer_t* vm_buffer, iree_hal_allocator_t* device_allocator, + iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { + // Get read-only pointer to the underlying buffer heap memory. + iree_const_byte_span_t span = iree_const_byte_span_empty(); + IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro( + vm_buffer, 0, iree_vm_buffer_length(vm_buffer), 1, &span)); + + // Wrap the heap memory in a HAL buffer for read-only access. + iree_hal_buffer_release_callback_t release_callback = { + .fn = iree_hal_buffer_release_vm_buffer, + .user_data = vm_buffer, + }; + iree_vm_buffer_retain(vm_buffer); + iree_hal_buffer_t* hal_buffer = NULL; + iree_status_t status = iree_hal_heap_buffer_wrap( + device_allocator, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + IREE_HAL_MEMORY_ACCESS_READ, + IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE | IREE_HAL_BUFFER_USAGE_MAPPING, + span.data_length, iree_cast_const_byte_span(span), release_callback, + &hal_buffer); + iree_vm_buffer_release(vm_buffer); + + // Wrap the HAL buffer in a buffer view. + if (iree_status_is_ok(status)) { + status = iree_tooling_create_buffer_view_with_hal_buffer( + hal_buffer, host_allocator, out_buffer_view); + } + + iree_hal_buffer_release(hal_buffer); + return status; +} + +static iree_status_t iree_tooling_create_buffer_view_empty( + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, + iree_hal_buffer_view_t** out_buffer_view) { + iree_hal_buffer_t* hal_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_heap_buffer_wrap( + device_allocator, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + IREE_HAL_MEMORY_ACCESS_READ, + IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE | IREE_HAL_BUFFER_USAGE_MAPPING, 0, + iree_byte_span_empty(), iree_hal_buffer_release_callback_null(), + &hal_buffer)); + iree_status_t status = iree_tooling_create_buffer_view_with_hal_buffer( + hal_buffer, host_allocator, out_buffer_view); + iree_hal_buffer_release(hal_buffer); + return status; +} + +static iree_status_t iree_tooling_create_buffer_view_with_value( + iree_vm_value_t value, iree_hal_allocator_t* device_allocator, + iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { + iree_device_size_t byte_length = 0; + iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; + switch (value.type) { + case IREE_VM_VALUE_TYPE_NONE: + return iree_tooling_create_buffer_view_empty( + device_allocator, host_allocator, out_buffer_view); + case IREE_VM_VALUE_TYPE_I8: + byte_length = sizeof(value.i8); + element_type = IREE_HAL_ELEMENT_TYPE_INT_8; + break; + case IREE_VM_VALUE_TYPE_I16: + byte_length = sizeof(value.i16); + element_type = IREE_HAL_ELEMENT_TYPE_INT_16; + break; + case IREE_VM_VALUE_TYPE_I32: + byte_length = sizeof(value.i32); + element_type = IREE_HAL_ELEMENT_TYPE_INT_32; + break; + case IREE_VM_VALUE_TYPE_I64: + byte_length = sizeof(value.i64); + element_type = IREE_HAL_ELEMENT_TYPE_INT_64; + break; + case IREE_VM_VALUE_TYPE_F32: + byte_length = sizeof(value.f32); + element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32; + break; + case IREE_VM_VALUE_TYPE_F64: + byte_length = sizeof(value.f64); + element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_64; + break; + default: + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "unsupported value type"); + } + + iree_hal_buffer_params_t params = { + .usage = + IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + }; + iree_hal_buffer_t* hal_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + device_allocator, params, byte_length, &hal_buffer)); + + iree_status_t status = iree_hal_buffer_map_write( + hal_buffer, 0, value.value_storage, byte_length); + + if (iree_status_is_ok(status)) { + status = iree_hal_buffer_view_create(hal_buffer, /*shape_rank=*/0, + /*shape=*/NULL, element_type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + host_allocator, out_buffer_view); + } + + iree_hal_buffer_release(hal_buffer); + return status; +} + +static iree_status_t iree_tooling_create_buffer_view_from_variant( + iree_vm_variant_t variant, iree_hal_allocator_t* device_allocator, + iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { + *out_buffer_view = NULL; + if (iree_vm_variant_is_empty(variant)) { + // Empty value - we need to emit a zero-length value to keep the npy file + // ordered when there are multiple entries. + return iree_tooling_create_buffer_view_empty( + device_allocator, host_allocator, out_buffer_view); + } else if (iree_vm_variant_is_ref(variant)) { + if (iree_hal_buffer_view_isa(variant.ref)) { + // Buffer view returned can provide the metadata required. + *out_buffer_view = iree_hal_buffer_view_deref(variant.ref); + iree_hal_buffer_view_retain(*out_buffer_view); + return iree_ok_status(); + } else if (iree_hal_buffer_isa(variant.ref)) { + // i8 buffer view of the total length of the HAL buffer. + iree_hal_buffer_t* buffer = iree_hal_buffer_deref(variant.ref); + return iree_tooling_create_buffer_view_with_hal_buffer( + buffer, host_allocator, out_buffer_view); + } else if (iree_vm_buffer_isa(variant.ref)) { + // i8 buffer view of the total length of the VM buffer wrapped in a HAL + // buffer. + iree_vm_buffer_t* buffer = iree_vm_buffer_deref(variant.ref); + return iree_tooling_create_buffer_view_with_vm_buffer( + buffer, device_allocator, host_allocator, out_buffer_view); + } else { + // Unsupported type. + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unsupported output source type; expected: " + "!hal.buffer, !hal.buffer_view, !vm.buffer"); + } + } else { + // Primitive value that we wrap in a scalar buffer view. + return iree_tooling_create_buffer_view_with_value( + iree_vm_variant_value(variant), device_allocator, host_allocator, + out_buffer_view); + } +} + +static iree_status_t iree_tooling_write_variant_to_npy_file( + iree_io_stream_t* stream, iree_vm_variant_t variant, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator) { + // npy files require buffer views so if we receive anything but a buffer view + // we wrap it in one typed as bytes. + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR(iree_tooling_create_buffer_view_from_variant( + variant, device_allocator, host_allocator, &buffer_view)); + + // Append buffer view contents to the file stream. + iree_numpy_npy_save_options_t options = IREE_NUMPY_NPY_SAVE_OPTION_DEFAULT; + iree_status_t status = + iree_numpy_npy_save_ndarray(stream, options, buffer_view, host_allocator); + + iree_hal_buffer_view_release(buffer_view); + return status; +} + +static iree_status_t iree_tooling_write_variant_to_binary_file( + iree_io_stream_t* stream, iree_vm_variant_t variant, + iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator) { + // Today we reuse the buffer view code to get the variant into a byte buffer + // to write out even though we don't use any of the metadata. This is a + // command line tool writing out files using stdio and not an example of how + // to create a high performance I/O mechanism. + iree_hal_buffer_view_t* buffer_view = NULL; + IREE_RETURN_IF_ERROR(iree_tooling_create_buffer_view_from_variant( + variant, device_allocator, host_allocator, &buffer_view)); + iree_device_size_t byte_length = + iree_hal_buffer_view_byte_length(buffer_view); + + // Map the buffer memory into a host pointer so we can access it. + iree_hal_buffer_mapping_t mapping; + iree_status_t status = iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(buffer_view), IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &mapping); + + // Write to the file from the mapped memory. + if (iree_status_is_ok(status)) { + status = iree_io_stream_write(stream, byte_length, mapping.contents.data); + } + + iree_status_ignore(iree_hal_buffer_unmap_range(&mapping)); + + iree_hal_buffer_view_release(buffer_view); + return status; +} + +static iree_status_t iree_tooling_write_variant_to_file( + iree_vm_variant_t variant, iree_string_view_t spec, + iree_allocator_t host_allocator) { + // Open the output file based on the spec. + iree_io_stdio_stream_mode_t mode = 0; + if (iree_string_view_consume_prefix(&spec, IREE_SV("@"))) { + mode |= IREE_IO_STDIO_STREAM_MODE_WRITE | IREE_IO_STDIO_STREAM_MODE_DISCARD; + } else if (iree_string_view_consume_prefix(&spec, IREE_SV("+"))) { + mode |= IREE_IO_STDIO_STREAM_MODE_WRITE | IREE_IO_STDIO_STREAM_MODE_APPEND; + } else { + // We only support overwrite and append for now. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "unsupported output mode specification '%.*s'", + (int)spec.size, spec.data); + } + iree_io_stream_t* stream = NULL; + IREE_RETURN_IF_ERROR( + iree_io_stdio_stream_open(mode, spec, host_allocator, &stream)); + + // Dummy heap used for allocating transient variants. + // This is wasteful to cycle for each one but we don't care about it in the + // tooling. + iree_hal_allocator_t* device_allocator = NULL; + iree_status_t status = iree_hal_allocator_create_heap( + IREE_SV("tooling"), host_allocator, host_allocator, &device_allocator); + + // Output format is based on file extension with ones we don't know about + // going into binary mode. Some formats require metadata from buffer views + // but in binary mode we just dump whatever contents we have and leave it up + // to the user to handle the shape/type/encoding. + if (iree_status_is_ok(status)) { + if (iree_string_view_ends_with(spec, IREE_SV(".npy"))) { + status = iree_tooling_write_variant_to_npy_file( + stream, variant, device_allocator, host_allocator); + } else { + status = iree_tooling_write_variant_to_binary_file( + stream, variant, device_allocator, host_allocator); + } + } + + iree_hal_allocator_release(device_allocator); + iree_io_stream_release(stream); + return status; +} + +static iree_status_t iree_tooling_write_variant( + iree_vm_variant_t variant, iree_string_view_t spec, + iree_host_size_t max_element_count, iree_io_stream_t* default_stream, + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + if (iree_string_view_is_empty(spec)) { + // Send into the void. + } else if (iree_string_view_equal(spec, IREE_SV("-"))) { + // Route to the default stream (if provided). + if (default_stream) { + status = iree_tooling_print_variant(variant, max_element_count, + default_stream, host_allocator); + } + } else { + // Write to a file. + status = iree_tooling_write_variant_to_file(variant, spec, host_allocator); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_write_variants(iree_vm_list_t* list, + iree_string_view_list_t specs, + iree_host_size_t max_element_count, + iree_io_stream_t* default_stream, + iree_allocator_t host_allocator) { + IREE_ASSERT_ARGUMENT(list); + IREE_TRACE_ZONE_BEGIN(z0); + + if (iree_vm_list_size(list) != specs.count) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "%" PRIhsz + " outputs specified but the provided variant list only has %" PRIhsz + " elements", + specs.count, iree_vm_list_size(list)); + } + + for (iree_host_size_t i = 0; i < specs.count; ++i) { + iree_vm_variant_t variant = iree_vm_variant_empty(); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_vm_list_get_variant_assign(list, i, &variant)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_tooling_write_variant(variant, specs.values[i], max_element_count, + default_stream, host_allocator)); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} diff --git a/runtime/src/iree/tooling/function_io.h b/runtime/src/iree/tooling/function_io.h new file mode 100644 index 000000000000..d4dd382275bb --- /dev/null +++ b/runtime/src/iree/tooling/function_io.h @@ -0,0 +1,120 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_TOOLING_FUNCTION_IO_H_ +#define IREE_TOOLING_FUNCTION_IO_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/io/stream.h" +#include "iree/vm/api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Parsing +//===----------------------------------------------------------------------===// + +// Parses zero or more variants from the provided |specs| list. +// |device_allocator| is used for any HAL buffers required and on devices that +// require it an optional |device| will be used for transfer operations. +// +// Supported input string specifiers (and examples): +// - Special values: +// `(null)` (a null vm ref) +// - Primitive value types: +// `123` (i32) +// `3.14` (f32) +// - Shaped tensor types (using HAL buffer view parsing): +// `f32=1.2` (tensor) +// `2x2xf32=1,2,3,4` (tensor<2x2xf32>) +// `2x2xi32=[[1 2][3 4]]` (tensor<2x2xi32>) +// - Numpy files: +// `@file.npy` (first array from the file) +// `+file.npy` (next array from the file) +// `*file.npy` (all following arrays from the file) +// - Binary files: +// `2x2xf32=@file.ext` (dense tensor<2x2xf32> at the start of the file) +// `4xf32=+file.ext` (dense tensor<4xf32> following the prior input) +// - Storage buffers for output arguments (shape/type used for sizing): +// `&4xf32` (tensor<4xf32> as a HAL buffer for output operands) +// `&4xf32=1,2,3,4` (tensor<4xf32> storage with an initial value) +iree_status_t iree_tooling_parse_variants( + iree_string_view_t cconv, iree_string_view_list_t specs, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, + iree_allocator_t host_allocator, iree_vm_list_t** out_list); + +//===----------------------------------------------------------------------===// +// Printing +//===----------------------------------------------------------------------===// + +// Formats a variant list to |builder| as strings with newlines between them. +// |list_name| will be printed alongside each element ordinal. When printing +// |max_element_count| is used to limit the number of buffer view elements. +// The textual format of this is subject to change. +// +// Prints scalars in the format: +// value +// Prints buffers in the IREE standard shaped buffer format: +// [shape]xtype=[value] +// described in +// https://github.com/openxla/iree/tree/main/runtime/src/iree/hal/api.h +iree_status_t iree_tooling_format_variants(iree_string_view_t list_name, + iree_vm_list_t* list, + iree_host_size_t max_element_count, + iree_string_builder_t* builder); + +// Prints a variant list to |stream|. +// |list_name| will be printed alongside each element ordinal. When printing +// |max_element_count| is used to limit the number of buffer view elements. +// The textual format of this is subject to change. +// +// Prints scalars in the format: +// value +// Prints buffers in the IREE standard shaped buffer format: +// [shape]xtype=[value] +// described in +// https://github.com/openxla/iree/tree/main/runtime/src/iree/hal/api.h +iree_status_t iree_tooling_print_variants(iree_string_view_t list_name, + iree_vm_list_t* list, + iree_host_size_t max_element_count, + iree_io_stream_t* stream, + iree_allocator_t host_allocator); + +//===----------------------------------------------------------------------===// +// Writing +//===----------------------------------------------------------------------===// + +// Outputs a variant list to |stream| or the targets defined by |specs|. +// If provided values will be printed to |default_stream| ala +// iree_tooling_print_variants if their spec is `-`. When printing +// |max_element_count| is used to limit the number of buffer +// view elements. The textual format of this is subject to change. +// +// Supported string output specifiers (and examples): +// - Ignore a list element (don't output): +// `` +// - Print to the specified |default_stream| ala iree_tooling_print_variants: +// `-` +// - Numpy files: +// `@file.npy` (write array from the specified file, discarding) +// `+file.npy` (append array to the specified file) +// - Binary files: +// `@file.ext` (write buffer contents to the specified file, discarding) +// `+file.ext` (append buffer contents to the specified file) +iree_status_t iree_tooling_write_variants(iree_vm_list_t* list, + iree_string_view_list_t specs, + iree_host_size_t max_element_count, + iree_io_stream_t* default_stream, + iree_allocator_t host_allocator); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_TOOLING_FUNCTION_IO_H_ diff --git a/runtime/src/iree/tooling/function_io_test.cc b/runtime/src/iree/tooling/function_io_test.cc new file mode 100644 index 000000000000..db2a1b043d58 --- /dev/null +++ b/runtime/src/iree/tooling/function_io_test.cc @@ -0,0 +1,134 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/function_io.h" + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/io/vec_stream.h" +#include "iree/modules/hal/module.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" +#include "iree/vm/api.h" + +namespace iree { +namespace { + +struct FunctionIOTest : public ::testing::Test { + virtual void SetUp() { + host_allocator = iree_allocator_system(); + IREE_ASSERT_OK(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, + host_allocator, &instance)); + IREE_ASSERT_OK(iree_hal_module_register_all_types(instance)); + IREE_ASSERT_OK(iree_hal_allocator_create_heap( + IREE_SV("test"), host_allocator, host_allocator, &device_allocator)); + } + + virtual void TearDown() { + iree_hal_allocator_release(device_allocator); + iree_vm_instance_release(instance); + } + + Status ParseToVariantList(iree_string_view_t cconv, + std::vector input_strings, + iree_vm_list_t** out_list) { + std::vector input_string_views(input_strings.size()); + for (size_t i = 0; i < input_strings.size(); ++i) { + input_string_views[i].data = input_strings[i].data(); + input_string_views[i].size = input_strings[i].size(); + } + return iree_tooling_parse_variants( + cconv, + iree_string_view_list_t{input_string_views.size(), + input_string_views.data()}, + /*device=*/NULL, device_allocator, host_allocator, out_list); + } + + Status PrintVariantList(iree_vm_list_t* variant_list, + std::string* out_string) { + iree_io_stream_t* stream = NULL; + IREE_RETURN_IF_ERROR(iree_io_vec_stream_create( + IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE | + IREE_IO_STREAM_MODE_SEEKABLE, + /*block_size=*/32 * 1024, host_allocator, &stream)); + iree_status_t status = iree_tooling_print_variants( + IREE_SV("result"), variant_list, /*max_element_count=*/1024, stream, + host_allocator); + if (iree_status_is_ok(status)) { + status = iree_io_stream_seek(stream, IREE_IO_STREAM_SEEK_SET, 0); + } + if (iree_status_is_ok(status)) { + out_string->resize(iree_io_stream_length(stream)); + status = iree_io_stream_read(stream, out_string->size(), + out_string->data(), NULL); + } + iree_io_stream_release(stream); + return status; + } + + iree_allocator_t host_allocator; + iree_vm_instance_t* instance = nullptr; + iree_hal_allocator_t* device_allocator = nullptr; +}; + +TEST_F(FunctionIOTest, ParsePrintBuffer) { + std::string buf_string = "&2x2xi32=[42 43][44 45]"; + vm::ref variant_list; + IREE_ASSERT_OK(ParseToVariantList( + IREE_SV("r"), std::vector{buf_string}, &variant_list)); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, + std::string("result[0]: hal.buffer\n") + "(no printer)" + "\n"); +} + +TEST_F(FunctionIOTest, ParsePrintBufferView) { + std::string buf_string = "2x2xi32=[42 43][44 45]"; + vm::ref variant_list; + IREE_ASSERT_OK(ParseToVariantList( + IREE_SV("r"), std::vector{buf_string}, &variant_list)); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, + std::string("result[0]: hal.buffer_view\n") + buf_string + "\n"); +} + +TEST_F(FunctionIOTest, ParsePrintScalar) { + std::string input_string = "42"; + vm::ref variant_list; + IREE_ASSERT_OK(ParseToVariantList( + IREE_SV("i"), std::vector{input_string}, &variant_list)); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: i32=") + input_string + "\n"); +} + +TEST_F(FunctionIOTest, ParsePrintRank0BufferView) { + std::string buf_string = "i32=42"; + vm::ref variant_list; + IREE_ASSERT_OK(ParseToVariantList( + IREE_SV("r"), std::vector{buf_string}, &variant_list)); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, + std::string("result[0]: hal.buffer_view\n") + buf_string + "\n"); +} + +TEST_F(FunctionIOTest, ParsePrintMultipleBufferViews) { + std::string buf_string1 = "2x2xi32=[42 43][44 45]"; + std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]"; + vm::ref variant_list; + IREE_ASSERT_OK(ParseToVariantList( + IREE_SV("rr"), std::vector{buf_string1, buf_string2}, + &variant_list)); + std::string result; + IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); + EXPECT_EQ(result, std::string("result[0]: hal.buffer_view\n") + buf_string1 + + "\nresult[1]: hal.buffer_view\n" + buf_string2 + "\n"); +} + +} // namespace +} // namespace iree diff --git a/runtime/src/iree/tooling/function_util.c b/runtime/src/iree/tooling/function_util.c new file mode 100644 index 000000000000..4e4d71daea4d --- /dev/null +++ b/runtime/src/iree/tooling/function_util.c @@ -0,0 +1,239 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/tooling/function_util.h" + +#include "iree/modules/hal/module.h" + +iree_status_t iree_tooling_append_async_fences( + iree_vm_list_t* list, iree_vm_function_t function, + iree_hal_device_t* device, iree_hal_fence_t* wait_fence, + iree_hal_fence_t** out_signal_fence) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_string_view_t model = iree_vm_function_lookup_attr_by_name( + &function, IREE_SV("iree.abi.model")); + if (!iree_string_view_equal(model, IREE_SV("coarse-fences"))) { + // Ignore unknown models - the user may have provided their own fences. + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + // Create the signal fence as a 0->1 transition. The caller will wait on that. + iree_hal_semaphore_t* semaphore = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); + iree_hal_fence_t* signal_fence = NULL; + iree_status_t status = iree_hal_fence_create_at( + semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); + iree_hal_semaphore_release(semaphore); + + // Append (wait, signal) fences. + if (iree_status_is_ok(status)) { + iree_vm_ref_t wait_fence_ref = iree_hal_fence_retain_ref(wait_fence); + status = iree_vm_list_push_ref_move(list, &wait_fence_ref); + iree_vm_ref_release(&wait_fence_ref); + } + if (iree_status_is_ok(status)) { + iree_vm_ref_t signal_fence_ref = iree_hal_fence_retain_ref(signal_fence); + status = iree_vm_list_push_ref_move(list, &signal_fence_ref); + iree_vm_ref_release(&signal_fence_ref); + } + + if (iree_status_is_ok(status)) { + *out_signal_fence = signal_fence; + } else { + iree_hal_fence_release(signal_fence); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static bool iree_tooling_requires_buffer_transfer( + iree_hal_buffer_t* source_buffer, iree_hal_device_t* target_device, + iree_hal_buffer_params_t target_params) { + // TODO(benvanik): if source/target devices don't match or can't be imported + // then we need a transfer. + return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer), + target_params.type) || + !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer), + target_params.usage); +} + +static iree_status_t iree_tooling_setup_buffer_transfer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, + iree_hal_buffer_t** out_target_buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_allocator); + IREE_ASSERT_ARGUMENT(out_target_buffer); + *out_target_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_allocator_allocate_buffer( + target_allocator, target_params, + iree_hal_buffer_allocation_size(source_buffer), &target_buffer)); + + iree_status_t status = iree_hal_command_buffer_copy_buffer( + command_buffer, source_buffer, 0, target_buffer, 0, + iree_hal_buffer_byte_length(source_buffer)); + + if (iree_status_is_ok(status)) { + *out_target_buffer = target_buffer; + } else { + iree_hal_buffer_release(target_buffer); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_submit_transfer( + iree_hal_device_t* device, iree_hal_fence_t* wait_fence, + iree_hal_queue_affinity_t queue_affinity, + iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + bool needs_wait = signal_fence == NULL; + if (needs_wait) { + iree_hal_semaphore_t* semaphore = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); + status = iree_hal_fence_create_at( + semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); + iree_hal_semaphore_release(semaphore); + } else { + iree_hal_fence_retain(signal_fence); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_device_queue_execute( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer); + } + + if (iree_status_is_ok(status) && needs_wait) { + status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout()); + } + + iree_hal_fence_release(signal_fence); + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_transfer_variants( + iree_vm_list_t* list, iree_hal_device_t* target_device, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, + iree_hal_fence_t* signal_fence) { + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(target_device); + IREE_ASSERT_ARGUMENT(target_allocator); + IREE_TRACE_ZONE_BEGIN(z0); + + // If all buffers are already host-accessible we can skip the transfer. + bool requires_transfer = false; + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_ref_t value = iree_vm_ref_null(); + IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); + if (iree_hal_buffer_isa(value)) { + iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); + if (iree_tooling_requires_buffer_transfer(source_buffer, target_device, + target_params)) { + requires_transfer = true; + break; + } + } else if (iree_hal_buffer_view_isa(value)) { + iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_view); + if (iree_tooling_requires_buffer_transfer(source_buffer, target_device, + target_params)) { + requires_transfer = true; + break; + } + } + } + if (!requires_transfer) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_command_buffer_create( + target_device, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity, + /*binding_capacity=*/0, &command_buffer)); + + iree_status_t status = iree_hal_command_buffer_begin(command_buffer); + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_ref_t value = iree_vm_ref_null(); + IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); + if (iree_hal_buffer_isa(value)) { + iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); + if (!iree_tooling_requires_buffer_transfer(source_buffer, target_device, + target_params)) { + // Already ok. + continue; + } + iree_hal_buffer_t* target_buffer = NULL; + status = iree_tooling_setup_buffer_transfer( + command_buffer, source_buffer, target_allocator, target_params, + &target_buffer); + if (!iree_status_is_ok(status)) break; + status = iree_vm_list_set_buffer_retain(list, i, target_buffer); + iree_hal_buffer_release(target_buffer); + if (!iree_status_is_ok(status)) break; + } else if (iree_hal_buffer_view_isa(value)) { + iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_view); + if (!iree_tooling_requires_buffer_transfer(source_buffer, target_device, + target_params)) { + // Already ok. + continue; + } + iree_hal_buffer_t* target_buffer = NULL; + status = iree_tooling_setup_buffer_transfer( + command_buffer, source_buffer, target_allocator, target_params, + &target_buffer); + if (!iree_status_is_ok(status)) break; + iree_hal_buffer_view_t* target_view = NULL; + status = iree_hal_buffer_view_create_like( + target_buffer, source_view, + iree_hal_allocator_host_allocator(target_allocator), &target_view); + iree_hal_buffer_release(target_buffer); + if (!iree_status_is_ok(status)) break; + status = iree_vm_list_set_buffer_view_retain(list, i, target_view); + iree_hal_buffer_view_release(target_view); + if (!iree_status_is_ok(status)) break; + } + } + } + if (iree_status_is_ok(status)) { + status = iree_hal_command_buffer_end(command_buffer); + } + + if (iree_status_is_ok(status)) { + status = iree_tooling_submit_transfer(target_device, wait_fence, + target_params.queue_affinity, + command_buffer, signal_fence); + } + + iree_hal_command_buffer_release(command_buffer); + + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/runtime/src/iree/tooling/function_util.h b/runtime/src/iree/tooling/function_util.h new file mode 100644 index 000000000000..63369d526eea --- /dev/null +++ b/runtime/src/iree/tooling/function_util.h @@ -0,0 +1,43 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_TOOLING_FUNCTION_UTIL_H_ +#define IREE_TOOLING_FUNCTION_UTIL_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/vm/api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Appends fences to |list| if the invocation model of |function| requires them +// (has the `iree.abi.model` as `coarse-fences`). +// If no |wait_fence| is provided then the invocation will begin immediately. +// Upon return if |out_signal_fence| is not NULL the caller must wait on the +// returned |out_signal_fence| before accessing the contents of any buffers +// returned from the invocation. +iree_status_t iree_tooling_append_async_fences( + iree_vm_list_t* list, iree_vm_function_t function, + iree_hal_device_t* device, iree_hal_fence_t* wait_fence, + iree_hal_fence_t** out_signal_fence); + +// Transfers all buffers in |list| to ones using |target_params|. +// If no |wait_fence| is provided then the transfer will begin immediately. +// If no |signal_fence| is provided then the call will block until the transfer +// completes. +iree_status_t iree_tooling_transfer_variants( + iree_vm_list_t* list, iree_hal_device_t* target_device, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, + iree_hal_fence_t* signal_fence); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_TOOLING_FUNCTION_UTIL_H_ diff --git a/runtime/src/iree/tooling/numpy_io.c b/runtime/src/iree/tooling/numpy_io.c index 747bae3b1f43..43b2be2c1b56 100644 --- a/runtime/src/iree/tooling/numpy_io.c +++ b/runtime/src/iree/tooling/numpy_io.c @@ -29,7 +29,7 @@ // start of the file payload. // |out_header_buffer| must be freed by the caller with |host_allocator|. static iree_status_t iree_numpy_npy_read_header( - FILE* stream, iree_allocator_t host_allocator, + iree_io_stream_t* stream, iree_allocator_t host_allocator, iree_host_size_t* out_header_length, char** out_header_buffer) { IREE_ASSERT_ARGUMENT(stream); IREE_ASSERT_ARGUMENT(out_header_length); @@ -45,10 +45,9 @@ static iree_status_t iree_numpy_npy_read_header( uint8_t version_minor; } header; static_assert(sizeof(header) == 8, "packing"); - if (fread(&header, 1, sizeof(header), stream) != sizeof(header)) { - return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, - "unable to read entire header prefix"); - } + IREE_RETURN_IF_ERROR( + iree_io_stream_read(stream, sizeof(header), &header, NULL), + "unable to read entire header prefix"); // Verify magic bytes to confirm this is an npy file. static const uint8_t kMagicBytes[6] = {0x93, 'N', 'U', 'M', 'P', 'Y'}; @@ -73,23 +72,17 @@ static iree_status_t iree_numpy_npy_read_header( iree_host_size_t header_length = 0; if (header.version_major == 1) { uint16_t header_length_u16 = 0; - if (fread(&header_length_u16, 1, sizeof(header_length_u16), stream) != - sizeof(header_length_u16)) { - return iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "failed to read version %d.%d 2-byte header length", - header.version_major, header.version_minor); - } + IREE_RETURN_IF_ERROR(iree_io_stream_read(stream, sizeof(header_length_u16), + &header_length_u16, NULL), + "failed to read version %d.%d 2-byte header length", + header.version_major, header.version_minor); header_length = header_length_u16; } else { uint32_t header_length_u32 = 0; - if (fread(&header_length_u32, 1, sizeof(header_length_u32), stream) != - sizeof(header_length_u32)) { - return iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "failed to read version %d.%d 4-byte header length", - header.version_major, header.version_minor); - } + IREE_RETURN_IF_ERROR(iree_io_stream_read(stream, sizeof(header_length_u32), + &header_length_u32, NULL), + "failed to read version %d.%d 4-byte header length", + header.version_major, header.version_minor); header_length = header_length_u32; } @@ -101,11 +94,9 @@ static iree_status_t iree_numpy_npy_read_header( // Read entire header string, including padding/newline. iree_status_t status = iree_ok_status(); - if (fread(header_buffer, 1, header_length, stream) != header_length) { - status = iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "failed to read header string of %" PRIhsz " bytes", header_length); - } + IREE_RETURN_IF_ERROR( + iree_io_stream_read(stream, header_length, header_buffer, NULL), + "failed to read header string of %" PRIhsz " bytes", header_length); if (iree_status_is_ok(status)) { // Caller must free the string buffer. @@ -118,19 +109,17 @@ static iree_status_t iree_numpy_npy_read_header( } typedef struct { - FILE* stream; + iree_io_stream_t* stream; } iree_numpy_npy_read_params_t; static iree_status_t iree_numpy_npy_read_into_mapping( iree_hal_buffer_mapping_t* mapping, void* user_data) { iree_numpy_npy_read_params_t* params = (iree_numpy_npy_read_params_t*)user_data; - iree_host_size_t contents_length = mapping->contents.data_length; - if (fread(mapping->contents.data, 1, contents_length, params->stream) != - contents_length) { - return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, - "failed to read npy contents of %" PRIhsz " bytes", - contents_length); - } + IREE_RETURN_IF_ERROR( + iree_io_stream_read(params->stream, mapping->contents.data_length, + mapping->contents.data, NULL), + "failed to read npy contents of %" PRIhsz " bytes", + mapping->contents.data_length); return iree_ok_status(); } @@ -293,7 +282,7 @@ static iree_status_t iree_numpy_parse_shape_dims(iree_string_view_t shape, } IREE_API_EXPORT iree_status_t iree_numpy_npy_load_ndarray( - FILE* stream, iree_numpy_npy_load_options_t options, + iree_io_stream_t* stream, iree_numpy_npy_load_options_t options, iree_hal_buffer_params_t buffer_params, iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, iree_hal_buffer_view_t** out_buffer_view) { @@ -309,7 +298,7 @@ IREE_API_EXPORT iree_status_t iree_numpy_npy_load_ndarray( // if we failed trying to parse the header. Since npy files are often // concatenated callers are likely to be using this in a loop and checking for // this condition, even if it'd be better if they did it themselves. - if (feof(stream)) { + if (iree_io_stream_is_eos(stream)) { IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "end-of-file"); } @@ -502,7 +491,7 @@ static iree_status_t iree_numpy_npy_build_header( // alignment. |header_dict| should not have any trailing padding or the newline // character. static iree_status_t iree_numpy_npy_write_header( - FILE* stream, iree_numpy_npy_save_options_t options, + iree_io_stream_t* stream, iree_numpy_npy_save_options_t options, iree_string_view_t header_dict) { // v1 -> v2 if the header requires it; we don't but good to be conformant. bool requires_v2 = header_dict.size > 65535; @@ -520,10 +509,8 @@ static iree_status_t iree_numpy_npy_write_header( .version_minor = 0, }; static_assert(sizeof(header) == 8, "padding"); - if (fwrite(&header, 1, sizeof(header), stream) != sizeof(header)) { - return iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write header prefix"); - } + IREE_RETURN_IF_ERROR(iree_io_stream_write(stream, sizeof(header), &header), + "failed to write header prefix"); // Pad out what we write to 64b. // Note that this includes the header prefix, length, dict, and newline. @@ -536,46 +523,37 @@ static iree_status_t iree_numpy_npy_write_header( iree_host_size_t header_length = header_dict.size + padding_length + /*\n*/ 1; if (requires_v2) { uint32_t header_length_u32 = (uint32_t)header_length; - if (fwrite(&header_length_u32, 1, sizeof(header_length_u32), stream) != - sizeof(header_length_u32)) { - return iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write header length"); - } + IREE_RETURN_IF_ERROR(iree_io_stream_write(stream, sizeof(header_length_u32), + &header_length_u32), + "failed to write header length"); } else { uint16_t header_length_u16 = (uint16_t)header_length; - if (fwrite(&header_length_u16, 1, sizeof(header_length_u16), stream) != - sizeof(header_length_u16)) { - return iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write header length"); - } + IREE_RETURN_IF_ERROR(iree_io_stream_write(stream, sizeof(header_length_u16), + &header_length_u16), + "failed to write header length"); } // Write header contents (without padding/trailing newline). - if (fwrite(header_dict.data, 1, header_dict.size, stream) != - header_dict.size) { - return iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write header contents"); - } + IREE_RETURN_IF_ERROR( + iree_io_stream_write(stream, header_dict.size, header_dict.data), + "failed to write header contents"); // Add space padding up to 64b alignment (minus newline). - for (iree_host_size_t i = 0; i < padding_length; ++i) { - if (fputc(' ', stream) != ' ') { - return iree_make_status(IREE_STATUS_DATA_LOSS, "failed to pad header"); - } - } + const char space = ' '; + IREE_RETURN_IF_ERROR( + iree_io_stream_fill(stream, padding_length, &space, sizeof(space)), + "failed to pad header"); // Trailing newline, which should put us right at the %64=0 alignment. - if (fputc('\n', stream) != '\n') { - return iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write trailing newline"); - } + IREE_RETURN_IF_ERROR(iree_io_stream_write_char(stream, '\n'), + "failed to write trailing newline"); return iree_ok_status(); } // Writes |buffer_view| contents to |stream|. static iree_status_t iree_numpy_npy_write_bytes( - FILE* stream, iree_hal_buffer_view_t* buffer_view) { + iree_io_stream_t* stream, iree_hal_buffer_view_t* buffer_view) { iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view); iree_device_size_t write_length = iree_hal_buffer_view_byte_length(buffer_view); @@ -585,17 +563,16 @@ static iree_status_t iree_numpy_npy_write_bytes( buffer, IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0, write_length, &mapping)); - bool write_ok = - fwrite(mapping.contents.data, 1, write_length, stream) == write_length; + iree_status_t status = iree_status_annotate( + iree_io_stream_write(stream, write_length, mapping.contents.data), + IREE_SV("failed to write buffer contents")); - IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap_range(&mapping)); - return write_ok ? iree_ok_status() - : iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write buffer contents"); + IREE_IGNORE_ERROR(iree_hal_buffer_unmap_range(&mapping)); + return status; } IREE_API_EXPORT iree_status_t iree_numpy_npy_save_ndarray( - FILE* stream, iree_numpy_npy_save_options_t options, + iree_io_stream_t* stream, iree_numpy_npy_save_options_t options, iree_hal_buffer_view_t* buffer_view, iree_allocator_t host_allocator) { IREE_ASSERT_ARGUMENT(stream); IREE_ASSERT_ARGUMENT(buffer_view); diff --git a/runtime/src/iree/tooling/numpy_io.h b/runtime/src/iree/tooling/numpy_io.h index 0baaefc8470c..a9ea0be4c0f7 100644 --- a/runtime/src/iree/tooling/numpy_io.h +++ b/runtime/src/iree/tooling/numpy_io.h @@ -36,10 +36,9 @@ #ifndef IREE_TOOLING_NUMPY_IO_H_ #define IREE_TOOLING_NUMPY_IO_H_ -#include - #include "iree/base/api.h" #include "iree/hal/api.h" +#include "iree/io/stream.h" #ifdef __cplusplus extern "C" { @@ -84,7 +83,7 @@ typedef uint32_t iree_numpy_npy_save_options_t; // See `numpy.load`: // https://numpy.org/doc/stable/reference/generated/numpy.load.html IREE_API_EXPORT iree_status_t iree_numpy_npy_load_ndarray( - FILE* stream, iree_numpy_npy_load_options_t options, + iree_io_stream_t* stream, iree_numpy_npy_load_options_t options, iree_hal_buffer_params_t buffer_params, iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, iree_hal_buffer_view_t** out_buffer_view); @@ -95,7 +94,7 @@ IREE_API_EXPORT iree_status_t iree_numpy_npy_load_ndarray( // See `numpy.save`: // https://numpy.org/doc/stable/reference/generated/numpy.save.html IREE_API_EXPORT iree_status_t iree_numpy_npy_save_ndarray( - FILE* stream, iree_numpy_npy_save_options_t options, + iree_io_stream_t* stream, iree_numpy_npy_save_options_t options, iree_hal_buffer_view_t* buffer_view, iree_allocator_t host_allocator); #ifdef __cplusplus diff --git a/runtime/src/iree/tooling/numpy_io_test.cc b/runtime/src/iree/tooling/numpy_io_test.cc index 554f870ee00c..976cbe038064 100644 --- a/runtime/src/iree/tooling/numpy_io_test.cc +++ b/runtime/src/iree/tooling/numpy_io_test.cc @@ -6,7 +6,8 @@ #include "iree/tooling/numpy_io.h" -#include "iree/base/internal/file_io.h" +#include "iree/io/memory_stream.h" +#include "iree/io/vec_stream.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" #include "iree/tooling/device_util.h" @@ -19,6 +20,9 @@ using iree::testing::status::IsOk; using iree::testing::status::StatusIs; using ::testing::ElementsAreArray; +using StreamPtr = + std::unique_ptr; + class NumpyIOTest : public ::testing::Test { protected: virtual void SetUp() { @@ -36,53 +40,35 @@ class NumpyIOTest : public ::testing::Test { virtual void TearDown() { iree_hal_device_release(device_); } - static std::string GetTempFilename(const char* suffix) { - static int unique_id = 0; - char* test_tmpdir = getenv("TEST_TMPDIR"); - if (!test_tmpdir) { - test_tmpdir = getenv("TMPDIR"); - } - if (!test_tmpdir) { - test_tmpdir = getenv("TEMP"); - } - if (!test_tmpdir) { - std::cerr << "TEST_TMPDIR/TMPDIR/TEMP not defined\n"; - exit(1); - } - return test_tmpdir + std::string("/iree_test_") + - std::to_string(unique_id++) + '_' + suffix; - } - - FILE* OpenInputFile(const char* name) { + StreamPtr OpenInputFile(const char* name) { const struct iree_file_toc_t* file_toc = iree_numpy_npy_files_create(); for (size_t i = 0; i < iree_numpy_npy_files_size(); ++i) { if (strcmp(file_toc[i].name, name) != 0) continue; - auto file_path = GetTempFilename(name); - IREE_CHECK_OK(iree_file_write_contents( - file_path.c_str(), - iree_make_const_byte_span(file_toc[i].data, file_toc[i].size))); - return fopen(file_path.c_str(), "rb"); + iree_io_stream_t* stream = NULL; + IREE_CHECK_OK(iree_io_memory_stream_wrap( + IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_SEEKABLE, + iree_make_byte_span((void*)file_toc[i].data, file_toc[i].size), + iree_io_memory_stream_release_callback_null(), + iree_allocator_system(), &stream)); + return StreamPtr(stream, iree_io_stream_release); } - return NULL; + return StreamPtr{nullptr, iree_io_stream_release}; } - FILE* OpenOutputFile(const char* name) { - auto file_path = GetTempFilename(name); - return fopen(file_path.c_str(), "w+b"); + StreamPtr OpenOutputFile(const char* name) { + iree_io_stream_t* stream = NULL; + IREE_CHECK_OK(iree_io_vec_stream_create( + IREE_IO_STREAM_MODE_READABLE | IREE_IO_STREAM_MODE_WRITABLE | + IREE_IO_STREAM_MODE_SEEKABLE, + // /*block_size=*/32 * 1024, + /*block_size=*/64, iree_allocator_system(), &stream)); + return StreamPtr(stream, iree_io_stream_release); } iree_hal_device_t* device_ = nullptr; iree_hal_allocator_t* device_allocator_ = nullptr; }; -static bool IsEOF(FILE* stream) { - long original_pos = ftell(stream); - fseek(stream, 0, SEEK_END); - long end_pos = ftell(stream); - fseek(stream, original_pos, SEEK_SET); - return original_pos == end_pos; -} - template static void AssertBufferViewContents(iree_hal_buffer_view_t* buffer_view, std::vector shape, @@ -109,7 +95,8 @@ static void AssertBufferViewContents(iree_hal_buffer_view_t* buffer_view, } template -static void LoadArrayAndAssertContents(FILE* stream, iree_hal_device_t* device, +static void LoadArrayAndAssertContents(iree_io_stream_t* stream, + iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, std::vector shape, iree_hal_element_type_t element_type, @@ -130,10 +117,10 @@ static void LoadArrayAndAssertContents(FILE* stream, iree_hal_device_t* device, // Tests that an empty file returns EOF. TEST_F(NumpyIOTest, LoadEmptyFile) { - FILE* stream = OpenInputFile("empty.npy"); + auto stream = OpenInputFile("empty.npy"); // Should start at EOF - the file is empty. - ASSERT_TRUE(IsEOF(stream)); + ASSERT_TRUE(iree_io_stream_is_eos(stream.get())); // Try (and fail) to parse something from the empty file. iree_hal_buffer_params_t buffer_params = {}; @@ -142,183 +129,200 @@ TEST_F(NumpyIOTest, LoadEmptyFile) { buffer_params.type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL; iree_hal_buffer_view_t* buffer_view = NULL; EXPECT_THAT(Status(iree_numpy_npy_load_ndarray( - stream, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params, - device_, device_allocator_, &buffer_view)), - StatusIs(StatusCode::kResourceExhausted)); + stream.get(), IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, + buffer_params, device_, device_allocator_, &buffer_view)), + StatusIs(StatusCode::kOutOfRange)); // Should still be at EOF. - ASSERT_TRUE(IsEOF(stream)); - fclose(stream); + ASSERT_TRUE(iree_io_stream_is_eos(stream.get())); } // Tests loading a single array from a file. TEST_F(NumpyIOTest, LoadSingleArray) { - FILE* stream = OpenInputFile("single.npy"); + auto stream = OpenInputFile("single.npy"); // np.array([1.1, 2.2, 3.3], dtype=np.float32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {3}, IREE_HAL_ELEMENT_TYPE_FLOAT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1.1f, 2.2f, 3.3f}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {3}, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1.1f, 2.2f, 3.3f}); // Should have hit EOF. - ASSERT_TRUE(IsEOF(stream)); - fclose(stream); + ASSERT_TRUE(iree_io_stream_is_eos(stream.get())); } // Tests loading multiple arrays from a concatenated file. TEST_F(NumpyIOTest, LoadMultipleArrays) { - FILE* stream = OpenInputFile("multiple.npy"); + auto stream = OpenInputFile("multiple.npy"); // np.array([1.1, 2.2, 3.3], dtype=np.float32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {3}, IREE_HAL_ELEMENT_TYPE_FLOAT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1.1f, 2.2f, 3.3f}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {3}, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1.1f, 2.2f, 3.3f}); // np.array([[0, 1], [2, 3]], dtype=np.int32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2, 2}, IREE_HAL_ELEMENT_TYPE_SINT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {0, 1, 2, 3}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2, 2}, IREE_HAL_ELEMENT_TYPE_SINT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {0, 1, 2, 3}); // np.array(42, dtype=np.int32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {}, IREE_HAL_ELEMENT_TYPE_SINT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {42}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {}, IREE_HAL_ELEMENT_TYPE_SINT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {42}); // Should have hit EOF. - ASSERT_TRUE(IsEOF(stream)); - fclose(stream); + ASSERT_TRUE(iree_io_stream_is_eos(stream.get())); } // Tests loading arrays with various shapes. TEST_F(NumpyIOTest, ArrayShapes) { - FILE* stream = OpenInputFile("array_shapes.npy"); + auto stream = OpenInputFile("array_shapes.npy"); // np.array(1, dtype=np.int8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {}, IREE_HAL_ELEMENT_TYPE_SINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1}); // np.array([], dtype=np.int8) LoadArrayAndAssertContents( - stream, device_, device_allocator_, {0}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {}); + stream.get(), device_, device_allocator_, {0}, + IREE_HAL_ELEMENT_TYPE_SINT_8, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {}); // np.array([1], dtype=np.int8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {1}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {1}, IREE_HAL_ELEMENT_TYPE_SINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1}); // np.array([[1], [2]], dtype=np.int8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2, 1}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 2}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2, 1}, IREE_HAL_ELEMENT_TYPE_SINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 2}); // np.array([[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.int8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {8, 1}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {0, 1, 2, 3, 4, 5, 6, 7}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {8, 1}, IREE_HAL_ELEMENT_TYPE_SINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {0, 1, 2, 3, 4, 5, 6, 7}); // np.array([[1, 2], [3, 4]], dtype=np.int8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2, 2}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 2, 3, 4}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2, 2}, IREE_HAL_ELEMENT_TYPE_SINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 2, 3, 4}); // np.array([[[1], [2]], [[3], [4]]], dtype=np.int8) - LoadArrayAndAssertContents(stream, device_, device_allocator_, + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, {2, 2, 1}, IREE_HAL_ELEMENT_TYPE_SINT_8, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 2, 3, 4}); // Should have hit EOF. - ASSERT_TRUE(IsEOF(stream)); - fclose(stream); + ASSERT_TRUE(iree_io_stream_is_eos(stream.get())); } // Tests loading arrays with various element types. TEST_F(NumpyIOTest, ArrayTypes) { - FILE* stream = OpenInputFile("array_types.npy"); + auto stream = OpenInputFile("array_types.npy"); // np.array([True, False], dtype=np.bool_) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_BOOL_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 0}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_BOOL_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 0}); // np.array([-1, 1], dtype=np.int8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_SINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {-1, 1}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_SINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {-1, 1}); // np.array([-20000, 20000], dtype=np.int16) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_SINT_16, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {-20000, 20000}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_SINT_16, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {-20000, 20000}); // np.array([-2000000, 2000000], dtype=np.int32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_SINT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {-2000000, 2000000}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_SINT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {-2000000, 2000000}); // np.array([-20000000000, 20000000000], dtype=np.int64) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_SINT_64, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {-20000000000, 20000000000}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_SINT_64, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {-20000000000, 20000000000}); // np.array([1, 255], dtype=np.uint8) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_UINT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 255}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_UINT_8, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 255}); // np.array([1, 65535], dtype=np.uint16) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_UINT_16, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 65535}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_UINT_16, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 65535}); // np.array([1, 4294967295], dtype=np.uint32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_UINT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 4294967295u}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_UINT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 4294967295u}); // np.array([1, 18446744073709551615], dtype=np.uint64) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_UINT_64, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1, 18446744073709551615ull}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_UINT_64, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {1, 18446744073709551615ull}); // np.array([-1.1, 1.1], dtype=np.float16) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_FLOAT_16, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {0xBC66, 0x3C66}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_FLOAT_16, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {0xBC66, 0x3C66}); // np.array([-1.1, 1.1], dtype=np.float32) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_FLOAT_32, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {-1.1f, 1.1f}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_FLOAT_32, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {-1.1f, 1.1f}); // np.array([-1.1, 1.1], dtype=np.float64) - LoadArrayAndAssertContents( - stream, device_, device_allocator_, {2}, IREE_HAL_ELEMENT_TYPE_FLOAT_64, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {-1.1, 1.1}); + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_FLOAT_64, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + {-1.1, 1.1}); // np.array([1 + 5j, 2 + 6j], dtype=np.complex64) - LoadArrayAndAssertContents(stream, device_, device_allocator_, {2}, - IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, + LoadArrayAndAssertContents(stream.get(), device_, device_allocator_, + {2}, IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1.0f, 5.0f, 2.0f, 6.0f}); // np.array([-1.1, 1.1], dtype=np.float64) - LoadArrayAndAssertContents(stream, device_, device_allocator_, {2}, - IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - {1.0, 5.0, 2.0, 6.0}); + LoadArrayAndAssertContents( + stream.get(), device_, device_allocator_, {2}, + IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, {1.0, 5.0, 2.0, 6.0}); // Should have hit EOF. - ASSERT_TRUE(IsEOF(stream)); - fclose(stream); + ASSERT_TRUE(iree_io_stream_is_eos(stream.get())); } -static void RoundTripArrays(FILE* source_stream, FILE* target_stream, +static void RoundTripArrays(iree_io_stream_t* source_stream, + iree_io_stream_t* target_stream, iree_hal_device_t* device, iree_hal_allocator_t* device_allocator) { - while (!IsEOF(source_stream)) { + while (!iree_io_stream_is_eos(source_stream)) { iree_hal_buffer_params_t buffer_params = {}; buffer_params.usage = IREE_HAL_BUFFER_USAGE_TRANSFER; buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ; @@ -332,71 +336,68 @@ static void RoundTripArrays(FILE* source_stream, FILE* target_stream, iree_hal_allocator_host_allocator(device_allocator))); iree_hal_buffer_view_release(buffer_view); } - fflush(target_stream); } -static void CompareStreams(FILE* source_stream, FILE* target_stream) { - fseek(source_stream, 0, SEEK_END); - fseek(target_stream, 0, SEEK_END); - size_t source_size = ftell(source_stream); - size_t target_size = ftell(target_stream); +static void CompareStreams(iree_io_stream_t* source_stream, + iree_io_stream_t* target_stream) { + iree_io_stream_pos_t source_size = iree_io_stream_length(source_stream); + iree_io_stream_pos_t target_size = iree_io_stream_length(target_stream); ASSERT_EQ(source_size, target_size) << "streams should have the same length"; - fseek(source_stream, 0, SEEK_SET); - fseek(target_stream, 0, SEEK_SET); + + IREE_ASSERT_OK( + iree_io_stream_seek(source_stream, IREE_IO_STREAM_SEEK_SET, 0)); + IREE_ASSERT_OK( + iree_io_stream_seek(target_stream, IREE_IO_STREAM_SEEK_SET, 0)); std::vector source_data; source_data.resize(source_size); std::vector target_data; target_data.resize(target_size); - - ASSERT_EQ(source_data.size(), - fread(source_data.data(), 1, source_data.size(), source_stream)); - ASSERT_EQ(target_data.size(), - fread(target_data.data(), 1, target_data.size(), target_stream)); + IREE_ASSERT_OK(iree_io_stream_read(source_stream, source_data.size(), + source_data.data(), NULL)); + IREE_ASSERT_OK(iree_io_stream_read(target_stream, target_data.size(), + target_data.data(), NULL)); ASSERT_THAT(target_data, ElementsAreArray(source_data)); - ASSERT_EQ(IsEOF(source_stream), IsEOF(target_stream)) + ASSERT_EQ(iree_io_stream_is_eos(source_stream), + iree_io_stream_is_eos(target_stream)) << "streams should have the same length"; } // Tests round-tripping a single array. TEST_F(NumpyIOTest, RoundTripSingleArray) { - FILE* source_stream = OpenInputFile("single.npy"); - FILE* target_stream = OpenOutputFile("single_out.npy"); - RoundTripArrays(source_stream, target_stream, device_, device_allocator_); - CompareStreams(source_stream, target_stream); - fclose(source_stream); - fclose(target_stream); + auto source_stream = OpenInputFile("single.npy"); + auto target_stream = OpenOutputFile("single_out.npy"); + RoundTripArrays(source_stream.get(), target_stream.get(), device_, + device_allocator_); + CompareStreams(source_stream.get(), target_stream.get()); } // Tests round-tripping multiple array. TEST_F(NumpyIOTest, RoundTripMultipleArrays) { - FILE* source_stream = OpenInputFile("multiple.npy"); - FILE* target_stream = OpenOutputFile("multiple_out.npy"); - RoundTripArrays(source_stream, target_stream, device_, device_allocator_); - CompareStreams(source_stream, target_stream); - fclose(source_stream); - fclose(target_stream); + auto source_stream = OpenInputFile("multiple.npy"); + auto target_stream = OpenOutputFile("multiple_out.npy"); + RoundTripArrays(source_stream.get(), target_stream.get(), device_, + device_allocator_); + CompareStreams(source_stream.get(), target_stream.get()); } // Tests round-tripping arrays with various shapes. TEST_F(NumpyIOTest, RoundTripArrayShapes) { - FILE* source_stream = OpenInputFile("array_shapes.npy"); - FILE* target_stream = OpenOutputFile("array_shapes_out.npy"); - RoundTripArrays(source_stream, target_stream, device_, device_allocator_); - CompareStreams(source_stream, target_stream); - fclose(source_stream); - fclose(target_stream); + auto source_stream = OpenInputFile("array_shapes.npy"); + auto target_stream = OpenOutputFile("array_shapes_out.npy"); + RoundTripArrays(source_stream.get(), target_stream.get(), device_, + device_allocator_); + CompareStreams(source_stream.get(), target_stream.get()); } // Tests round-tripping arrays with various types. TEST_F(NumpyIOTest, RoundTripArrayTypes) { - FILE* source_stream = OpenInputFile("array_types.npy"); - FILE* target_stream = OpenOutputFile("array_types_out.npy"); - RoundTripArrays(source_stream, target_stream, device_, device_allocator_); - CompareStreams(source_stream, target_stream); - fclose(source_stream); - fclose(target_stream); + auto source_stream = OpenInputFile("array_types.npy"); + auto target_stream = OpenOutputFile("array_types_out.npy"); + RoundTripArrays(source_stream.get(), target_stream.get(), device_, + device_allocator_); + CompareStreams(source_stream.get(), target_stream.get()); } } // namespace diff --git a/runtime/src/iree/tooling/run_module.c b/runtime/src/iree/tooling/run_module.c index f7ed56a74772..9d1a093591df 100644 --- a/runtime/src/iree/tooling/run_module.c +++ b/runtime/src/iree/tooling/run_module.c @@ -9,12 +9,14 @@ #include "iree/base/api.h" #include "iree/base/internal/flags.h" #include "iree/hal/api.h" +#include "iree/io/stdio_stream.h" #include "iree/modules/hal/types.h" #include "iree/tooling/comparison.h" #include "iree/tooling/context_util.h" #include "iree/tooling/device_util.h" +#include "iree/tooling/function_io.h" +#include "iree/tooling/function_util.h" #include "iree/tooling/instrument_util.h" -#include "iree/tooling/vm_util.h" #include "iree/vm/api.h" #include "iree/vm/bytecode/module.h" @@ -89,8 +91,9 @@ IREE_FLAG( IREE_FLAG(bool, print_statistics, false, "Prints runtime statistics to stderr on exit."); -static iree_status_t iree_tooling_process_outputs( - iree_hal_device_t* device, iree_vm_list_t* outputs, +static iree_status_t iree_tooling_process_results( + iree_hal_device_t* device, iree_string_view_t results_cconv, + iree_vm_list_t* results, iree_io_stream_t* stream, iree_allocator_t host_allocator, int* out_exit_code); static iree_status_t iree_tooling_create_run_context( @@ -200,20 +203,27 @@ static iree_status_t iree_tooling_run_function( iree_string_view_t function_name = iree_vm_function_name(&function); (void)function_name; + iree_vm_function_signature_t signature = + iree_vm_function_signature(&function); + iree_string_view_t arguments_cconv, results_cconv; + iree_status_t status = iree_vm_function_call_get_cconv_fragments( + &signature, &arguments_cconv, &results_cconv); + // Parse --input= values into device buffers. iree_vm_list_t* inputs = NULL; - iree_status_t status = iree_status_annotate_f( - iree_tooling_parse_to_variant_list( - device, device_allocator, FLAG_input_list().values, - FLAG_input_list().count, host_allocator, &inputs), - "parsing function inputs"); + if (iree_status_is_ok(status)) { + status = iree_status_annotate_f( + iree_tooling_parse_variants(arguments_cconv, FLAG_input_list(), device, + device_allocator, host_allocator, &inputs), + "parsing function inputs"); + } // If the function is async add fences so we can invoke it synchronously. iree_hal_fence_t* finish_fence = NULL; if (iree_status_is_ok(status)) { status = iree_status_annotate_f( - iree_tooling_append_async_fence_inputs( - inputs, &function, device, /*wait_fence=*/NULL, &finish_fence), + iree_tooling_append_async_fences(inputs, function, device, + /*wait_fence=*/NULL, &finish_fence), "setting up async-external fence inputs"); } @@ -228,6 +238,7 @@ static iree_status_t iree_tooling_run_function( if (iree_status_is_ok(status)) { fprintf(stdout, "EXEC @%.*s\n", (int)function_name.size, function_name.data); + fflush(stdout); } // Begin profiling immediate prior to invocation. @@ -278,28 +289,41 @@ static iree_status_t iree_tooling_run_function( .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, .min_alignment = 0, }; - status = iree_tooling_transfer_variant_list( - device, outputs, device_allocator, target_params, + status = iree_tooling_transfer_variants( + outputs, device, device_allocator, target_params, /*wait_fence=*/NULL, /*signal_fence=*/NULL); } + // Wrap stdout for printing results. + iree_io_stream_t* stdout_stream = NULL; + if (iree_status_is_ok(status)) { + status = iree_status_annotate_f( + iree_io_stdio_stream_wrap(IREE_IO_STREAM_MODE_WRITABLE, stdout, + /*owns_handle=*/false, host_allocator, + &stdout_stream), + "opening stdout stream"); + } + // Handle either printing/writing the outputs or checking them against // expected values (basic pass/fail testing). if (iree_status_is_ok(status)) { status = iree_status_annotate_f( - iree_tooling_process_outputs(device, outputs, host_allocator, + iree_tooling_process_results(device, results_cconv, outputs, + stdout_stream, host_allocator, out_exit_code), "processing function outputs"); } iree_vm_list_release(outputs); + iree_io_stream_release(stdout_stream); fflush(stdout); return status; } -static iree_status_t iree_tooling_process_outputs( - iree_hal_device_t* device, iree_vm_list_t* outputs, +static iree_status_t iree_tooling_process_results( + iree_hal_device_t* device, iree_string_view_t results_cconv, + iree_vm_list_t* results, iree_io_stream_t* stream, iree_allocator_t host_allocator, int* out_exit_code) { *out_exit_code = EXIT_SUCCESS; @@ -308,16 +332,18 @@ static iree_status_t iree_tooling_process_outputs( if (FLAG_output_list().count == 0) { // Print all outputs. return iree_status_annotate_f( - iree_tooling_variant_list_fprint( - IREE_SV("result"), outputs, - (iree_host_size_t)FLAG_output_max_element_count, stdout), + iree_tooling_print_variants( + IREE_SV("result"), results, + (iree_host_size_t)FLAG_output_max_element_count, stream, + host_allocator), "printing results"); } else { // Write (or ignore) all outputs. return iree_status_annotate_f( - iree_tooling_output_variant_list( - outputs, FLAG_output_list().values, FLAG_output_list().count, - (iree_host_size_t)FLAG_output_max_element_count, stdout), + iree_tooling_write_variants( + results, FLAG_output_list(), + (iree_host_size_t)FLAG_output_max_element_count, stream, + host_allocator), "outputting results"); } } @@ -331,14 +357,14 @@ static iree_status_t iree_tooling_process_outputs( // Parse expected list into host-local memory that we can easily access. iree_vm_list_t* expected_list = NULL; iree_status_t status = iree_status_annotate_f( - iree_tooling_parse_to_variant_list( - device, heap_allocator, FLAG_expected_output_list().values, - FLAG_expected_output_list().count, host_allocator, &expected_list), + iree_tooling_parse_variants(results_cconv, FLAG_expected_output_list(), + device, heap_allocator, host_allocator, + &expected_list), "parsing expected function outputs"); // Compare expected vs actual lists and output diffs. if (iree_status_is_ok(status)) { - bool did_match = iree_tooling_compare_variant_lists(expected_list, outputs, + bool did_match = iree_tooling_compare_variant_lists(expected_list, results, host_allocator, stdout); if (did_match) { fprintf( diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c deleted file mode 100644 index c028b358051e..000000000000 --- a/runtime/src/iree/tooling/vm_util.c +++ /dev/null @@ -1,930 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/tooling/vm_util.h" - -#include -#include -#include - -#include "iree/base/api.h" -#include "iree/base/internal/file_io.h" -#include "iree/hal/api.h" -#include "iree/modules/hal/module.h" -#include "iree/tooling/numpy_io.h" - -static iree_status_t iree_allocate_and_copy_cstring_from_view( - iree_allocator_t allocator, iree_string_view_t view, char** cstring) { - IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, view.size + 1, (void**)cstring)); - memcpy(*cstring, view.data, view.size); - (*cstring)[view.size] = 0; - return iree_ok_status(); -} - -static iree_status_t iree_tooling_load_ndarrays_from_file( - iree_string_view_t file_path, iree_hal_device_t* device, - iree_hal_allocator_t* device_allocator, iree_vm_list_t* list) { - char* file_path_cstring = NULL; - IREE_RETURN_IF_ERROR(iree_allocate_and_copy_cstring_from_view( - iree_allocator_system(), file_path, &file_path_cstring)); - FILE* file = fopen(file_path_cstring, "rb"); - iree_allocator_free(iree_allocator_system(), file_path_cstring); - if (!file) { - return iree_make_status(iree_status_code_from_errno(errno), - "failed to open file '%.*s'", (int)file_path.size, - file_path.data); - } - - uint64_t file_length = 0; - iree_status_t status = iree_file_query_length(file, &file_length); - - iree_hal_buffer_params_t buffer_params = {0}; - buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; - buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ; - buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; - - while (iree_status_is_ok(status) && !iree_file_is_at(file, file_length)) { - iree_hal_buffer_view_t* buffer_view = NULL; - status = iree_numpy_npy_load_ndarray( - file, IREE_NUMPY_NPY_LOAD_OPTION_DEFAULT, buffer_params, device, - device_allocator, &buffer_view); - if (iree_status_is_ok(status)) { - iree_vm_ref_t buffer_view_ref = - iree_hal_buffer_view_retain_ref(buffer_view); - status = iree_vm_list_push_ref_move(list, &buffer_view_ref); - } - iree_hal_buffer_view_release(buffer_view); - } - - fclose(file); - return status; -} - -struct iree_create_buffer_from_file_generator_user_data_t { - FILE* file; -}; - -static iree_status_t iree_create_buffer_from_file_generator_callback( - iree_hal_buffer_mapping_t* mapping, void* user_data) { - struct iree_create_buffer_from_file_generator_user_data_t* read_params = - user_data; - size_t bytes_read = fread(mapping->contents.data, 1, - mapping->contents.data_length, read_params->file); - if (bytes_read != mapping->contents.data_length) { - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "file contents truncated; expected %" PRIhsz - " bytes " - "based on buffer view size", - mapping->contents.data_length); - } - return iree_ok_status(); -} - -// Creates a HAL buffer view with the given |metadata| and reads the contents -// from the file at |file_path|. -// -// The file contents are directly read in to memory with no processing. -static iree_status_t iree_create_buffer_view_from_file( - iree_string_view_t metadata, iree_string_view_t file_path, - iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, - iree_hal_buffer_view_t** out_buffer_view) { - *out_buffer_view = NULL; - - // Parse shape and element type used to allocate the buffer view. - iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; - iree_host_size_t shape_rank = 0; - iree_status_t shape_result = iree_hal_parse_shape_and_element_type( - metadata, 0, &shape_rank, NULL, &element_type); - if (!iree_status_is_ok(shape_result) && - !iree_status_is_out_of_range(shape_result)) { - return shape_result; - } else if (shape_rank > 128) { - return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, - "a shape rank of %" PRIhsz - " is just a little bit excessive, eh?", - shape_rank); - } - iree_status_ignore(shape_result); - iree_hal_dim_t* shape = - (iree_hal_dim_t*)iree_alloca(shape_rank * sizeof(iree_hal_dim_t)); - IREE_RETURN_IF_ERROR(iree_hal_parse_shape_and_element_type( - metadata, shape_rank, &shape_rank, shape, &element_type)); - - // TODO(benvanik): allow specifying the encoding. - iree_hal_encoding_type_t encoding_type = - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR; - - // Open the file for reading. - char* file_path_cstring = NULL; - IREE_RETURN_IF_ERROR(iree_allocate_and_copy_cstring_from_view( - iree_allocator_system(), file_path, &file_path_cstring)); - FILE* file = fopen(file_path_cstring, "rb"); - iree_allocator_free(iree_allocator_system(), file_path_cstring); - if (!file) { - return iree_make_status(iree_status_code_from_errno(errno), - "failed to open file '%.*s'", (int)file_path.size, - file_path.data); - } - - iree_hal_buffer_params_t buffer_params = {0}; - buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL; - buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; - struct iree_create_buffer_from_file_generator_user_data_t read_params = { - file, - }; - iree_status_t status = iree_hal_buffer_view_generate_buffer( - device, device_allocator, shape_rank, shape, element_type, encoding_type, - buffer_params, iree_create_buffer_from_file_generator_callback, - &read_params, out_buffer_view); - - fclose(file); - - return status; -} - -iree_status_t iree_tooling_parse_to_variant_list( - iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, - const iree_string_view_t* input_strings, - iree_host_size_t input_strings_count, iree_allocator_t host_allocator, - iree_vm_list_t** out_list) { - IREE_TRACE_ZONE_BEGIN(z0); - - *out_list = NULL; - iree_vm_list_t* list = NULL; - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_vm_list_create(iree_vm_make_undefined_type_def(), - input_strings_count, host_allocator, &list)); - - iree_status_t status = iree_tooling_parse_into_variant_list( - device, device_allocator, input_strings, input_strings_count, - host_allocator, list); - if (iree_status_is_ok(status)) { - *out_list = list; - } else { - iree_vm_list_release(list); - } - IREE_TRACE_ZONE_END(z0); - return status; -} - -iree_status_t iree_tooling_parse_into_variant_list( - iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, - const iree_string_view_t* input_strings, - iree_host_size_t input_strings_count, iree_allocator_t host_allocator, - iree_vm_list_t* list) { - IREE_TRACE_ZONE_BEGIN(z0); - - // Reset the list and prepare for pushing items. - iree_vm_list_clear(list); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_vm_list_reserve(list, input_strings_count)); - - iree_status_t status = iree_ok_status(); - for (size_t i = 0; i < input_strings_count; ++i) { - if (!iree_status_is_ok(status)) break; - iree_string_view_t input_view = iree_string_view_trim(input_strings[i]); - if (iree_string_view_is_empty(input_view)) { - status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "no value specified for input"); - break; - } else if (iree_string_view_consume_prefix(&input_view, IREE_SV("@"))) { - status = iree_tooling_load_ndarrays_from_file(input_view, device, - device_allocator, list); - continue; - } else if (iree_string_view_equal(input_view, IREE_SV("(null)")) || - iree_string_view_equal(input_view, IREE_SV("(ignored)"))) { - iree_vm_ref_t null_ref = iree_vm_ref_null(); - status = iree_vm_list_push_ref_retain(list, &null_ref); - continue; - } - bool has_equal = - iree_string_view_find_char(input_view, '=', 0) != IREE_STRING_VIEW_NPOS; - bool has_x = - iree_string_view_find_char(input_view, 'x', 0) != IREE_STRING_VIEW_NPOS; - if (has_equal || has_x) { - // Buffer view (either just a shape or a shape=value) or buffer. - bool is_storage_reference = iree_string_view_consume_prefix( - &input_view, iree_make_cstring_view("&")); - iree_hal_buffer_view_t* buffer_view = NULL; - bool has_at = iree_string_view_find_char(input_view, '@', 0) != - IREE_STRING_VIEW_NPOS; - if (has_at) { - // Referencing an external file; split into the portion used to - // initialize the buffer view and the file contents. - iree_string_view_t metadata, file_path; - iree_string_view_split(input_view, '@', &metadata, &file_path); - iree_string_view_consume_suffix(&metadata, iree_make_cstring_view("=")); - status = iree_create_buffer_view_from_file( - metadata, file_path, device, device_allocator, &buffer_view); - if (!iree_status_is_ok(status)) break; - } else { - status = iree_hal_buffer_view_parse(input_view, device, - device_allocator, &buffer_view); - if (!iree_status_is_ok(status)) { - status = - iree_status_annotate_f(status, "parsing value '%.*s'", - (int)input_view.size, input_view.data); - break; - } - } - if (is_storage_reference) { - // Storage buffer reference; just take the storage for the buffer view - - // it'll still have whatever contents were specified (or 0) but we'll - // discard the metadata. - iree_vm_ref_t buffer_ref = iree_hal_buffer_retain_ref( - iree_hal_buffer_view_buffer(buffer_view)); - iree_hal_buffer_view_release(buffer_view); - status = iree_vm_list_push_ref_move(list, &buffer_ref); - if (!iree_status_is_ok(status)) break; - } else { - iree_vm_ref_t buffer_view_ref = - iree_hal_buffer_view_move_ref(buffer_view); - status = iree_vm_list_push_ref_move(list, &buffer_view_ref); - if (!iree_status_is_ok(status)) break; - } - } else { - // Scalar. - bool has_dot = iree_string_view_find_char(input_view, '.', 0) != - IREE_STRING_VIEW_NPOS; - iree_vm_value_t val; - if (has_dot) { - // Float. - val = iree_vm_value_make_f32(0.0f); - if (!iree_string_view_atof(input_view, &val.f32)) { - status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "parsing value '%.*s' as f32", - (int)input_view.size, input_view.data); - break; - } - } else { - // Integer. - val = iree_vm_value_make_i32(0); - if (!iree_string_view_atoi_int32(input_view, &val.i32)) { - status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "parsing value '%.*s' as i32", - (int)input_view.size, input_view.data); - break; - } - } - status = iree_vm_list_push_value(list, &val); - if (!iree_status_is_ok(status)) break; - } - } - - IREE_TRACE_ZONE_END(z0); - return status; -} - -iree_status_t iree_tooling_append_async_fence_inputs( - iree_vm_list_t* list, const iree_vm_function_t* function, - iree_hal_device_t* device, iree_hal_fence_t* wait_fence, - iree_hal_fence_t** out_signal_fence) { - IREE_TRACE_ZONE_BEGIN(z0); - - iree_string_view_t model = - iree_vm_function_lookup_attr_by_name(function, IREE_SV("iree.abi.model")); - if (!iree_string_view_equal(model, IREE_SV("coarse-fences"))) { - // Ignore unknown models - the user may have provided their own fences. - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); - } - - // Create the signal fence as a 0->1 transition. The caller will wait on that. - iree_hal_semaphore_t* semaphore = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); - iree_hal_fence_t* signal_fence = NULL; - iree_status_t status = iree_hal_fence_create_at( - semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); - iree_hal_semaphore_release(semaphore); - - // Append (wait, signal) fences. - if (iree_status_is_ok(status)) { - iree_vm_ref_t wait_fence_ref = iree_hal_fence_retain_ref(wait_fence); - status = iree_vm_list_push_ref_move(list, &wait_fence_ref); - iree_vm_ref_release(&wait_fence_ref); - } - if (iree_status_is_ok(status)) { - iree_vm_ref_t signal_fence_ref = iree_hal_fence_retain_ref(signal_fence); - status = iree_vm_list_push_ref_move(list, &signal_fence_ref); - iree_vm_ref_release(&signal_fence_ref); - } - - if (iree_status_is_ok(status)) { - *out_signal_fence = signal_fence; - } else { - iree_hal_fence_release(signal_fence); - } - IREE_TRACE_ZONE_END(z0); - return status; -} - -static bool iree_tooling_requires_buffer_transfer( - iree_hal_buffer_t* source_buffer, iree_hal_buffer_params_t target_params) { - return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer), - target_params.type) || - !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer), - target_params.usage); -} - -static iree_status_t iree_tooling_setup_buffer_transfer( - iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, - iree_hal_allocator_t* target_allocator, - iree_hal_buffer_params_t target_params, - iree_hal_buffer_t** out_target_buffer) { - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(source_buffer); - IREE_ASSERT_ARGUMENT(target_allocator); - IREE_ASSERT_ARGUMENT(out_target_buffer); - *out_target_buffer = NULL; - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_buffer_t* target_buffer = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_allocator_allocate_buffer( - target_allocator, target_params, - iree_hal_buffer_allocation_size(source_buffer), &target_buffer)); - - iree_status_t status = iree_hal_command_buffer_copy_buffer( - command_buffer, source_buffer, 0, target_buffer, 0, - iree_hal_buffer_byte_length(source_buffer)); - - if (iree_status_is_ok(status)) { - *out_target_buffer = target_buffer; - } else { - iree_hal_buffer_release(target_buffer); - } - IREE_TRACE_ZONE_END(z0); - return status; -} - -static iree_status_t iree_tooling_submit_transfer( - iree_hal_device_t* device, iree_hal_fence_t* wait_fence, - iree_hal_queue_affinity_t queue_affinity, - iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) { - IREE_TRACE_ZONE_BEGIN(z0); - - iree_status_t status = iree_ok_status(); - - bool needs_wait = signal_fence == NULL; - if (needs_wait) { - iree_hal_semaphore_t* semaphore = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); - status = iree_hal_fence_create_at( - semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); - iree_hal_semaphore_release(semaphore); - } else { - iree_hal_fence_retain(signal_fence); - } - - if (iree_status_is_ok(status)) { - status = iree_hal_device_queue_execute( - device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), - iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer); - } - - if (iree_status_is_ok(status) && needs_wait) { - status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout()); - } - - iree_hal_fence_release(signal_fence); - IREE_TRACE_ZONE_END(z0); - return status; -} - -iree_status_t iree_tooling_transfer_variant_list( - iree_hal_device_t* device, iree_vm_list_t* list, - iree_hal_allocator_t* target_allocator, - iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, - iree_hal_fence_t* signal_fence) { - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(list); - IREE_ASSERT_ARGUMENT(target_allocator); - IREE_TRACE_ZONE_BEGIN(z0); - - // If all buffers are already host-accessible we can skip the transfer. - bool requires_transfer = false; - for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { - iree_vm_ref_t value = iree_vm_ref_null(); - IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); - if (iree_hal_buffer_isa(value)) { - iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); - if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { - requires_transfer = true; - break; - } - } else if (iree_hal_buffer_view_isa(value)) { - iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); - iree_hal_buffer_t* source_buffer = - iree_hal_buffer_view_buffer(source_view); - if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { - requires_transfer = true; - break; - } - } - } - if (!requires_transfer) { - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); - } - - iree_hal_command_buffer_t* command_buffer = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_command_buffer_create( - device, - IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | - IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, - IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity, - /*binding_capacity=*/0, &command_buffer)); - - iree_status_t status = iree_hal_command_buffer_begin(command_buffer); - if (iree_status_is_ok(status)) { - for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { - iree_vm_ref_t value = iree_vm_ref_null(); - IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); - if (iree_hal_buffer_isa(value)) { - iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); - if (!iree_tooling_requires_buffer_transfer(source_buffer, - target_params)) { - // Already ok. - continue; - } - iree_hal_buffer_t* target_buffer = NULL; - status = iree_tooling_setup_buffer_transfer( - command_buffer, source_buffer, target_allocator, target_params, - &target_buffer); - if (!iree_status_is_ok(status)) break; - status = iree_vm_list_set_buffer_retain(list, i, target_buffer); - iree_hal_buffer_release(target_buffer); - if (!iree_status_is_ok(status)) break; - } else if (iree_hal_buffer_view_isa(value)) { - iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); - iree_hal_buffer_t* source_buffer = - iree_hal_buffer_view_buffer(source_view); - if (!iree_tooling_requires_buffer_transfer(source_buffer, - target_params)) { - // Already ok. - continue; - } - iree_hal_buffer_t* target_buffer = NULL; - status = iree_tooling_setup_buffer_transfer( - command_buffer, source_buffer, target_allocator, target_params, - &target_buffer); - if (!iree_status_is_ok(status)) break; - iree_hal_buffer_view_t* target_view = NULL; - status = iree_hal_buffer_view_create_like( - target_buffer, source_view, - iree_hal_allocator_host_allocator(target_allocator), &target_view); - iree_hal_buffer_release(target_buffer); - if (!iree_status_is_ok(status)) break; - status = iree_vm_list_set_buffer_view_retain(list, i, target_view); - iree_hal_buffer_view_release(target_view); - if (!iree_status_is_ok(status)) break; - } - } - } - if (iree_status_is_ok(status)) { - status = iree_hal_command_buffer_end(command_buffer); - } - - if (iree_status_is_ok(status)) { - status = iree_tooling_submit_transfer(device, wait_fence, - target_params.queue_affinity, - command_buffer, signal_fence); - } - - iree_hal_command_buffer_release(command_buffer); - - IREE_TRACE_ZONE_END(z0); - return status; -} - -#define IREE_PRINTVARIANT_CASE_I(SIZE, B, V) \ - case IREE_VM_VALUE_TYPE_I##SIZE: \ - return iree_string_builder_append_format( \ - B, "i" #SIZE "=%" PRIi##SIZE "\n", (V).i##SIZE); - -#define IREE_PRINTVARIANT_CASE_F(SIZE, B, V) \ - case IREE_VM_VALUE_TYPE_F##SIZE: \ - return iree_string_builder_append_format(B, "f" #SIZE "=%g\n", (V).f##SIZE); - -// Prints variant description including a trailing newline. -static iree_status_t iree_variant_format(iree_vm_variant_t variant, - iree_host_size_t max_element_count, - iree_string_builder_t* builder) { - if (iree_vm_variant_is_empty(variant)) { - return iree_string_builder_append_string(builder, IREE_SV("(null)\n")); - } else if (iree_vm_variant_is_value(variant)) { - switch (iree_vm_type_def_as_value(variant.type)) { - IREE_PRINTVARIANT_CASE_I(8, builder, variant) - IREE_PRINTVARIANT_CASE_I(16, builder, variant) - IREE_PRINTVARIANT_CASE_I(32, builder, variant) - IREE_PRINTVARIANT_CASE_I(64, builder, variant) - IREE_PRINTVARIANT_CASE_F(32, builder, variant) - IREE_PRINTVARIANT_CASE_F(64, builder, variant) - default: - return iree_string_builder_append_string(builder, IREE_SV("?\n")); - } - } else if (iree_vm_variant_is_ref(variant)) { - iree_string_view_t type_name = - iree_vm_ref_type_name(iree_vm_type_def_as_ref(variant.type)); - IREE_RETURN_IF_ERROR(iree_string_builder_append_string(builder, type_name)); - IREE_RETURN_IF_ERROR( - iree_string_builder_append_string(builder, IREE_SV("\n"))); - if (iree_vm_list_isa(variant.ref)) { - iree_vm_list_t* child_list = iree_vm_list_deref(variant.ref); - IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines( - IREE_SV("child_list"), child_list, max_element_count, builder)); - return iree_string_builder_append_string(builder, IREE_SV("\n")); - } else if (iree_hal_buffer_view_isa(variant.ref)) { - iree_hal_buffer_view_t* buffer_view = - iree_hal_buffer_view_deref(variant.ref); - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_append_to_builder( - buffer_view, max_element_count, builder)); - return iree_string_builder_append_string(builder, IREE_SV("\n")); - } else { - // TODO(benvanik): a way for ref types to describe themselves. - return iree_string_builder_append_string(builder, - IREE_SV("(no printer)\n")); - } - } else { - return iree_string_builder_append_string(builder, IREE_SV("(null)\n")); - } - return iree_ok_status(); -} - -static iree_status_t iree_variant_fprint(iree_vm_variant_t variant, - iree_host_size_t max_element_count, - FILE* file) { - iree_string_builder_t builder; - iree_string_builder_initialize(iree_allocator_system(), &builder); - iree_status_t status = - iree_variant_format(variant, max_element_count, &builder); - if (iree_status_is_ok(status)) { - size_t written = fwrite(iree_string_builder_buffer(&builder), 1, - iree_string_builder_size(&builder), file); - if (written != iree_string_builder_size(&builder)) { - status = iree_status_from_code(IREE_STATUS_PERMISSION_DENIED); - } - fflush(file); - } - iree_string_builder_deinitialize(&builder); - return status; -} - -iree_status_t iree_tooling_append_variant_list_lines( - iree_string_view_t list_name, iree_vm_list_t* list, - iree_host_size_t max_element_count, iree_string_builder_t* builder) { - IREE_TRACE_ZONE_BEGIN(z0); - for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { - iree_vm_variant_t variant = iree_vm_variant_empty(); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_vm_list_get_variant_assign(list, i, &variant), - "variant %" PRIhsz " not present", i); - iree_string_builder_append_format( - builder, "%.*s[%" PRIhsz "]: ", (int)list_name.size, list_name.data, i); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_variant_format(variant, max_element_count, builder)); - } - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -iree_status_t iree_tooling_variant_list_fprint( - iree_string_view_t list_name, iree_vm_list_t* list, - iree_host_size_t max_element_count, FILE* file) { - iree_string_builder_t builder; - iree_string_builder_initialize(iree_allocator_system(), &builder); - iree_status_t status = iree_tooling_append_variant_list_lines( - list_name, list, max_element_count, &builder); - if (iree_status_is_ok(status)) { - size_t written = fwrite(iree_string_builder_buffer(&builder), 1, - iree_string_builder_size(&builder), file); - if (written != iree_string_builder_size(&builder)) { - status = iree_status_from_code(IREE_STATUS_PERMISSION_DENIED); - } - fflush(file); - } - iree_string_builder_deinitialize(&builder); - return status; -} - -static iree_status_t iree_tooling_create_buffer_view_with_hal_buffer( - iree_hal_buffer_t* hal_buffer, iree_allocator_t host_allocator, - iree_hal_buffer_view_t** out_buffer_view) { - iree_hal_dim_t shape[1] = { - (iree_hal_dim_t)iree_hal_buffer_byte_length(hal_buffer), - }; - return iree_hal_buffer_view_create( - hal_buffer, IREE_ARRAYSIZE(shape), shape, IREE_HAL_ELEMENT_TYPE_INT_8, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, host_allocator, out_buffer_view); -} - -static void iree_hal_buffer_release_vm_buffer( - void* user_data, struct iree_hal_buffer_t* buffer) { - iree_vm_buffer_release((iree_vm_buffer_t*)user_data); -} - -static iree_status_t iree_tooling_create_buffer_view_with_vm_buffer( - iree_vm_buffer_t* vm_buffer, iree_hal_allocator_t* device_allocator, - iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { - // Get read-only pointer to the underlying buffer heap memory. - iree_const_byte_span_t span = iree_const_byte_span_empty(); - IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro( - vm_buffer, 0, iree_vm_buffer_length(vm_buffer), 1, &span)); - - // Wrap the heap memory in a HAL buffer for read-only access. - iree_hal_buffer_release_callback_t release_callback = { - .fn = iree_hal_buffer_release_vm_buffer, - .user_data = vm_buffer, - }; - iree_vm_buffer_retain(vm_buffer); - iree_hal_buffer_t* hal_buffer = NULL; - iree_status_t status = iree_hal_heap_buffer_wrap( - device_allocator, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_READ, - IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE | IREE_HAL_BUFFER_USAGE_MAPPING, - span.data_length, iree_cast_const_byte_span(span), release_callback, - &hal_buffer); - iree_vm_buffer_release(vm_buffer); - - // Wrap the HAL buffer in a buffer view. - if (iree_status_is_ok(status)) { - status = iree_tooling_create_buffer_view_with_hal_buffer( - hal_buffer, host_allocator, out_buffer_view); - } - - iree_hal_buffer_release(hal_buffer); - return status; -} - -static iree_status_t iree_tooling_create_buffer_view_empty( - iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator, - iree_hal_buffer_view_t** out_buffer_view) { - iree_hal_buffer_t* hal_buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_heap_buffer_wrap( - device_allocator, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_READ, - IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE | IREE_HAL_BUFFER_USAGE_MAPPING, 0, - iree_byte_span_empty(), iree_hal_buffer_release_callback_null(), - &hal_buffer)); - iree_status_t status = iree_tooling_create_buffer_view_with_hal_buffer( - hal_buffer, host_allocator, out_buffer_view); - iree_hal_buffer_release(hal_buffer); - return status; -} - -static iree_status_t iree_tooling_create_buffer_view_with_value( - iree_vm_value_t value, iree_hal_allocator_t* device_allocator, - iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { - iree_device_size_t byte_length = 0; - iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; - switch (value.type) { - case IREE_VM_VALUE_TYPE_NONE: - return iree_tooling_create_buffer_view_empty( - device_allocator, host_allocator, out_buffer_view); - case IREE_VM_VALUE_TYPE_I8: - byte_length = sizeof(value.i8); - element_type = IREE_HAL_ELEMENT_TYPE_INT_8; - break; - case IREE_VM_VALUE_TYPE_I16: - byte_length = sizeof(value.i16); - element_type = IREE_HAL_ELEMENT_TYPE_INT_16; - break; - case IREE_VM_VALUE_TYPE_I32: - byte_length = sizeof(value.i32); - element_type = IREE_HAL_ELEMENT_TYPE_INT_32; - break; - case IREE_VM_VALUE_TYPE_I64: - byte_length = sizeof(value.i64); - element_type = IREE_HAL_ELEMENT_TYPE_INT_64; - break; - case IREE_VM_VALUE_TYPE_F32: - byte_length = sizeof(value.f32); - element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32; - break; - case IREE_VM_VALUE_TYPE_F64: - byte_length = sizeof(value.f64); - element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_64; - break; - default: - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "unsupported value type"); - } - - iree_hal_buffer_params_t params = { - .usage = - IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE | IREE_HAL_BUFFER_USAGE_MAPPING, - .access = IREE_HAL_MEMORY_ACCESS_ALL, - .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - }; - iree_hal_buffer_t* hal_buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( - device_allocator, params, byte_length, &hal_buffer)); - - iree_status_t status = iree_hal_buffer_map_write( - hal_buffer, 0, value.value_storage, byte_length); - - if (iree_status_is_ok(status)) { - status = iree_hal_buffer_view_create(hal_buffer, /*shape_rank=*/0, - /*shape=*/NULL, element_type, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - host_allocator, out_buffer_view); - } - - iree_hal_buffer_release(hal_buffer); - return status; -} - -static iree_status_t iree_tooling_create_buffer_view_from_variant( - iree_vm_variant_t variant, iree_hal_allocator_t* device_allocator, - iree_allocator_t host_allocator, iree_hal_buffer_view_t** out_buffer_view) { - *out_buffer_view = NULL; - if (iree_vm_variant_is_empty(variant)) { - // Empty value - we need to emit a zero-length value to keep the npy file - // ordered when there are multiple entries. - return iree_tooling_create_buffer_view_empty( - device_allocator, host_allocator, out_buffer_view); - } else if (iree_vm_variant_is_ref(variant)) { - if (iree_hal_buffer_view_isa(variant.ref)) { - // Buffer view returned can provide the metadata required. - *out_buffer_view = iree_hal_buffer_view_deref(variant.ref); - iree_hal_buffer_view_retain(*out_buffer_view); - return iree_ok_status(); - } else if (iree_hal_buffer_isa(variant.ref)) { - // i8 buffer view of the total length of the HAL buffer. - iree_hal_buffer_t* buffer = iree_hal_buffer_deref(variant.ref); - return iree_tooling_create_buffer_view_with_hal_buffer( - buffer, host_allocator, out_buffer_view); - } else if (iree_vm_buffer_isa(variant.ref)) { - // i8 buffer view of the total length of the VM buffer wrapped in a HAL - // buffer. - iree_vm_buffer_t* buffer = iree_vm_buffer_deref(variant.ref); - return iree_tooling_create_buffer_view_with_vm_buffer( - buffer, device_allocator, host_allocator, out_buffer_view); - } else { - // Unsupported type. - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "unsupported output source type; expected: " - "!hal.buffer, !hal.buffer_view, !vm.buffer"); - } - } else { - // Primitive value that we wrap in a scalar buffer view. - return iree_tooling_create_buffer_view_with_value( - iree_vm_variant_value(variant), device_allocator, host_allocator, - out_buffer_view); - } -} - -static iree_status_t iree_tooling_output_variant_to_npy_file( - FILE* file, iree_vm_variant_t variant, - iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator) { - // npy files require buffer views so if we receive anything but a buffer view - // we wrap it in one typed as bytes. - iree_hal_buffer_view_t* buffer_view = NULL; - IREE_RETURN_IF_ERROR(iree_tooling_create_buffer_view_from_variant( - variant, device_allocator, host_allocator, &buffer_view)); - - // Append buffer view contents to the file stream. - iree_numpy_npy_save_options_t options = IREE_NUMPY_NPY_SAVE_OPTION_DEFAULT; - iree_status_t status = iree_numpy_npy_save_ndarray(file, options, buffer_view, - iree_allocator_system()); - - iree_hal_buffer_view_release(buffer_view); - return status; -} - -static iree_status_t iree_tooling_output_variant_to_binary_file( - FILE* file, iree_vm_variant_t variant, - iree_hal_allocator_t* device_allocator, iree_allocator_t host_allocator) { - // Today we reuse the buffer view code to get the variant into a byte buffer - // to write out even though we don't use any of the metadata. This is a - // command line tool writing out files using stdio and not an example of how - // to create a high performance I/O mechanism. - iree_hal_buffer_view_t* buffer_view = NULL; - IREE_RETURN_IF_ERROR(iree_tooling_create_buffer_view_from_variant( - variant, device_allocator, host_allocator, &buffer_view)); - iree_device_size_t byte_length = - iree_hal_buffer_view_byte_length(buffer_view); - - // Map the buffer memory into a host pointer so we can access it. - iree_hal_buffer_mapping_t mapping; - iree_status_t status = iree_hal_buffer_map_range( - iree_hal_buffer_view_buffer(buffer_view), IREE_HAL_MAPPING_MODE_SCOPED, - IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &mapping); - - // Write to the file from the mapped memory. - if (iree_status_is_ok(status)) { - bool write_ok = - fwrite(mapping.contents.data, 1, byte_length, file) == byte_length; - status = write_ok ? iree_ok_status() - : iree_make_status(IREE_STATUS_DATA_LOSS, - "failed to write buffer contents"); - } - - iree_status_ignore(iree_hal_buffer_unmap_range(&mapping)); - - iree_hal_buffer_view_release(buffer_view); - return status; -} - -static iree_status_t iree_tooling_output_variant( - iree_vm_variant_t variant, iree_string_view_t output_str, - iree_host_size_t max_element_count, FILE* default_file) { - iree_allocator_t host_allocator = iree_allocator_system(); - - if (iree_string_view_is_empty(output_str)) { - // Send into the void. - return iree_ok_status(); - } else if (iree_string_view_equal(output_str, IREE_SV("-"))) { - // Route to the provided file. - return iree_variant_fprint(variant, max_element_count, default_file); - } - - bool has_at = iree_string_view_consume_prefix(&output_str, IREE_SV("@")); - bool has_plus = iree_string_view_consume_prefix(&output_str, IREE_SV("+")); - if (!has_at && !has_plus) { - // Other types of outputs are not yet supported. We could allow for shapes - // and either verify metadata or output binary files ala - // `--input=4xf32=@foo.bin`. - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "unsupported output mode specification '%.*s'", - (int)output_str.size, output_str.data); - } - - // Output format is based on file extension with ones we don't know about - // going into binary mode. Some formats require metadata from buffer views - // but in binary mode we just dump whatever contents we have and leave it up - // to the user to handle the shape/type/encoding. - iree_string_view_t file_path = output_str; - - // Open file for either overwriting or appending (npy files can contain - // multiple arrays). - char* file_path_cstring = NULL; - IREE_RETURN_IF_ERROR(iree_allocate_and_copy_cstring_from_view( - host_allocator, file_path, &file_path_cstring)); - const char* mode = has_plus ? "ab" : "wb"; - FILE* file = fopen(file_path_cstring, mode); - iree_allocator_free(host_allocator, file_path_cstring); - if (!file) { - return iree_make_status(iree_status_code_from_errno(errno), - "failed to open file '%.*s'", (int)file_path.size, - file_path.data); - } - - iree_hal_allocator_t* device_allocator = NULL; - iree_status_t status = iree_hal_allocator_create_heap( - IREE_SV("tooling"), host_allocator, host_allocator, &device_allocator); - if (iree_status_is_ok(status)) { - if (iree_string_view_ends_with(file_path, IREE_SV(".npy"))) { - status = iree_tooling_output_variant_to_npy_file( - file, variant, device_allocator, host_allocator); - } else { - status = iree_tooling_output_variant_to_binary_file( - file, variant, device_allocator, host_allocator); - } - } - iree_hal_allocator_release(device_allocator); - - fclose(file); - return status; -} - -iree_status_t iree_tooling_output_variant_list( - iree_vm_list_t* list, const iree_string_view_t* output_strings, - iree_host_size_t output_strings_count, iree_host_size_t max_element_count, - FILE* file) { - IREE_ASSERT_ARGUMENT(list); - IREE_ASSERT_ARGUMENT(!output_strings_count || output_strings); - - // We only care if there are not enough outputs to satisfy the user - // request. We could force users to specify all outputs to make this a bit - // harder to misuse but saving off outputs is a power-user feature. - if (iree_vm_list_size(list) != output_strings_count) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "%" PRIhsz " outputs specified but the provided list only has %" PRIhsz - " elements", - output_strings_count, iree_vm_list_size(list)); - } - - IREE_TRACE_ZONE_BEGIN(z0); - - for (iree_host_size_t i = 0; i < output_strings_count; ++i) { - iree_vm_variant_t variant = iree_vm_variant_empty(); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_vm_list_get_variant_assign(list, i, &variant)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_tooling_output_variant(variant, output_strings[i], - max_element_count, file)); - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h deleted file mode 100644 index bc9ca008236b..000000000000 --- a/runtime/src/iree/tooling/vm_util.h +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_TOOLING_VM_UTIL_H_ -#define IREE_TOOLING_VM_UTIL_H_ - -#include "iree/base/api.h" -#include "iree/hal/api.h" -#include "iree/vm/api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// NOTE: this file is not best-practice and needs to be rewritten; consider this -// appropriate only for test code. - -// Parses |input_strings| into a variant list of VM scalars and buffers. -// Scalars should be in the format: -// type=value -// Buffers should be in the IREE standard shaped buffer format: -// [shape]xtype=[value] -// described in iree/hal/api.h -// Uses |device_allocator| to allocate the buffers. -// The returned variant list must be freed by the caller. -iree_status_t iree_tooling_parse_to_variant_list( - iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, - const iree_string_view_t* input_strings, - iree_host_size_t input_strings_count, iree_allocator_t host_allocator, - iree_vm_list_t** out_list); - -// Parses |input_strings| into a variant list of VM scalars and buffers. -// Scalars should be in the format: -// type=value -// Buffers should be in the IREE standard shaped buffer format: -// [shape]xtype=[value] -// described in iree/hal/api.h -// Uses |device_allocator| to allocate the buffers. -iree_status_t iree_tooling_parse_into_variant_list( - iree_hal_device_t* device, iree_hal_allocator_t* device_allocator, - const iree_string_view_t* input_strings, - iree_host_size_t input_strings_count, iree_allocator_t host_allocator, - iree_vm_list_t* list); - -// Appends fences to |list| if the invocation model of |function| requires them. -// If no |wait_fence| is provided then the invocation will begin immediately. -// The caller must wait on the returned |out_signal_fence| before accessing the -// contents of any buffers returned from the invocation. -iree_status_t iree_tooling_append_async_fence_inputs( - iree_vm_list_t* list, const iree_vm_function_t* function, - iree_hal_device_t* device, iree_hal_fence_t* wait_fence, - iree_hal_fence_t** out_signal_fence); - -// Transfers all buffers in |list| to ones using |target_params|. -// If no |wait_fence| is provided then the transfer will begin immediately. -// If no |signal_fence| is provided then the call will block until the transfer -// completes. -iree_status_t iree_tooling_transfer_variant_list( - iree_hal_device_t* device, iree_vm_list_t* list, - iree_hal_allocator_t* target_allocator, - iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, - iree_hal_fence_t* signal_fence); - -// Appends a variant list of VM scalars and buffers to |builder|. -// |list_name| will be printed alongside each element ordinal. -// -// Prints scalars in the format: -// value -// Prints buffers in the IREE standard shaped buffer format: -// [shape]xtype=[value] -// described in -// https://github.com/openxla/iree/tree/main/runtime/src/iree/hal/api.h -iree_status_t iree_tooling_append_variant_list_lines( - iree_string_view_t list_name, iree_vm_list_t* list, - iree_host_size_t max_element_count, iree_string_builder_t* builder); - -// Prints a variant list to a |file|. -// |list_name| will be printed alongside each element ordinal. -iree_status_t iree_tooling_variant_list_fprint( - iree_string_view_t list_name, iree_vm_list_t* list, - iree_host_size_t max_element_count, FILE* file); - -// Prints a variant |list| to targets based on the provided |output_strings|. -// -// |output_strings| format: -// (empty): ignore output -// `-`: print textual form to |file| -// `@file.npy`: create/overwrite a numpy .npy file. -// `+file.npy': create/append a numpy .npy file. -iree_status_t iree_tooling_output_variant_list( - iree_vm_list_t* list, const iree_string_view_t* output_strings, - iree_host_size_t output_strings_count, iree_host_size_t max_element_count, - FILE* file); - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // IREE_TOOLING_VM_UTIL_H_ diff --git a/runtime/src/iree/tooling/vm_util_test.cc b/runtime/src/iree/tooling/vm_util_test.cc deleted file mode 100644 index 2da4059757d4..000000000000 --- a/runtime/src/iree/tooling/vm_util_test.cc +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/tooling/vm_util.h" - -#include "iree/base/api.h" -#include "iree/base/internal/span.h" -#include "iree/hal/api.h" -#include "iree/modules/hal/module.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" -#include "iree/tooling/device_util.h" -#include "iree/vm/api.h" - -namespace iree { -namespace { - -static Status ParseToVariantList(iree_hal_device_t* device, - iree_hal_allocator_t* device_allocator, - iree::span input_strings, - iree_allocator_t host_allocator, - iree_vm_list_t** out_list) { - std::vector input_string_views(input_strings.size()); - for (size_t i = 0; i < input_strings.size(); ++i) { - input_string_views[i].data = input_strings[i].data(); - input_string_views[i].size = input_strings[i].size(); - } - return iree_tooling_parse_to_variant_list( - device, device_allocator, input_string_views.data(), - input_string_views.size(), host_allocator, out_list); -} - -static Status PrintVariantList(iree_vm_list_t* variant_list, - std::string* out_string) { - iree_string_builder_t builder; - iree_string_builder_initialize(iree_allocator_system(), &builder); - IREE_RETURN_IF_ERROR(iree_tooling_append_variant_list_lines( - IREE_SV("result"), variant_list, /*max_element_count=*/1024, &builder)); - out_string->assign(iree_string_builder_buffer(&builder), - iree_string_builder_size(&builder)); - iree_string_builder_deinitialize(&builder); - return iree_ok_status(); -} - -class VmUtilTest : public ::testing::Test { - protected: - virtual void SetUp() { - IREE_ASSERT_OK(iree_vm_instance_create( - IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance_)); - IREE_ASSERT_OK(iree_hal_module_register_all_types(instance_)); - iree_status_t status = iree_hal_create_device( - iree_hal_available_driver_registry(), IREE_SV("local-sync"), - iree_allocator_system(), &device_); - if (iree_status_is_not_found(status)) { - fprintf(stderr, "Skipping test as 'local-sync' driver was not found:\n"); - iree_status_fprint(stderr, status); - iree_status_free(status); - GTEST_SKIP(); - } - allocator_ = iree_hal_device_allocator(device_); - } - - virtual void TearDown() { - iree_hal_device_release(device_); - iree_vm_instance_release(instance_); - } - - iree_vm_instance_t* instance_ = nullptr; - iree_hal_device_t* device_ = nullptr; - iree_hal_allocator_t* allocator_ = nullptr; -}; - -TEST_F(VmUtilTest, ParsePrintBuffer) { - std::string buf_string = "&2x2xi32=[42 43][44 45]"; - vm::ref variant_list; - IREE_ASSERT_OK(ParseToVariantList( - device_, allocator_, std::vector{buf_string}, - iree_vm_instance_allocator(instance_), &variant_list)); - std::string result; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); - EXPECT_EQ(result, - std::string("result[0]: hal.buffer\n") + "(no printer)" + "\n"); -} - -TEST_F(VmUtilTest, ParsePrintBufferView) { - std::string buf_string = "2x2xi32=[42 43][44 45]"; - vm::ref variant_list; - IREE_ASSERT_OK(ParseToVariantList( - device_, allocator_, std::vector{buf_string}, - iree_vm_instance_allocator(instance_), &variant_list)); - std::string result; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); - EXPECT_EQ(result, - std::string("result[0]: hal.buffer_view\n") + buf_string + "\n"); -} - -TEST_F(VmUtilTest, ParsePrintScalar) { - std::string input_string = "42"; - vm::ref variant_list; - IREE_ASSERT_OK(ParseToVariantList( - device_, allocator_, std::vector{input_string}, - iree_vm_instance_allocator(instance_), &variant_list)); - std::string result; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); - EXPECT_EQ(result, std::string("result[0]: i32=") + input_string + "\n"); -} - -TEST_F(VmUtilTest, ParsePrintRank0BufferView) { - std::string buf_string = "i32=42"; - vm::ref variant_list; - IREE_ASSERT_OK(ParseToVariantList( - device_, allocator_, std::vector{buf_string}, - iree_vm_instance_allocator(instance_), &variant_list)); - std::string result; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); - EXPECT_EQ(result, - std::string("result[0]: hal.buffer_view\n") + buf_string + "\n"); -} - -TEST_F(VmUtilTest, ParsePrintMultipleBufferViews) { - std::string buf_string1 = "2x2xi32=[42 43][44 45]"; - std::string buf_string2 = "2x3xf64=[1 2 3][4 5 6]"; - vm::ref variant_list; - IREE_ASSERT_OK(ParseToVariantList( - device_, allocator_, std::vector{buf_string1, buf_string2}, - iree_vm_instance_allocator(instance_), &variant_list)); - std::string result; - IREE_ASSERT_OK(PrintVariantList(variant_list.get(), &result)); - EXPECT_EQ(result, std::string("result[0]: hal.buffer_view\n") + buf_string1 + - "\nresult[1]: hal.buffer_view\n" + buf_string2 + "\n"); -} - -} // namespace -} // namespace iree diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel index 13727e32c22d..18d620f4a071 100644 --- a/tools/BUILD.bazel +++ b/tools/BUILD.bazel @@ -34,7 +34,7 @@ iree_runtime_cc_binary( "//runtime/src/iree/modules/hal:types", "//runtime/src/iree/tooling:context_util", "//runtime/src/iree/tooling:device_util", - "//runtime/src/iree/tooling:vm_util", + "//runtime/src/iree/tooling:function_io", "//runtime/src/iree/vm", "@com_google_benchmark//:benchmark", ], @@ -53,7 +53,6 @@ iree_runtime_cc_binary( "//runtime/src/iree/testing:gtest", "//runtime/src/iree/tooling:context_util", "//runtime/src/iree/tooling:device_util", - "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/vm", "//runtime/src/iree/vm/bytecode:module", ], @@ -216,7 +215,6 @@ iree_runtime_cc_binary( "//runtime/src/iree/modules/hal", "//runtime/src/iree/tooling:context_util", "//runtime/src/iree/tooling:device_util", - "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/vm", "//runtime/src/iree/vm:cc", ], diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 67a8a7f7c8b1..08e3e51d4bc4 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -76,7 +76,7 @@ iree_cc_binary( iree::modules::hal::types iree::tooling::context_util iree::tooling::device_util - iree::tooling::vm_util + iree::tooling::function_io iree::vm INSTALL_COMPONENT IREETools-Runtime ) @@ -95,7 +95,6 @@ iree_cc_binary( iree::testing::gtest iree::tooling::context_util iree::tooling::device_util - iree::tooling::vm_util iree::vm iree::vm::bytecode::module TESTONLY @@ -242,7 +241,6 @@ iree_cc_binary( iree::modules::hal iree::tooling::context_util iree::tooling::device_util - iree::tooling::vm_util iree::vm iree::vm::cc ) diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc index f7536641c6fa..fd08546a2856 100644 --- a/tools/iree-benchmark-module-main.cc +++ b/tools/iree-benchmark-module-main.cc @@ -67,7 +67,7 @@ #include "iree/modules/hal/types.h" #include "iree/tooling/context_util.h" #include "iree/tooling/device_util.h" -#include "iree/tooling/vm_util.h" +#include "iree/tooling/function_io.h" #include "iree/vm/api.h" constexpr char kNanosecondsUnitString[] = "ns"; @@ -498,10 +498,15 @@ class IREEBenchmark { iree_string_view_t{function_name.data(), (iree_host_size_t)function_name.size()}, &function)); - - IREE_CHECK_OK(iree_tooling_parse_to_variant_list( - device_.get(), device_allocator_.get(), FLAG_input_list().values, - FLAG_input_list().count, iree_vm_instance_allocator(instance_.get()), + iree_vm_function_signature_t signature = + iree_vm_function_signature(&function); + iree_string_view_t arguments_cconv, results_cconv; + IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( + &signature, &arguments_cconv, &results_cconv)); + + IREE_CHECK_OK(iree_tooling_parse_variants( + arguments_cconv, FLAG_input_list(), device_.get(), + device_allocator_.get(), iree_vm_instance_allocator(instance_.get()), &inputs_)); iree_string_view_t invocation_model = iree_vm_function_lookup_attr_by_name( diff --git a/tools/iree-check-module-main.cc b/tools/iree-check-module-main.cc index baa1d330a80f..b820b6069eba 100644 --- a/tools/iree-check-module-main.cc +++ b/tools/iree-check-module-main.cc @@ -20,7 +20,6 @@ #include "iree/testing/status_matchers.h" #include "iree/tooling/context_util.h" #include "iree/tooling/device_util.h" -#include "iree/tooling/vm_util.h" #include "iree/vm/api.h" #include "iree/vm/bytecode/module.h" diff --git a/tools/iree-e2e-matmul-test.cc b/tools/iree-e2e-matmul-test.cc index 904fe56bf13c..a4d5b11f8007 100644 --- a/tools/iree-e2e-matmul-test.cc +++ b/tools/iree-e2e-matmul-test.cc @@ -19,7 +19,6 @@ #include "iree/modules/hal/module.h" #include "iree/tooling/context_util.h" #include "iree/tooling/device_util.h" -#include "iree/tooling/vm_util.h" #include "iree/vm/api.h" #include "iree/vm/native_module_cc.h" diff --git a/tools/test/iree-run-module-expected.mlir b/tools/test/iree-run-module-expected.mlir index 33953fe57975..ae70ea3b9457 100644 --- a/tools/test/iree-run-module-expected.mlir +++ b/tools/test/iree-run-module-expected.mlir @@ -4,7 +4,6 @@ // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --module=- --function=abs --input=f32=-2 --expected_output=f32=123 --expected_output=f32=2.0) | FileCheck %s --check-prefix=FAILED-FIRST // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --module=- --function=abs --input=f32=-2 --expected_output=f32=-2 --expected_output=f32=4.5) | FileCheck %s --check-prefix=FAILED-SECOND // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --module=- --function=abs --input=f32=-2 --expected_output=f32=-2 --expected_output=4xf32=2.0) | FileCheck %s --check-prefix=FAILED-SHAPE -// RUN: (iree-compile --iree-hal-target-backends=vmvx %s | not iree-run-module --device=local-task --module=- --function=abs --input=f32=-2 --expected_output=f32=-2 --expected_output=8) | FileCheck %s --check-prefix=FAILED-TYPE // SUCCESS-MATCHES: [SUCCESS] // SUCCESS-THRESHOLD: [SUCCESS] @@ -12,7 +11,6 @@ // FAILED-FIRST: [FAILED] result[0]: element at index 0 (-2) does not match the expected (123) // FAILED-SECOND: [FAILED] result[1]: element at index 0 (2) does not match the expected (4.5) // FAILED-SHAPE: [FAILED] result[1]: metadata is f32; expected that the view matches 4xf32 -// FAILED-TYPE: [FAILED] result[1]: variant types mismatch func.func @abs(%input: tensor) -> (tensor, tensor) { %result = math.absf %input : tensor diff --git a/tools/test/iree-run-module-inputs.mlir b/tools/test/iree-run-module-inputs.mlir index c8b739aea78c..08e6c87445d6 100644 --- a/tools/test/iree-run-module-inputs.mlir +++ b/tools/test/iree-run-module-inputs.mlir @@ -14,8 +14,13 @@ func.func @no_input() { // * The VM does not use i1/i8 types, so i32 VM types are returned instead. // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | \ -// RUN: iree-run-module --device=local-sync --module=- --function=scalars \ -// RUN: --input=1 --input=5 --input=1234 --input=-3.14) | \ +// RUN: iree-run-module --device=local-sync \ +// RUN: --module=- \ +// RUN: --function=scalars \ +// RUN: --input=1 \ +// RUN: --input=5 \ +// RUN: --input=1234 \ +// RUN: --input=-3.14) | \ // RUN: FileCheck --check-prefix=INPUT-SCALARS %s // INPUT-SCALARS-LABEL: EXEC @scalars func.func @scalars(%arg0: i1, %arg1: i8, %arg2 : i32, %arg3 : f32) -> (i1, i8, i32, f32) { @@ -34,8 +39,12 @@ func.func @scalars(%arg0: i1, %arg1: i8, %arg2 : i32, %arg3 : f32) -> (i1, i8, i // * Brackets may also be used to separate element values. // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | \ -// RUN: iree-run-module --device=local-sync --module=- --function=buffers \ -// RUN: --input=i32=5 --input=2xi32 --input="2x3xi32=1 2 3 4 5 6") | \ +// RUN: iree-run-module --device=local-sync \ +// RUN: --module=- \ +// RUN: --function=buffers \ +// RUN: --input=i32=5 \ +// RUN: --input=2xi32 \ +// RUN: --input="2x3xi32=1 2 3 4 5 6") | \ // RUN: FileCheck --check-prefix=INPUT-BUFFERS %s // INPUT-BUFFERS-LABEL: EXEC @buffers func.func @buffers(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2x3xi32>) -> (tensor, tensor<2xi32>, tensor<2x3xi32>) { @@ -55,20 +64,25 @@ func.func @buffers(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2x3xi // provide 1+ values. // * Some data types may be converted (i32 -> si32 here) - bug? -// RUN: iree-compile --iree-hal-target-backends=vmvx %s -o %t.vmfb -// -// RUN: iree-run-module --device=local-sync --module=%t.vmfb --function=npy_round_trip \ -// RUN: --input="2xi32=11 12" --input="3xi32=1 2 3" --output=@%t.npy --output=+%t.npy -// -// RUN: iree-run-module --device=local-sync --module=%t.vmfb --function=npy_round_trip \ -// RUN: --input=@%t.npy | \ -// RUN: FileCheck --check-prefix=INPUT-NPY_FILES %s +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s -o=%t.vmfb && \ +// RUN: iree-run-module --device=local-sync \ +// RUN: --module=%t.vmfb \ +// RUN: --function=npy_round_trip \ +// RUN: --input=2xi32=11,12 \ +// RUN: --input=3xi32=1,2,3 \ +// RUN: --output=@%t.npy \ +// RUN: --output=+%t.npy && \ +// RUN: iree-run-module --device=local-sync \ +// RUN: --module=%t.vmfb \ +// RUN: --function=npy_round_trip \ +// RUN: --input=*%t.npy) | \ +// RUN: FileCheck --check-prefix=INPUT-NUMPY %s -// INPUT-NPY_FILES-LABEL: EXEC @npy_round_trip +// INPUT-NUMPY-LABEL: EXEC @npy_round_trip func.func @npy_round_trip(%arg0: tensor<2xi32>, %arg1: tensor<3xi32>) -> (tensor<2xi32>, tensor<3xi32>) { - // INPUT-NPY_FILES: result[0]: hal.buffer_view - // INPUT-NPY_FILES-NEXT: 2xsi32=11 12 - // INPUT-NPY_FILES: result[1]: hal.buffer_view - // INPUT-NPY_FILES-NEXT: 3xsi32=1 2 3 + // INPUT-NUMPY: result[0]: hal.buffer_view + // INPUT-NUMPY-NEXT: 2xsi32=11 12 + // INPUT-NUMPY: result[1]: hal.buffer_view + // INPUT-NUMPY-NEXT: 3xsi32=1 2 3 return %arg0, %arg1 : tensor<2xi32>, tensor<3xi32> } diff --git a/tools/test/iree-run-module-outputs.mlir b/tools/test/iree-run-module-outputs.mlir index 5ae3e3a9cb6f..fc3157f3474d 100644 --- a/tools/test/iree-run-module-outputs.mlir +++ b/tools/test/iree-run-module-outputs.mlir @@ -55,7 +55,8 @@ func.func @numpy() -> (i32, tensor, tensor) { // ----- // Tests output to binary files by round-tripping the output of a function into -// another invocation reading from the binary files. +// another invocation reading from the binary files. Each output is written to +// its own file (optimal for alignment/easier to inspect). // RUN: (iree-compile --iree-hal-target-backends=vmvx %s -o=%t.vmfb && \ // RUN: iree-run-module --device=local-sync \ @@ -69,6 +70,24 @@ func.func @numpy() -> (i32, tensor, tensor) { // RUN: --input=f32=@%t.0.bin \ // RUN: --input=2x4xi32=@%t.1.bin) | \ // RUN: FileCheck --check-prefix=OUTPUT-BINARY %s + +// Tests output to binary files by round-tripping the output of a function into +// another invocation reading from the binary files. The values are appended to +// a single file and read from the single file. + +// RUN: (iree-compile --iree-hal-target-backends=vmvx %s -o=%t.vmfb && \ +// RUN: iree-run-module --device=local-sync \ +// RUN: --module=%t.vmfb \ +// RUN: --function=write_binary \ +// RUN: --output=@%t.bin \ +// RUN: --output=+%t.bin && \ +// RUN: iree-run-module --device=local-sync \ +// RUN: --module=%t.vmfb \ +// RUN: --function=echo_binary \ +// RUN: --input=f32=@%t.bin \ +// RUN: --input=2x4xi32=+%t.bin) | \ +// RUN: FileCheck --check-prefix=OUTPUT-BINARY %s + func.func @write_binary() -> (tensor, tensor) { %0 = arith.constant dense<4.0> : tensor %1 = flow.tensor.constant dense<[[0,1,2,3],[4,5,6,7]]> : tensor<2x4xi32> -> tensor