Skip to content

Commit

Permalink
[torch-frontend] fix expand_as with different dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Sep 14, 2024
1 parent dd0a9d6 commit b110102
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,12 @@ struct ConvertAtenExpandAsOp : public OpConversionPattern<AtenExpandAsOp> {
Value rhs = adaptor.getOther();
int64_t lhs_rank = cast<RankedTensorType>(lhs.getType()).getRank();
int64_t rhs_rank = cast<RankedTensorType>(rhs.getType()).getRank();
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));

Value shape = rewriter.create<shape::ShapeOfOp>(op->getLoc(), rhs);
Value result = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(), rhs.getType(), lhs, shape,
op->getLoc(), outType, lhs, shape,
rewriter.getDenseI64ArrayAttr(llvm::to_vector(
llvm::seq<int64_t>(rhs_rank - lhs_rank, rhs_rank))));
rewriter.replaceOp(op, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,6 @@ def test_expand_as():
module = compile(ExpandAsModule(), inputs, "stablehlo")
numerical_test_helper(module, inputs, ExpandAsModule()(*inputs))

inputs = [tu.tensor(1.0), tu.randn(3, 4)]
inputs = [tu.tensor(1.0), tu.randn(3, 4).to(torch.int64)]
module = compile(ExpandAsModule1(), inputs, "stablehlo")
numerical_test_helper(module, inputs, ExpandAsModule1()(*inputs))

0 comments on commit b110102

Please sign in to comment.