Skip to content

Commit

Permalink
[xla:ffi] Add auto-binding for FFI results
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621945774
  • Loading branch information
ezhulenev authored and copybara-github committed Apr 4, 2024
1 parent 136750b commit 807e00e
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 90 deletions.
156 changes: 127 additions & 29 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,15 @@ namespace internal {
// A type tag to forward all remaining args as `RemainingArgs`.
struct RemainingArgsTag {};

// A type tag to distinguish arguments tied to the attributes in the
// `Binding` variadic template argument.
// A type tag to distinguish parameters tied to results in the `Binding`
// variadic template. In XLA FFI we use destination passing style APIs and don't
// return anything from the handler, but instead pass a destination where the
// handler should write the result.
template <typename T>
struct RetTag {};

// A type tag to distinguish parameters tied to the attributes in the
// `Binding` variadic template.
template <typename T>
struct AttrTag {};

Expand All @@ -220,7 +227,7 @@ struct AttrTag {};
template <typename T = Dictionary>
struct AttrsTag {};

// A type tag to distinguish arguments extracted from an execution context.
// A type tag to distinguish parameter extracted from an execution context.
template <typename T>
struct CtxTag {};

Expand Down Expand Up @@ -267,6 +274,11 @@ class Binding {
return {std::move(*this)};
}

template <typename T>
Binding<Ts..., internal::RetTag<T>> Ret() && {
return {std::move(*this)};
}

Binding<Ts..., internal::RemainingArgsTag> RemainingArgs() && {
static_assert(!internal::HasRemainingArgsTag<Ts...>::value,
"remaining arguments can be passed just once");
Expand Down Expand Up @@ -340,6 +352,20 @@ struct ArgBinding {
using Arg = void;
};

// XLA FFI binding for a returned result.
//
// Example: binding for the `MyType` result
//
// template <>
// struct RetBinding<MyType> {
// using Ret = MyType;
// };
//
template <typename T>
struct RetBinding {
using Ret = void;
};

// XLA FFI binding for a named attribute.
//
// Example: binding for the `MyType` attribute
Expand Down Expand Up @@ -382,6 +408,10 @@ template <typename Param>
inline constexpr bool is_arg_binding_v =
!std::is_void_v<typename ArgBinding<Param>::Arg>;

template <typename Param>
inline constexpr bool is_ret_binding_v =
!std::is_void_v<typename RetBinding<Param>::Ret>;

template <typename Param>
inline constexpr bool is_attr_binding_v =
!std::is_void_v<typename AttrBinding<Param>::Attr>;
Expand Down Expand Up @@ -410,6 +440,11 @@ struct BindOne<Fn, Param, Params...> {
return BindOne<Fn, Params...>::To(
std::move(fn),
std::move(binding).template Arg<typename ArgBinding<Param>::Arg>());
} else if constexpr (is_ret_binding_v<Param>) {
// Bind parameter as an FFI handler result.
return BindOne<Fn, Params...>::To(
std::move(fn),
std::move(binding).template Ret<typename RetBinding<Param>::Ret>());

} else if constexpr (is_attr_binding_v<Param>) {
// Bind parameter as a named FFI handler attribute.
Expand Down Expand Up @@ -482,12 +517,26 @@ auto Ffi::BindTo(Fn fn) {
}
}

// A container for defining attribute type and name as compile time parameters.
// A container for defining parameters corresponding to results.
template <typename T>
class Result {
public:
Result(T value) : value_(value) {} // NOLINT
T& operator*() { return value_; }
T* operator->() { return &value_; }

private:
T value_;
};

// A container for defining parameters corresponding to attributes with an
// attribute name available as compile time value.
template <typename T, char const* attr_name>
class Attr {
public:
Attr(T value) : value_(value) {} // NOLINT
T& operator*() { return value_; }
T* operator->() { return &value_; }

private:
T value_;
Expand Down Expand Up @@ -527,6 +576,23 @@ struct AttrsBinding<Dictionary> {
template <typename T>
struct ArgDecoding;

//===----------------------------------------------------------------------===//
// Results decoding implementation
//===----------------------------------------------------------------------===//

// XLA FFI results decoding must be defined by specializing this template.
//
// Example: decoding for the `MyType` results
//
// template <>
// struct RetDecoding<MyType> {
// static std::optional<MyType> Decode(XLA_FFI_RetType type, void* ret);
// };
//
// If argument can't be decoded it should return the empty optional.
template <typename T>
struct RetDecoding;

//===----------------------------------------------------------------------===//
// Attributes decoding implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -654,6 +720,7 @@ namespace internal {
// attributes we decoded so far to compute call frame offsets.
struct DecodingOffsets {
int64_t args = 0;
int64_t rets = 0;
int64_t attrs = 0;
};

Expand All @@ -677,6 +744,17 @@ struct Decode {

} // namespace internal

template <typename T>
struct internal::Decode<internal::RetTag<T>> {
static std::optional<Result<T>> call(DecodingOffsets& offsets,
DecodingContext& ctx,
DiagnosticEngine& diagnostic) {
int64_t idx = offsets.rets++;
return RetDecoding<T>::Decode(ctx.call_frame->rets.types[idx],
ctx.call_frame->rets.rets[idx], diagnostic);
}
};

template <typename T>
struct internal::Decode<internal::AttrTag<T>> {
using R = typename AttrDecoding<T>::Type;
Expand Down Expand Up @@ -774,16 +852,16 @@ class RemainingArgs {
public:
RemainingArgs(const XLA_FFI_Args* args, size_t offset)
: args_(args), offset_(offset) {
assert(offset <= args_->num_args && "illegal remaining args offset");
assert(offset <= args_->size && "illegal remaining args offset");
}

size_t size() const { return args_->num_args - offset_; }
size_t size() const { return args_->size - offset_; }
bool empty() const { return size() == 0; }

template <typename T>
Expected<T, std::string> get(size_t index) const {
size_t idx = offset_ + index;
if (idx >= args_->num_args) {
if (idx >= args_->size) {
return Unexpected("Index out of range.");
}

Expand Down Expand Up @@ -818,10 +896,10 @@ class Dictionary {
public:
explicit Dictionary(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {}

size_t size() const { return attrs_->num_attrs; }
size_t size() const { return attrs_->size; }

bool contains(std::string_view name) const {
return Find(name) < attrs_->num_attrs;
return Find(name) < attrs_->size;
}

template <typename T>
Expand All @@ -838,7 +916,7 @@ class Dictionary {
std::optional<T> get(std::string_view name,
DiagnosticEngine& diagnostic) const {
size_t idx = Find(name);
if (idx >= attrs_->num_attrs) {
if (idx >= attrs_->size) {
return diagnostic.Emit("Unexpected attribute: ") << name;
}

Expand All @@ -850,7 +928,7 @@ class Dictionary {
private:
size_t Find(std::string_view name) const {
XLA_FFI_ByteSpan** begin = attrs_->names;
XLA_FFI_ByteSpan** end = begin + attrs_->num_attrs;
XLA_FFI_ByteSpan** end = begin + attrs_->size;

auto name_eq = [&](XLA_FFI_ByteSpan* attr) {
std::string_view name_view = {attr->ptr, attr->len};
Expand Down Expand Up @@ -897,33 +975,41 @@ struct FnArgType {
using Type = T;
};

// Extracts the underlying type from the attribute type tag.
template <typename T>
struct FnArgType<internal::AttrTag<T>> {
using Type = typename AttrDecoding<T>::Type;
template <>
struct FnArgType<internal::RemainingArgsTag> {
using Type = RemainingArgs;
};

// Extracts the underlying type from the context type tag.
// Extracts the underlying type from the returned result type tag.
template <typename T>
struct FnArgType<internal::CtxTag<T>> {
using Type = typename CtxDecoding<T>::Type;
struct FnArgType<internal::RetTag<T>> {
using Type = Result<T>;
};

template <>
struct FnArgType<internal::RemainingArgsTag> {
using Type = RemainingArgs;
// Extracts the underlying type from the attribute type tag.
template <typename T>
struct FnArgType<internal::AttrTag<T>> {
using Type = typename AttrDecoding<T>::Type;
};

template <typename T>
struct FnArgType<internal::AttrsTag<T>> {
using Type = T;
};

// Extracts the underlying type from the context type tag.
template <typename T>
struct FnArgType<internal::CtxTag<T>> {
using Type = typename CtxDecoding<T>::Type;
};

// A template for checking if type in a parameter pack is a tagged one and has
// a special decoding rule defined by template specialization.
template <typename>
struct IsTagged : std::false_type {};
template <typename T>
struct IsTagged<RetTag<T>> : std::true_type {};
template <typename T>
struct IsTagged<AttrTag<T>> : std::true_type {};
template <typename T>
struct IsTagged<AttrsTag<T>> : std::true_type {};
Expand Down Expand Up @@ -958,6 +1044,9 @@ class Handler : public Ffi {

static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;

static constexpr int64_t kNumRets =
internal::NumTagged<internal::RetTag, Ts...>::value;

static constexpr int64_t kNumAttrs =
internal::NumTagged<internal::AttrTag, Ts...>::value;

Expand Down Expand Up @@ -986,32 +1075,41 @@ class Handler : public Ffi {
// Check that the number of passed arguments matches the signature. Each
// individual argument decoding will check the actual type.
if (internal::HasRemainingArgsTag<Ts...>::value) {
if (XLA_FFI_PREDICT_FALSE(call_frame->args.num_args < kNumArgs)) {
if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) {
return InvalidArgument(
call_frame->api,
StrCat("Wrong number of arguments: expected at least ",
kNumArgs - 1, " but got ", call_frame->args.num_args));
kNumArgs - 1, " but got ", call_frame->args.size));
}
} else {
if (XLA_FFI_PREDICT_FALSE(call_frame->args.num_args != kNumArgs)) {
if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) {
return InvalidArgument(
call_frame->api,
StrCat("Wrong number of arguments: expected ", kNumArgs,
" but got ", call_frame->args.num_args));
" but got ", call_frame->args.size));
}
}

// Check that the number of results matches the signature. Each individual
// result decoding will check the actual type.
if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) {
return InvalidArgument(
call_frame->api,
StrCat("Wrong number of results: expected ", kNumRets, " but got ",
call_frame->rets.size));
}

// Check that the number of passed attributes matches the signature. Each
// individual attribute decoding will check the actual type. If we decode
// attributes into a dictionary (or a custom struct decoded from a
// dictionary), then there is no need to check attributes, as the FFI
// handler (or a struct decoding) should be responsible for it.
if (XLA_FFI_PREDICT_FALSE(kNumDictAttrs == 0 &&
call_frame->attrs.num_attrs != kNumAttrs)) {
call_frame->attrs.size != kNumAttrs)) {
return InvalidArgument(
call_frame->api,
StrCat("Wrong number of attributes: expected ", kNumAttrs,
" but got ", call_frame->attrs.num_attrs));
" but got ", call_frame->attrs.size));
}

// Define index sequences to access custom call operands.
Expand Down Expand Up @@ -1205,9 +1303,9 @@ struct DecodeDictionaryAttr {
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE static std::optional<T> Decode(
const XLA_FFI_Attrs* attrs, std::array<std::string_view, kSize> names,
std::index_sequence<Is...>, DiagnosticEngine& diagnostic) {
if (XLA_FFI_PREDICT_FALSE(kSize != attrs->num_attrs)) {
if (XLA_FFI_PREDICT_FALSE(kSize != attrs->size)) {
return diagnostic.Emit("Wrong number of attributes: expected ")
<< kSize << " attributes but got " << attrs->num_attrs;
<< kSize << " attributes but got " << attrs->size;
}

// TODO(ezhulenev): We rely on dictionary to lookup struct members by name
Expand Down
34 changes: 27 additions & 7 deletions xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ typedef enum {
XLA_FFI_ArgType_BUFFER = 1,
} XLA_FFI_ArgType;

//===----------------------------------------------------------------------===//
// Builtin result types
//===----------------------------------------------------------------------===//

typedef enum {
XLA_FFI_RetType_BUFFER = 1,
} XLA_FFI_RetType;

//===----------------------------------------------------------------------===//
// Builtin attribute types
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -249,23 +257,34 @@ struct XLA_FFI_Args {
size_t struct_size;
void* priv;

int64_t num_args;
XLA_FFI_ArgType* types; // length == num_args
void** args; // length == num_args
int64_t size;
XLA_FFI_ArgType* types; // length == size
void** args; // length == size
};

XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Args, args);

struct XLA_FFI_Rets {
size_t struct_size;
void* priv;

int64_t size;
XLA_FFI_RetType* types; // length == size
void** rets; // length == size
};

XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Rets, rets);

// FFI handler attributes are always sorted by name, so that the handler can
// rely on binary search to look up attributes by name.
struct XLA_FFI_Attrs {
size_t struct_size;
void* priv;

int64_t num_attrs;
XLA_FFI_AttrType* types; // length == num_attrs
XLA_FFI_ByteSpan** names; // length == num_attrs
void** attrs; // length == num_attrs
int64_t size;
XLA_FFI_AttrType* types; // length == size
XLA_FFI_ByteSpan** names; // length == size
void** attrs; // length == size
};

XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Attrs, attrs);
Expand All @@ -277,6 +296,7 @@ struct XLA_FFI_CallFrame {
XLA_FFI_Api* api;
XLA_FFI_ExecutionContext* ctx;
XLA_FFI_Args args;
XLA_FFI_Rets rets;
XLA_FFI_Attrs attrs;
};

Expand Down
Loading

0 comments on commit 807e00e

Please sign in to comment.