Skip to content

Commit

Permalink
Rename shark-turbine -> iree.turbine (#197)
Browse files Browse the repository at this point in the history
* Move files from files from `shark-turbine` to `iree/turbine`.
* Update imports
* Update `setup.py`
* Make backward redirect `shark-turbine` -> `iree.turbine` (do we need
this?)

Progress on #28

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Oct 5, 2024
1 parent 3312f73 commit 40016ad
Show file tree
Hide file tree
Showing 192 changed files with 250 additions and 807 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ include README.md
include requirements.txt
include pytorch-cpu-requirements.txt
include version_info.json
include shark_turbine/ops/templates/*.mlir
include iree/turbine/ops/templates/*.mlir
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Turbine provides a collection of tools:

* *AOT Export*: For compiling one or more `nn.Module`s to compiled, deployment
ready artifacts. This operates via both a simple one-shot export API (Already upstreamed to [torch-mlir](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py))
for simple models and an underlying [advanced API](shark_turbine/aot/compiled_module.py) for complicated models
for simple models and an underlying [advanced API](iree/turbine/aot/compiled_module.py) for complicated models
and accessing the full features of the runtime.
* *Eager Execution*: A `torch.compile` backend is provided and a Turbine Tensor/Device
is available for more native, interactive use within a PyTorch session.
Expand Down
4 changes: 1 addition & 3 deletions build_tools/build_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ def main():
print("Downloading remaining requirements")
download_requirements(REPO_ROOT / "requirements.txt")

print("Building shark-turbine")
build_wheel(REPO_ROOT)
print("Building iree-turbine")
build_wheel(REPO_ROOT, env={"TURBINE_PACKAGE_NAME": "iree-turbine"})
build_wheel(REPO_ROOT)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/aot_mlp/mlp_export_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn

import shark_turbine.aot as aot
import iree.turbine.aot as aot


class MLP(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion examples/aot_mlp/mlp_export_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn

import shark_turbine.aot as aot
import iree.turbine.aot as aot


class MLP(nn.Module):
Expand Down
47 changes: 0 additions & 47 deletions examples/llama2_inference/README.md

This file was deleted.

503 changes: 0 additions & 503 deletions examples/llama2_inference/llama2.ipynb

This file was deleted.

1 change: 0 additions & 1 deletion examples/llama2_inference/llama2_state_schema.json

This file was deleted.

4 changes: 0 additions & 4 deletions examples/llama2_inference/requirements.txt

This file was deleted.

2 changes: 1 addition & 1 deletion examples/resnet-18/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
transformers
shark_turbine==0.9.2
iree_turbine==0.9.2
2 changes: 1 addition & 1 deletion examples/resnet-18/resnet-18.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from shark_turbine.aot import *
from iree.turbine.aot import *
import iree.runtime as rt

# Loading feature extractor and pretrained model from huggingface
Expand Down
4 changes: 2 additions & 2 deletions examples/runtime_torture/launchable_torture.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import torch
import torch.nn as nn

import shark_turbine.aot as aot
import iree.turbine.aot as aot

from shark_turbine.runtime import (
from iree.turbine.runtime import (
Launchable,
)

Expand Down
12 changes: 0 additions & 12 deletions iree/turbine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,3 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# TODO: This redirection layer exists while we are migrating from the
# shark_turbine top-level package name to iree.turbine. It exports the
# public API but not the internal details. In a future switch, all code
# will be directly located here and the redirect will be done in the
# shark_turbine namespace.

from shark_turbine import aot
from shark_turbine import dynamo
from shark_turbine import kernel
from shark_turbine import ops
from shark_turbine import runtime
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ def _get_device_state() -> DeviceState:
return DeviceState(driver="local-task")


# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237
# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py
# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/iree/turbine/aot/builtins/jittable.py#L212-L237
# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/iree/turbine/dynamo/backends/cpu.py
# TODO: Try to generalize for other devices.
def compute_method(super_fn, *args, **kwargs):
# Compute factory fns reserve the last arg as src_op
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
scf_d,
vector_d,
)
from shark_turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type
from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type

# TK infrastructure imports.
from shark_turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.lang.global_symbols import *
from ..ops.wave_ops import (
write,
broadcast,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...support.logging import get_logger
from shark_turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
from ..ops.wave_ops import *
from ..lang.global_symbols import *
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...support.logging import get_logger
from shark_turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
from ..ops.wave_ops import *
from ..lang.global_symbols import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
TilingConstraint,
)
import torch.fx as fx
import shark_turbine.kernel.lang as tkl
import iree.turbine.kernel.lang as tkl


import tempfile
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
import shark_turbine.kernel.lang as tkl
import iree.turbine.kernel.lang as tkl
from .._support.tracing import (
CapturedTrace,
CompiledContext,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import re

from shark_turbine.support.ir_imports import *
from iree.turbine.support.ir_imports import *

from ..rewriter import *
from iree.compiler.ir import Context, DictAttr
Expand Down
File renamed without changes.
File renamed without changes.
24 changes: 12 additions & 12 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import logging
from typing import Callable
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders
from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers
from shark_turbine.kernel.wave.hoisting import hoist_allocs
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import *
from shark_turbine.kernel.wave.utils import run_test, print_trace
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
from iree.turbine.kernel.ops.wave_ops import *
from iree.turbine.kernel.wave.utils import run_test, print_trace


def get_read_nodes(graph: fx.Graph) -> list[CustomOp]:
Expand Down
10 changes: 5 additions & 5 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pytest
from typing import Callable
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel.wave.utils import run_test
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import run_test
import torch

M = tkl.sym.M
Expand Down
14 changes: 7 additions & 7 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import logging
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel.wave.utils import run_test, print_trace
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel._support.indexing import IndexingContext
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils import run_test, print_trace
import sympy

# Input sizes
Expand Down
28 changes: 14 additions & 14 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

import logging
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.promotion import promote_placeholders
from shark_turbine.kernel.wave.hoisting import hoist_allocs
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import *
from shark_turbine.kernel.wave.utils import run_test, print_trace
from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from shark_turbine.kernel.wave.shared_memory_indexing import (
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
from iree.turbine.kernel.ops.wave_ops import *
from iree.turbine.kernel.wave.utils import run_test, print_trace
from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from iree.turbine.kernel.wave.shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
)
from shark_turbine.kernel.wave.index_sequence_analysis import (
from iree.turbine.kernel.wave.index_sequence_analysis import (
partition_strided_operators,
)

Expand Down
30 changes: 15 additions & 15 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
import logging
from typing import Callable
import unittest
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.promotion import promote_placeholders
from shark_turbine.kernel.wave.hoisting import hoist_allocs
from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers
from shark_turbine.kernel.wave.expansion import expand_graph
from shark_turbine.kernel.lang.global_symbols import *
from shark_turbine.kernel._support.tracing import CapturedTrace
from shark_turbine.kernel._support.indexing import IndexingContext
from shark_turbine.kernel.ops.wave_ops import *
from shark_turbine.kernel.wave.utils import run_test, print_trace
from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from shark_turbine.kernel.wave.visualization import visualize_graph
from shark_turbine.kernel.wave.shared_memory_indexing import (
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
from iree.turbine.kernel.ops.wave_ops import *
from iree.turbine.kernel.wave.utils import run_test, print_trace
from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads
from iree.turbine.kernel.wave.visualization import visualize_graph
from iree.turbine.kernel.wave.shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
)

Expand Down
Loading

0 comments on commit 40016ad

Please sign in to comment.