Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Microscaling data types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU) #2581

Merged
merged 1 commit into from
Oct 23, 2024

Conversation

sergey-kozub
Copy link
Contributor

@sergey-kozub sergey-kozub commented Oct 7, 2024

This is a proposal to add MX (microscaling) floating point types to StableHLO.

Related links:

  • StableHLO PR#2582 Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU)
  • LLVM PR#95392 [APFloat] Add APFloat support for FP4 data type
  • LLVM PR#94735 [APFloat] Add APFloat support for FP6 data types
  • LLVM PR#107127 [APFloat] Add APFloat support for E8M0 type
  • LLVM PR#108877 [MLIR] Add f4E2M1FN type
  • LLVM PR#107999 [MLIR] Add f6E2M3FN type
  • LLVM PR#105573 [MLIR] Add f6E3M2FN type
  • LLVM PR#111028 [MLIR] Add f8E8M0FNU type
  • JAX-ML PR#181 Add sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn
  • JAX-ML PR#166 Add float8_e8m0_fnu (E8M0) OCP MX scale format

Copy link
Member

@GleasonK GleasonK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sending the RFC! I've signal boosted it around internally, would also recommend a post on openxla-discuss similar to the FP8 posting to get the attention of folks not actively watching PRs on the StableHLO repo.

Leaving a "request changes" reply to make the review blocking until RFC period is over, but overall LGTM!

@reedwm
Copy link
Member

reedwm commented Oct 8, 2024

Does this add the MXFP8, MXFP6, and MXFP4 types where a scale applies to a block of k elements? Or does it just add the raw unscaled element types (F4E2M1FN, etc), similar to the existing FP8 types in StableHLO? Would be good to clarify this in the RFC

Also this is more of an XLA question, but do you intend for FP4 and FP6 to be packed? FP6 in particular would span bytes if packed, unlike the existing packed int4 and int2 types.

CC @mooskagh

@sergey-kozub
Copy link
Contributor Author

would also recommend a post on openxla-discuss similar to the FP8 posting

I'm currently working on this, will post it within the next few days.

does it just add the raw unscaled element types (F4E2M1FN, etc), similar to the existing FP8 types in StableHLO?

Yes.

Does this add the MXFP8, MXFP6, and MXFP4 types where a scale applies to a block of k elements?

No, but I'll submit a JAX API endpoint for (some) MXFP types soon after the primitive types are landed.

Also this is more of an XLA question, but do you intend for FP4 and FP6 to be packed?

Yes, these will be packed, same as sub-byte integer types. I'll add this info to the XLA RFC.

@sergey-kozub
Copy link
Contributor Author

sergey-kozub commented Oct 9, 2024

Posted the RFC in the XLA discussion forum.

Also added a note in the openxla-discuss group.

@GleasonK
Copy link
Member

GleasonK commented Oct 9, 2024

Are MXFP types defined in MLIR yet? Or what will a program using one of those datatypes look like?

Edit: seeing the MX RFC you shared in the previous comment, has some MLIR examples, can address this Q there.

@sergey-kozub
Copy link
Contributor Author

Could this be unblocked now?

@GleasonK
Copy link
Member

Yes, the 2w rfc period is completed. LGTM.

@GleasonK GleasonK merged commit 2952108 into openxla:main Oct 23, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants