Skip to content

Commit

Permalink
[torch-frontend] use new register method to register byteir.flash_att…
Browse files Browse the repository at this point in the history
…n ops
  • Loading branch information
qingyunqu committed Oct 25, 2024
1 parent d10fadf commit 9d3578d
Showing 1 changed file with 73 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
import torch
import math
from torch.library import Library

OPERATORS = []


def op(schema):
def inner(f):
# TODO: Refactor the Library API so this is less rage inducing
# TODO: Perhaps the namespace should be directly based on Python
# module
if "::" in schema:
ns = schema.split("::", 2)[0]
else:
ns = "contrib"
# TODO: Library doesn't allow FRAGMENT, need to allow it
lib = Library(ns, "FRAGMENT")
name = lib.define(schema)
if "::" in name:
name = name.split("::", 2)[1]
lib.impl(name, f, "CompositeExplicitAutograd")
OPERATORS.append(lib)
return getattr(getattr(torch.ops, ns), name)

return inner


@op(
"byteir::flash_attn_fwd(Tensor q, Tensor k, Tensor v, float dropout_p, float softmax_scale, bool causal, bool return_softmax) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
)

@torch.library.custom_op("byteir::flash_attn_fwd", mutates_args=())
def byteir_flash_attn_fwd(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, return_softmax: bool
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
sizes = q.shape
batch_size = sizes[0]
seqlen_q = sizes[1]
num_heads = sizes[2]
seqlen_k = k.shape[1]

rng = torch.empty((2), dtype=torch.int64, device="meta")
softmax_lse = torch.empty(
(batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta"
)
p = None
if return_softmax:
p = torch.empty(
(batch_size, num_heads, seqlen_q, seqlen_k),
dtype=torch.float,
device="meta",
)
q_padded = q
k_padded = k
v_padded = v
out = torch.empty_like(q_padded)
out_padded = torch.empty_like(out)
return out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng

@torch.library.register_fake("byteir::flash_attn_fwd")
def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
sizes = q.shape
batch_size = sizes[0]
Expand All @@ -55,9 +56,32 @@ def byteir_flash_attn_fwd(q, k, v, dropout_p, softmax_scale, causal, return_soft
return out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng


@op(
"byteir::flash_attn_bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, float dropout_p, float softmax_scale, bool causal, Tensor rng) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"
)
@torch.library.custom_op("byteir::flash_attn_bwd", mutates_args=())
def byteir_flash_attn_bwd(
dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, rng_state: torch.Tensor
) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
sizes = q.shape
batch_size = sizes[0]
seqlen_q = sizes[1]
num_heads = sizes[2]
seqlen_q_rounded = ((seqlen_q + 127) // 128) * 128
head_size = sizes[3]
head_size_rounded = ((head_size + 31) // 32) * 32
dq_accum = torch.empty(
(batch_size, num_heads, seqlen_q_rounded, head_size_rounded),
dtype=torch.float,
device="meta",
)
softmax_d = torch.empty(
(batch_size, num_heads, seqlen_q_rounded), dtype=torch.float, device="meta"
)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
return dq, dk, dv, softmax_d, dq_accum


@torch.library.register_fake("byteir::byteir_flash_attn_bwd")
def byteir_flash_attn_bwd(
dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, causal, rng_state
):
Expand All @@ -82,9 +106,23 @@ def byteir_flash_attn_bwd(
return dq, dk, dv, softmax_d, dq_accum


@op(
"byteir::flash_attn_kvcache(Tensor q, Tensor k, Tensor v, Tensor kcache, Tensor vcache, Tensor seqlen_k, float softmax_scale, bool causal) -> (Tensor, Tensor)"
)
@torch.library.custom_op("byteir::flash_attn_kvcache", mutates_args())
def byteir_flash_attn_kvcache(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kcache: torch.Tensor, vcache: torch.Tensor, seqlen_k: torch.Tensor, softmax_scale: float, causal: bool
) -> (torch.Tensor, torch.Tensor):
sizes = q.shape
batch_size = sizes[0]
seqlen_q = sizes[1]
num_heads = sizes[2]

softmax_lse = torch.empty(
(batch_size, num_heads, seqlen_q), dtype=torch.float, device="meta"
)
out = torch.empty_like(q)
return out, softmax_lse


@torch.library.register_fake("byteir::flash_attn_kvcache")
def byteir_flash_attn_kvcache(q, k, v, kcache, vcache, seqlen_k, softmax_scale, causal):
sizes = q.shape
batch_size = sizes[0]
Expand Down

0 comments on commit 9d3578d

Please sign in to comment.