From 807e00ed22cb5b3a6afdc87a35b173ca8fa83cc4 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 4 Apr 2024 13:00:26 -0700 Subject: [PATCH] [xla:ffi] Add auto-binding for FFI results PiperOrigin-RevId: 621945774 --- xla/ffi/api/api.h | 156 ++++++++++++++++++++++----- xla/ffi/api/c_api.h | 34 ++++-- xla/ffi/api/ffi.h | 99 ++++++++++++++--- xla/ffi/api/ffi_test.cc | 31 ++++++ xla/ffi/call_frame.cc | 137 ++++++++++++++++------- xla/ffi/call_frame.h | 12 +++ xla/python/weakref_lru_cache.cc | 1 + xla/python/weakref_lru_cache_test.py | 16 +++ 8 files changed, 396 insertions(+), 90 deletions(-) diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index f3595cc20195c3..6169f06b2ce82e 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -210,8 +210,15 @@ namespace internal { // A type tag to forward all remaining args as `RemainingArgs`. struct RemainingArgsTag {}; -// A type tag to distinguish arguments tied to the attributes in the -// `Binding` variadic template argument. +// A type tag to distinguish parameters tied to results in the `Binding` +// variadic template. In XLA FFI we use destination passing style APIs and don't +// return anything from the handler, but instead pass a destination where the +// handler should write the result. +template +struct RetTag {}; + +// A type tag to distinguish parameters tied to the attributes in the +// `Binding` variadic template. template struct AttrTag {}; @@ -220,7 +227,7 @@ struct AttrTag {}; template struct AttrsTag {}; -// A type tag to distinguish arguments extracted from an execution context. +// A type tag to distinguish parameter extracted from an execution context. template struct CtxTag {}; @@ -267,6 +274,11 @@ class Binding { return {std::move(*this)}; } + template + Binding> Ret() && { + return {std::move(*this)}; + } + Binding RemainingArgs() && { static_assert(!internal::HasRemainingArgsTag::value, "remaining arguments can be passed just once"); @@ -340,6 +352,20 @@ struct ArgBinding { using Arg = void; }; +// XLA FFI binding for a returned result. +// +// Example: binding for the `MyType` result +// +// template <> +// struct RetBinding { +// using Ret = MyType; +// }; +// +template +struct RetBinding { + using Ret = void; +}; + // XLA FFI binding for a named attribute. // // Example: binding for the `MyType` attribute @@ -382,6 +408,10 @@ template inline constexpr bool is_arg_binding_v = !std::is_void_v::Arg>; +template +inline constexpr bool is_ret_binding_v = + !std::is_void_v::Ret>; + template inline constexpr bool is_attr_binding_v = !std::is_void_v::Attr>; @@ -410,6 +440,11 @@ struct BindOne { return BindOne::To( std::move(fn), std::move(binding).template Arg::Arg>()); + } else if constexpr (is_ret_binding_v) { + // Bind parameter as an FFI handler result. + return BindOne::To( + std::move(fn), + std::move(binding).template Ret::Ret>()); } else if constexpr (is_attr_binding_v) { // Bind parameter as a named FFI handler attribute. @@ -482,12 +517,26 @@ auto Ffi::BindTo(Fn fn) { } } -// A container for defining attribute type and name as compile time parameters. +// A container for defining parameters corresponding to results. +template +class Result { + public: + Result(T value) : value_(value) {} // NOLINT + T& operator*() { return value_; } + T* operator->() { return &value_; } + + private: + T value_; +}; + +// A container for defining parameters corresponding to attributes with an +// attribute name available as compile time value. template class Attr { public: Attr(T value) : value_(value) {} // NOLINT T& operator*() { return value_; } + T* operator->() { return &value_; } private: T value_; @@ -527,6 +576,23 @@ struct AttrsBinding { template struct ArgDecoding; +//===----------------------------------------------------------------------===// +// Results decoding implementation +//===----------------------------------------------------------------------===// + +// XLA FFI results decoding must be defined by specializing this template. +// +// Example: decoding for the `MyType` results +// +// template <> +// struct RetDecoding { +// static std::optional Decode(XLA_FFI_RetType type, void* ret); +// }; +// +// If argument can't be decoded it should return the empty optional. +template +struct RetDecoding; + //===----------------------------------------------------------------------===// // Attributes decoding implementation //===----------------------------------------------------------------------===// @@ -654,6 +720,7 @@ namespace internal { // attributes we decoded so far to compute call frame offsets. struct DecodingOffsets { int64_t args = 0; + int64_t rets = 0; int64_t attrs = 0; }; @@ -677,6 +744,17 @@ struct Decode { } // namespace internal +template +struct internal::Decode> { + static std::optional> call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + int64_t idx = offsets.rets++; + return RetDecoding::Decode(ctx.call_frame->rets.types[idx], + ctx.call_frame->rets.rets[idx], diagnostic); + } +}; + template struct internal::Decode> { using R = typename AttrDecoding::Type; @@ -774,16 +852,16 @@ class RemainingArgs { public: RemainingArgs(const XLA_FFI_Args* args, size_t offset) : args_(args), offset_(offset) { - assert(offset <= args_->num_args && "illegal remaining args offset"); + assert(offset <= args_->size && "illegal remaining args offset"); } - size_t size() const { return args_->num_args - offset_; } + size_t size() const { return args_->size - offset_; } bool empty() const { return size() == 0; } template Expected get(size_t index) const { size_t idx = offset_ + index; - if (idx >= args_->num_args) { + if (idx >= args_->size) { return Unexpected("Index out of range."); } @@ -818,10 +896,10 @@ class Dictionary { public: explicit Dictionary(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {} - size_t size() const { return attrs_->num_attrs; } + size_t size() const { return attrs_->size; } bool contains(std::string_view name) const { - return Find(name) < attrs_->num_attrs; + return Find(name) < attrs_->size; } template @@ -838,7 +916,7 @@ class Dictionary { std::optional get(std::string_view name, DiagnosticEngine& diagnostic) const { size_t idx = Find(name); - if (idx >= attrs_->num_attrs) { + if (idx >= attrs_->size) { return diagnostic.Emit("Unexpected attribute: ") << name; } @@ -850,7 +928,7 @@ class Dictionary { private: size_t Find(std::string_view name) const { XLA_FFI_ByteSpan** begin = attrs_->names; - XLA_FFI_ByteSpan** end = begin + attrs_->num_attrs; + XLA_FFI_ByteSpan** end = begin + attrs_->size; auto name_eq = [&](XLA_FFI_ByteSpan* attr) { std::string_view name_view = {attr->ptr, attr->len}; @@ -897,21 +975,21 @@ struct FnArgType { using Type = T; }; -// Extracts the underlying type from the attribute type tag. -template -struct FnArgType> { - using Type = typename AttrDecoding::Type; +template <> +struct FnArgType { + using Type = RemainingArgs; }; -// Extracts the underlying type from the context type tag. +// Extracts the underlying type from the returned result type tag. template -struct FnArgType> { - using Type = typename CtxDecoding::Type; +struct FnArgType> { + using Type = Result; }; -template <> -struct FnArgType { - using Type = RemainingArgs; +// Extracts the underlying type from the attribute type tag. +template +struct FnArgType> { + using Type = typename AttrDecoding::Type; }; template @@ -919,11 +997,19 @@ struct FnArgType> { using Type = T; }; +// Extracts the underlying type from the context type tag. +template +struct FnArgType> { + using Type = typename CtxDecoding::Type; +}; + // A template for checking if type in a parameter pack is a tagged one and has // a special decoding rule defined by template specialization. template struct IsTagged : std::false_type {}; template +struct IsTagged> : std::true_type {}; +template struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; @@ -958,6 +1044,9 @@ class Handler : public Ffi { static constexpr int64_t kNumArgs = internal::NumArgs::value; + static constexpr int64_t kNumRets = + internal::NumTagged::value; + static constexpr int64_t kNumAttrs = internal::NumTagged::value; @@ -986,32 +1075,41 @@ class Handler : public Ffi { // Check that the number of passed arguments matches the signature. Each // individual argument decoding will check the actual type. if (internal::HasRemainingArgsTag::value) { - if (XLA_FFI_PREDICT_FALSE(call_frame->args.num_args < kNumArgs)) { + if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of arguments: expected at least ", - kNumArgs - 1, " but got ", call_frame->args.num_args)); + kNumArgs - 1, " but got ", call_frame->args.size)); } } else { - if (XLA_FFI_PREDICT_FALSE(call_frame->args.num_args != kNumArgs)) { + if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of arguments: expected ", kNumArgs, - " but got ", call_frame->args.num_args)); + " but got ", call_frame->args.size)); } } + // Check that the number of results matches the signature. Each individual + // result decoding will check the actual type. + if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of results: expected ", kNumRets, " but got ", + call_frame->rets.size)); + } + // Check that the number of passed attributes matches the signature. Each // individual attribute decoding will check the actual type. If we decode // attributes into a dictionary (or a custom struct decoded from a // dictionary), then there is no need to check attributes, as the FFI // handler (or a struct decoding) should be responsible for it. if (XLA_FFI_PREDICT_FALSE(kNumDictAttrs == 0 && - call_frame->attrs.num_attrs != kNumAttrs)) { + call_frame->attrs.size != kNumAttrs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of attributes: expected ", kNumAttrs, - " but got ", call_frame->attrs.num_attrs)); + " but got ", call_frame->attrs.size)); } // Define index sequences to access custom call operands. @@ -1205,9 +1303,9 @@ struct DecodeDictionaryAttr { XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional Decode( const XLA_FFI_Attrs* attrs, std::array names, std::index_sequence, DiagnosticEngine& diagnostic) { - if (XLA_FFI_PREDICT_FALSE(kSize != attrs->num_attrs)) { + if (XLA_FFI_PREDICT_FALSE(kSize != attrs->size)) { return diagnostic.Emit("Wrong number of attributes: expected ") - << kSize << " attributes but got " << attrs->num_attrs; + << kSize << " attributes but got " << attrs->size; } // TODO(ezhulenev): We rely on dictionary to lookup struct members by name diff --git a/xla/ffi/api/c_api.h b/xla/ffi/api/c_api.h index e7ebd8da060796..d4dc246fc6f8a4 100644 --- a/xla/ffi/api/c_api.h +++ b/xla/ffi/api/c_api.h @@ -187,6 +187,14 @@ typedef enum { XLA_FFI_ArgType_BUFFER = 1, } XLA_FFI_ArgType; +//===----------------------------------------------------------------------===// +// Builtin result types +//===----------------------------------------------------------------------===// + +typedef enum { + XLA_FFI_RetType_BUFFER = 1, +} XLA_FFI_RetType; + //===----------------------------------------------------------------------===// // Builtin attribute types //===----------------------------------------------------------------------===// @@ -249,23 +257,34 @@ struct XLA_FFI_Args { size_t struct_size; void* priv; - int64_t num_args; - XLA_FFI_ArgType* types; // length == num_args - void** args; // length == num_args + int64_t size; + XLA_FFI_ArgType* types; // length == size + void** args; // length == size }; XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Args, args); +struct XLA_FFI_Rets { + size_t struct_size; + void* priv; + + int64_t size; + XLA_FFI_RetType* types; // length == size + void** rets; // length == size +}; + +XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Rets, rets); + // FFI handler attributes are always sorted by name, so that the handler can // rely on binary search to look up attributes by name. struct XLA_FFI_Attrs { size_t struct_size; void* priv; - int64_t num_attrs; - XLA_FFI_AttrType* types; // length == num_attrs - XLA_FFI_ByteSpan** names; // length == num_attrs - void** attrs; // length == num_attrs + int64_t size; + XLA_FFI_AttrType* types; // length == size + XLA_FFI_ByteSpan** names; // length == size + void** attrs; // length == size }; XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Attrs, attrs); @@ -277,6 +296,7 @@ struct XLA_FFI_CallFrame { XLA_FFI_Api* api; XLA_FFI_ExecutionContext* ctx; XLA_FFI_Args args; + XLA_FFI_Rets rets; XLA_FFI_Attrs attrs; }; diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index ab7a539c707215..b652d5accda0d7 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -171,6 +171,37 @@ template using BufferR3 = Buffer; template using BufferR4 = Buffer; // clang-format on +namespace internal { + +inline BufferBase DecodeBuffer(XLA_FFI_Buffer* buf) { + return BufferBase{static_cast(buf->dtype), buf->data, + Span(buf->dims, buf->rank)}; +} + +template +std::optional> DecodeBuffer(XLA_FFI_Buffer* buf, + DiagnosticEngine& diagnostic) { + if (auto buf_dtype = static_cast(buf->dtype); + XLA_FFI_PREDICT_FALSE(buf_dtype != dtype)) { + return diagnostic.Emit("Wrong buffer dtype: expected ") + << dtype << " but got " << buf_dtype; + } + + if constexpr (rank != internal::kDynamicRank) { + if (XLA_FFI_PREDICT_FALSE(buf->rank != rank)) { + return diagnostic.Emit("Wrong buffer rank: expected ") + << rank << " but got " << buf->rank; + } + } + + Buffer buffer; + buffer.data = static_cast*>(buf->data); + buffer.dimensions = Span(buf->dims, buf->rank); + return buffer; +} + +} // namespace internal + //===----------------------------------------------------------------------===// // Arguments binding //===----------------------------------------------------------------------===// @@ -185,6 +216,20 @@ struct ArgBinding> { using Arg = Buffer; }; +//===----------------------------------------------------------------------===// +// Results binding +//===----------------------------------------------------------------------===// + +template <> +struct RetBinding> { + using Ret = BufferBase; +}; + +template +struct RetBinding>> { + using Ret = Buffer; +}; + //===----------------------------------------------------------------------===// // Arguments decoding //===----------------------------------------------------------------------===// @@ -205,9 +250,7 @@ struct ArgDecoding { return diagnostic.Emit("Wrong argument type: expected ") << XLA_FFI_ArgType_BUFFER << " but got " << type; } - auto* buf = reinterpret_cast(arg); - return BufferBase{static_cast(buf->dtype), buf->data, - Span(buf->dims, buf->rank)}; + return internal::DecodeBuffer(reinterpret_cast(arg)); } }; @@ -221,25 +264,47 @@ struct ArgDecoding> { << XLA_FFI_ArgType_BUFFER << " but got " << type; } - auto* buf = reinterpret_cast(arg); + return internal::DecodeBuffer( + reinterpret_cast(arg), diagnostic); + } +}; + +//===----------------------------------------------------------------------===// +// Results decoding +//===----------------------------------------------------------------------===// - if (auto actual_dtype = static_cast(buf->dtype); - XLA_FFI_PREDICT_FALSE(actual_dtype != dtype)) { - return diagnostic.Emit("Wrong buffer dtype: expected ") - << dtype << " but got " << actual_dtype; +inline std::ostream& operator<<(std::ostream& os, const XLA_FFI_RetType type) { + switch (type) { + case XLA_FFI_RetType_BUFFER: + return os << "buffer"; + } +} + +template <> +struct RetDecoding { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional> Decode( + XLA_FFI_RetType type, void* ret, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) { + return diagnostic.Emit("Wrong result type: expected ") + << XLA_FFI_RetType_BUFFER << " but got " << type; } + return internal::DecodeBuffer(reinterpret_cast(ret)); + } +}; - if constexpr (rank != internal::kDynamicRank) { - if (XLA_FFI_PREDICT_FALSE(buf->rank != rank)) { - return diagnostic.Emit("Wrong buffer rank: expected ") - << rank << " but got " << buf->rank; - } +template +struct RetDecoding> { + XLA_FFI_ATTRIBUTE_ALWAYS_INLINE + static std::optional>> Decode( + XLA_FFI_RetType type, void* ret, DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_RetType_BUFFER)) { + return diagnostic.Emit("Wrong result type: expected ") + << XLA_FFI_RetType_BUFFER << " but got " << type; } - Buffer buffer; - buffer.data = static_cast*>(buf->data); - buffer.dimensions = Span(buf->dims, buf->rank); - return buffer; + return internal::DecodeBuffer( + reinterpret_cast(ret), diagnostic); } }; diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index f389a2ac05fb19..b1dc769ca8de9d 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -100,6 +100,25 @@ TEST(FfiTest, BufferArgument) { TF_ASSERT_OK(status); } +TEST(FfiTest, BufferBaseResult) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder; + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto handler = + Ffi::Bind().Ret().To([&](Result buffer) { + EXPECT_EQ(buffer->data, storage.data()); + EXPECT_EQ(buffer->dimensions.size(), 2); + return Error::Success(); + }); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, MissingBufferArgument) { CallFrameBuilder builder; auto call_frame = builder.Build(); @@ -170,6 +189,18 @@ TEST(FfiTest, AutoBinding) { TF_ASSERT_OK(status); } +TEST(FfiTest, AutoBindingResult) { + auto handler = + Ffi::BindTo(+[](Result buffer) { return Error::Success(); }); + + CallFrameBuilder builder; + builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{}); + auto call_frame = builder.Build(); + + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + struct I32AndF32 { int32_t i32; float f32; diff --git a/xla/ffi/call_frame.cc b/xla/ffi/call_frame.cc index 06dd10fb0833c0..4d064bfafcd06b 100644 --- a/xla/ffi/call_frame.cc +++ b/xla/ffi/call_frame.cc @@ -85,13 +85,19 @@ void CallFrameBuilder::AddBufferArg(se::DeviceMemoryBase memory, args_.push_back(Buffer{memory, type, {dims.begin(), dims.end()}}); } +void CallFrameBuilder::AddBufferRet(se::DeviceMemoryBase memory, + PrimitiveType type, + absl::Span dims) { + rets_.push_back(Buffer{memory, type, {dims.begin(), dims.end()}}); +} + void CallFrameBuilder::AddAttributes(AttributesMap attrs) { for (auto& [name, attr] : attrs) { attrs_.try_emplace(std::move(name), std::move(attr)); } } -CallFrame CallFrameBuilder::Build() { return CallFrame(args_, attrs_); } +CallFrame CallFrameBuilder::Build() { return CallFrame(args_, rets_, attrs_); } CallFrameBuilder::CallFrameBuilder(CallFrameBuilder&&) = default; CallFrameBuilder& CallFrameBuilder::operator=(CallFrameBuilder&&) = default; @@ -165,6 +171,21 @@ struct CallFrame::Arguments { XLA_FFI_Args ffi_args = {XLA_FFI_Args_STRUCT_SIZE, nullptr}; }; +struct CallFrame::Results { + explicit Results(size_t size) { + results.reserve(size); + types.reserve(size); + rets.reserve(size); + } + + std::vector results; + + std::vector types; // XLA_FFI_Rets::types + std::vector rets; // XLA_FFI_Rets::rets + + XLA_FFI_Rets ffi_rets = {XLA_FFI_Rets_STRUCT_SIZE, nullptr}; +}; + struct CallFrame::Attributes { explicit Attributes(size_t size) { attributes.reserve(size); @@ -187,8 +208,11 @@ struct CallFrame::Attributes { //===----------------------------------------------------------------------===// CallFrame::CallFrame(absl::Span args, + absl::Span rets, const CallFrameBuilder::AttributesMap& attrs) - : arguments_(InitArgs(args)), attributes_(InitAttrs(attrs)) {} + : arguments_(InitArgs(args)), + results_(InitRets(rets)), + attributes_(InitAttrs(attrs)) {} XLA_FFI_CallFrame CallFrame::Build(XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx) { @@ -196,54 +220,60 @@ XLA_FFI_CallFrame CallFrame::Build(XLA_FFI_Api* api, call_frame.api = api; call_frame.ctx = ctx; call_frame.args = arguments_->ffi_args; + call_frame.rets = results_->ffi_rets; call_frame.attrs = attributes_->ffi_attrs; return call_frame; } CallFrame::~CallFrame() = default; +// We rely on casting to and from underlying integral type to convert from +// PrimitiveType to XLA FFI DataType, and for safety convert all unknown types +// to invalid type, otherwise we can accidentally cause UB. +static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { + switch (primitive_type) { + case PrimitiveType::PRIMITIVE_TYPE_INVALID: + case PrimitiveType::PRED: + case PrimitiveType::S8: + case PrimitiveType::S16: + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U8: + case PrimitiveType::U16: + case PrimitiveType::U32: + case PrimitiveType::U64: + case PrimitiveType::F16: + case PrimitiveType::F32: + case PrimitiveType::F64: + case PrimitiveType::BF16: + return static_cast(primitive_type); + default: + DCHECK(false) << "Unsupported primitive type" << primitive_type; + return XLA_FFI_DataType_INVALID; + } +} + +CallFrame::Buffer CallFrame::ConvertBuffer( + const CallFrameBuilder::Buffer& buffer) { + Buffer result; + result.dims = buffer.dims; + result.buffer.data = const_cast(buffer.memory.opaque()); + result.buffer.dtype = ToDataType(buffer.type); + result.buffer.rank = result.dims.size(); + return result; +} + //===----------------------------------------------------------------------===// // Call frame arguments //===----------------------------------------------------------------------===// -/*static*/ std::unique_ptr CallFrame::InitArgs( +std::unique_ptr CallFrame::InitArgs( absl::Span bargs) { auto res = std::make_unique(bargs.size()); - // We rely on casting to and from underlying integral type to convert from - // PrimitiveType to XLA FFI DataType, and for safety convert all unknown types - // to invalid type, otherwise we can accidentally cause UB. - auto to_data_type = [](PrimitiveType primitive_type) { - switch (primitive_type) { - case PrimitiveType::PRIMITIVE_TYPE_INVALID: - case PrimitiveType::PRED: - case PrimitiveType::S8: - case PrimitiveType::S16: - case PrimitiveType::S32: - case PrimitiveType::S64: - case PrimitiveType::U8: - case PrimitiveType::U16: - case PrimitiveType::U32: - case PrimitiveType::U64: - case PrimitiveType::F16: - case PrimitiveType::F32: - case PrimitiveType::F64: - case PrimitiveType::BF16: - return static_cast(primitive_type); - default: - DCHECK(false) << "Unsupported primitive type" << primitive_type; - return XLA_FFI_DataType_INVALID; - } - }; - // Convert call frame builder arguments to call frame arguments. for (const CallFrameBuilder::Buffer& barg : bargs) { - Buffer buffer; - buffer.dims = barg.dims; - buffer.buffer.data = const_cast(barg.memory.opaque()); - buffer.buffer.dtype = to_data_type(barg.type); - buffer.buffer.rank = buffer.dims.size(); - res->arguments.push_back(std::move(buffer)); + res->arguments.push_back(ConvertBuffer(barg)); } // Fix up pointers in XLA FFI structs. @@ -259,13 +289,46 @@ CallFrame::~CallFrame() = default; // Finally initialize the XLA FFI struct. At this point all storage is // allocated and it's safe to grab a pointer to it. - res->ffi_args.num_args = res->arguments.size(); + res->ffi_args.size = res->arguments.size(); res->ffi_args.types = res->types.data(); res->ffi_args.args = res->args.data(); return res; } +//===----------------------------------------------------------------------===// +// Call frame results +//===----------------------------------------------------------------------===// + +std::unique_ptr CallFrame::InitRets( + absl::Span brets) { + auto res = std::make_unique(brets.size()); + + // Convert call frame builder arguments to call frame arguments. + for (const CallFrameBuilder::Buffer& bret : brets) { + res->results.push_back(ConvertBuffer(bret)); + } + + // Fix up pointers in XLA FFI structs. + for (CallFrame::Buffer& arg : res->results) { + arg.buffer.dims = arg.dims.data(); + } + + // Initialize vectors required for building XLA_FFI_Rets. + for (CallFrame::Buffer& ret : res->results) { + res->types.push_back(XLA_FFI_RetType_BUFFER); + res->rets.push_back(&ret.buffer); + } + + // Finally initialize the XLA FFI struct. At this point all storage is + // allocated and it's safe to grab a pointer to it. + res->ffi_rets.size = res->results.size(); + res->ffi_rets.types = res->types.data(); + res->ffi_rets.rets = res->rets.data(); + + return res; +} + //===----------------------------------------------------------------------===// // Call frame attributes //===----------------------------------------------------------------------===// @@ -401,7 +464,7 @@ struct CallFrame::AttributeStorage { // Finally initialize XLA FFI struct. At this point all storage is allocated // and it's safe to grab a pointer to it. - res->ffi_attrs.num_attrs = res->attributes.size(); + res->ffi_attrs.size = res->attributes.size(); res->ffi_attrs.names = res->names.data(); res->ffi_attrs.types = res->types.data(); res->ffi_attrs.attrs = res->attrs.data(); diff --git a/xla/ffi/call_frame.h b/xla/ffi/call_frame.h index b8f05105d92343..d283f607fb84a0 100644 --- a/xla/ffi/call_frame.h +++ b/xla/ffi/call_frame.h @@ -100,6 +100,9 @@ class CallFrameBuilder { void AddBufferArg(se::DeviceMemoryBase memory, PrimitiveType type, absl::Span dims); + void AddBufferRet(se::DeviceMemoryBase memory, PrimitiveType type, + absl::Span dims); + void AddAttributes(AttributesMap attrs); private: @@ -108,6 +111,7 @@ class CallFrameBuilder { struct Buffer; std::vector args_; + std::vector rets_; AttributesMap attrs_; }; @@ -132,21 +136,29 @@ class CallFrame { struct Buffer; struct Dictionary; struct NamedAttribute; + struct Results; struct Scalar; struct String; using Attribute = std::variant; CallFrame(absl::Span args, + absl::Span rets, const CallFrameBuilder::AttributesMap& attrs); static std::unique_ptr InitArgs( absl::Span args); + static std::unique_ptr InitRets( + absl::Span rets); + static std::unique_ptr InitAttrs( const CallFrameBuilder::AttributesMap& attrs); + static Buffer ConvertBuffer(const CallFrameBuilder::Buffer& buffer); + std::unique_ptr arguments_; + std::unique_ptr results_; std::unique_ptr attributes_; // Declare implementation detail structs to grant access to private members. diff --git a/xla/python/weakref_lru_cache.cc b/xla/python/weakref_lru_cache.cc index 4ce20a5c9a8451..a64cf93a6c4dde 100644 --- a/xla/python/weakref_lru_cache.cc +++ b/xla/python/weakref_lru_cache.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "nanobind/nanobind.h" // from @nanobind #include "nanobind/stl/shared_ptr.h" // from @nanobind // IWYU pragma: keep +#include "nanobind/stl/string.h" // from @nanobind // IWYU pragma: keep #include "nanobind/stl/vector.h" // from @nanobind // IWYU pragma: keep #include "xla/pjrt/lru_cache.h" #include "xla/python/nb_helpers.h" diff --git a/xla/python/weakref_lru_cache_test.py b/xla/python/weakref_lru_cache_test.py index a3b3d5ee6bf5cf..ad5f07bee0bf72 100644 --- a/xla/python/weakref_lru_cache_test.py +++ b/xla/python/weakref_lru_cache_test.py @@ -130,6 +130,22 @@ def __hash__(self): for _ in range(100): cache(wrkey, CrashingKey()) + def testPrintingStats(self): + class WRKey: + pass + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + wrkey = WRKey() + for i in range(10): + cache(wrkey, i) + for i in range(5): + cache(wrkey, i) + + self.assertEqual( + repr(cache.cache_info()), + "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", + ) + if __name__ == "__main__": absltest.main()