Skip to content

Commit

Permalink
fix getShapedFloat with APFloat value (#471)
Browse files Browse the repository at this point in the history
- fix `getShapedValue` in ONNXToStablehloCommon to support APFloat with
fp16 input
  • Loading branch information
YellowHCH authored Oct 24, 2024
1 parent 4c96ed9 commit 07f9744
Showing 1 changed file with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp
index 832e72b9..3d8e1d7f 100644
--- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp
+++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.hpp
@@ -71,14 +71,20 @@ Value getShapedFloat(Location loc, ConversionPatternRewriter &rewriter,
const T &value, Value &inp) {
Value broadcastedValue;
ShapedType inpType = inp.getType().cast<ShapedType>();
+ float f32Value;
+ if constexpr (std::is_same_v<APFloat, T>) {
+ f32Value = cast<APFloat>(value).convertToFloat();
+ } else {
+ f32Value = value;
+ }
if (inpType.hasStaticShape())
broadcastedValue = rewriter.create<stablehlo::ConstantOp>(
loc, DenseElementsAttr::get(inpType,
- rewriter.getFloatAttr(inpType.getElementType(), value)));
+ rewriter.getFloatAttr(inpType.getElementType(), f32Value)));
else {
Type elemType = inpType.getElementType();
Value floatValue = rewriter.create<stablehlo::ConstantOp>(
- loc, rewriter.getFloatAttr(elemType, value));
+ loc, rewriter.getFloatAttr(elemType, f32Value));
Value shape = rewriter.create<shape::ShapeOfOp>(loc, inp);
broadcastedValue = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, inpType, floatValue, shape, rewriter.getI64TensorAttr({}));

0 comments on commit 07f9744

Please sign in to comment.