Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to move collapse/expand_shape through linalg.generic if its a projected permuatation with zeros #387

Draft
wants to merge 1 commit into
base: feature/fused-ops
Choose a base branch
from

Commits on Oct 11, 2024

  1. Before, IR like

    ```
    module attributes {
      llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128",
      llvm.target_triple = "x86_64-unknown-linux-gnu"} {
      func.func @forward(%arg0: tensor<1x255x40x40xi8> {func.orig_type = tensor<1x255x40x40xi8>, onnx.name = "inp3in"} loc(unknown), %arg1: tensor<1x255x20x20xi8> {func.orig_type = tensor<1x255x20x20xi8>, onnx.name = "inp2in"} loc(unknown), %arg2: tensor<1x255x80x80xi8> {func.orig_type = tensor<1x255x80x80xi8>, onnx.name = "inp1in"} loc(unknown)) -> (tensor<1x3x40x40x85xf32> {func.orig_type = tensor<1x3x40x40x85xf32>, onnx.name = "onnx::Sigmoid_564"}) {
        %cst = arith.constant dense<3.100000e+00> : tensor<1x1x1x1xf32> loc(#loc)
        %0 = tensor.empty() : tensor<1x255x40x40xf32> loc(#loc1)
        %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x255x40x40xi8>) outs(%0 : tensor<1x255x40x40xf32>) {
        ^bb0(%in: i8 loc(unknown), %out: f32 loc("")):
          %5 = arith.sitofp %in : i8 to f32 loc(#loc1)
          linalg.yield %5 : f32 loc(#loc1)
        } -> tensor<1x255x40x40xf32> loc(#loc1)
        %2 = tensor.empty() : tensor<1x255x40x40xf32> loc(#loc1)
        %3 = linalg.generic {indexing_maps = [#map, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1, %cst : tensor<1x255x40x40xf32>, tensor<1x1x1x1xf32>) outs(%2 : tensor<1x255x40x40xf32>) {
        ^bb0(%in: f32 loc(""), %in_0: f32 loc(unknown), %out: f32 loc("")):
          %5 = arith.mulf %in, %in_0 : f32 loc(#loc1)
          linalg.yield %5 : f32 loc(#loc1)
        } -> tensor<1x255x40x40xf32> loc(#loc1)
        %expanded = tensor.expand_shape %3 [[0], [1, 2], [3], [4]] output_shape [1, 3, 85, 40, 40] : tensor<1x255x40x40xf32> into tensor<1x3x85x40x40xf32> loc(#loc2)
        %4 = tensor.empty() : tensor<1x3x40x40x85xf32> loc(#loc3)
        %transposed = linalg.transpose ins(%expanded : tensor<1x3x85x40x40xf32>) outs(%4 : tensor<1x3x40x40x85xf32>) permutation = [0, 1, 3, 4, 2]  loc(#loc3)
        return %transposed : tensor<1x3x40x40x85xf32> loc(#loc)
      } loc(#loc)
    } loc(#loc)
    ```
    
    would not move the tensor.expand_shape to the top because the check `isProjectedPermutation` on `#map`
    would fail. It passes if changing `#map` to the equivalent `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` (considering that the tensor size is 1 in that dimension).
    From my understanding, the transformation doesn't strictly require latter map,
    and there is already an option in isProjectedPermutation to allow zeros.
    mgehre-amd committed Oct 11, 2024
    Configuration menu
    Copy the full SHA
    86ab4d3 View commit details
    Browse the repository at this point in the history