From 40016ad258936df3429edb4b9256819bcd48e9ee Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 5 Oct 2024 19:30:17 +0300 Subject: [PATCH] Rename `shark-turbine` -> `iree.turbine` (#197) * Move files from files from `shark-turbine` to `iree/turbine`. * Update imports * Update `setup.py` * Make backward redirect `shark-turbine` -> `iree.turbine` (do we need this?) Progress on #28 --------- Signed-off-by: Ivan Butygin --- MANIFEST.in | 2 +- README.md | 2 +- build_tools/build_release.py | 4 +- examples/aot_mlp/mlp_export_dynamic.py | 2 +- examples/aot_mlp/mlp_export_simple.py | 2 +- examples/llama2_inference/README.md | 47 -- examples/llama2_inference/llama2.ipynb | 503 ------------------ .../llama2_inference/llama2_state_schema.json | 1 - examples/llama2_inference/requirements.txt | 4 - examples/resnet-18/requirements.txt | 2 +- examples/resnet-18/resnet-18.py | 2 +- .../runtime_torture/launchable_torture.py | 4 +- iree/turbine/__init__.py | 12 - .../turbine}/aot/__init__.py | 0 .../turbine}/aot/builtins/__init__.py | 0 .../turbine}/aot/builtins/globals.py | 0 .../turbine}/aot/builtins/jittable.py | 0 .../turbine}/aot/compiled_module.py | 0 .../turbine}/aot/decompositions.py | 0 .../turbine}/aot/exporter.py | 0 .../turbine}/aot/fx_programs.py | 0 {shark_turbine => iree/turbine}/aot/params.py | 0 .../turbine}/aot/passes/__init__.py | 0 .../turbine}/aot/passes/functorch.py | 0 .../turbine}/aot/support/ir_utils.py | 0 .../aot/support/procedural/__init__.py | 0 .../turbine}/aot/support/procedural/base.py | 0 .../support/procedural/exported_program.py | 0 .../aot/support/procedural/globals.py | 0 .../aot/support/procedural/iree_emitter.py | 0 .../aot/support/procedural/primitives.py | 0 .../turbine}/aot/support/procedural/tracer.py | 0 .../turbine}/aot/tensor_traits.py | 0 .../turbine}/dynamo/__init__.py | 0 .../turbine}/dynamo/backends/cpu.py | 0 .../turbine}/dynamo/decompositions.py | 0 .../turbine}/dynamo/executor.py | 0 .../turbine}/dynamo/passes.py | 0 .../turbine}/dynamo/tensor.py | 4 +- .../turbine}/dynamo/type_conversion.py | 0 .../turbine}/importers/README.md | 0 .../turbine}/importers/ir.py | 0 .../turbine}/importers/utils.py | 0 .../turbine}/kernel/__init__.py | 0 .../turbine}/kernel/_support/context.py | 0 .../turbine}/kernel/_support/dtype.py | 0 .../turbine}/kernel/_support/indexing.py | 0 .../turbine}/kernel/_support/regions.py | 0 .../turbine}/kernel/_support/shaped_type.py | 0 .../turbine}/kernel/_support/tracing.py | 0 .../turbine}/kernel/compiler/base.py | 0 .../turbine}/kernel/compiler/builder.py | 0 .../kernel/compiler/dispatch_codegen.py | 0 .../turbine}/kernel/compiler/host_codegen.py | 0 .../turbine}/kernel/compiler/ir.py | 0 .../kernel/compiler/kernel_codegen.py | 0 .../turbine}/kernel/compiler/op_matchers.py | 0 .../turbine}/kernel/compiler/utils.py | 0 .../kernel/compiler/vector_codegen.py | 0 .../turbine}/kernel/gen/__init__.py | 0 .../turbine}/kernel/gen/kernel.py | 0 .../turbine}/kernel/gen/thread.py | 0 .../turbine}/kernel/lang/__init__.py | 0 .../turbine}/kernel/lang/global_symbols.py | 0 .../turbine}/kernel/lang/grid.py | 0 .../turbine}/kernel/lang/kernel_buffer.py | 0 .../turbine}/kernel/lang/prims.py | 0 .../turbine}/kernel/lang/types.py | 0 .../turbine}/kernel/lang/wave_types.py | 0 .../turbine}/kernel/ops/__init__.py | 0 .../turbine}/kernel/ops/base.py | 0 .../turbine}/kernel/ops/control_flow.py | 0 .../turbine}/kernel/ops/core.py | 0 .../turbine}/kernel/ops/math.py | 0 .../turbine}/kernel/ops/memory.py | 0 .../turbine}/kernel/ops/reduction.py | 0 .../turbine}/kernel/ops/shape_manipulation.py | 0 .../turbine}/kernel/ops/wave_ops.py | 0 .../turbine}/kernel/wave/README.md | 0 .../turbine}/kernel/wave/__init__.py | 0 .../turbine}/kernel/wave/barriers.py | 0 .../turbine}/kernel/wave/codegen.py | 4 +- .../turbine}/kernel/wave/constraints.py | 0 .../kernel/wave/decompose_reduce_ops.py | 0 .../turbine}/kernel/wave/docs/gemm_example.md | 0 .../turbine}/kernel/wave/expansion.py | 0 .../turbine}/kernel/wave/hoisting.py | 2 +- .../kernel/wave/index_sequence_analysis.py | 0 .../turbine}/kernel/wave/iree_utils.py | 0 .../kernel/wave/minimize_global_loads.py | 0 .../turbine}/kernel/wave/promotion.py | 0 .../kernel/wave/scheduling/__init__.py | 0 .../kernel/wave/scheduling/graph_utils.py | 0 .../wave/scheduling/loop_reconstruction.py | 0 .../scheduling/loop_reconstruction_utils.py | 0 .../wave/scheduling/modulo_scheduling.py | 0 .../kernel/wave/scheduling/resources.py | 0 .../kernel/wave/scheduling/schedule.py | 0 .../kernel/wave/shared_memory_indexing.py | 0 .../kernel/wave/thread_shape_analysis.py | 2 +- .../turbine}/kernel/wave/utils.py | 2 +- .../turbine}/kernel/wave/visualization.py | 0 .../turbine}/kernel/wave/wave.py | 2 +- .../turbine}/kernel/wave/wave_sim.py | 0 .../turbine}/ops/__init__.py | 0 .../turbine}/ops/_jinja_test_ops.py | 0 .../turbine}/ops/_str_format_test_ops.py | 0 {shark_turbine => iree/turbine}/ops/iree.py | 0 .../ops/templates/test_add_jinja.mlir | 0 .../ops/templates/test_add_strformat.mlir | 0 .../ops/templates/test_syntax_error.mlir | 0 .../turbine}/runtime/__init__.py | 0 .../turbine}/runtime/device.py | 0 .../turbine}/runtime/launch.py | 0 .../turbine}/runtime/op_reg/__init__.py | 0 .../turbine}/runtime/op_reg/base.py | 0 .../turbine}/runtime/op_reg/compiler.py | 0 .../turbine}/runtime/op_reg/eager.py | 0 .../turbine}/runtime/op_reg/impl_helper.py | 0 .../turbine}/runtime/tracing.py | 0 .../turbine}/support/__init__.py | 0 .../turbine}/support/conversions.py | 0 .../turbine}/support/debugging.py | 0 .../turbine}/support/exceptions.py | 0 .../turbine}/support/ir_imports.py | 0 .../turbine}/support/logging.py | 0 .../turbine}/tools/__init__.py | 0 .../turbine}/tools/interpreter.py | 0 .../turbine}/transforms/builder.py | 0 .../transforms/general/add_metadata.py | 2 +- .../transforms/general/custom_op_expansion.py | 0 .../transforms/general/rename_parameters.py | 0 .../turbine}/transforms/merger.py | 0 .../transforms/quantization/mm_group_quant.py | 0 .../turbine}/transforms/rewriter.py | 0 lit_tests/kernel/wave/barriers.py | 24 +- lit_tests/kernel/wave/codegen.py | 10 +- lit_tests/kernel/wave/expansion.py | 14 +- .../kernel/wave/index_sequence_analysis.py | 28 +- .../kernel/wave/minimize_global_loads.py | 30 +- lit_tests/kernel/wave/promotion.py | 20 +- lit_tests/kernel/wave/scheduling.py | 28 +- lit_tests/kernel/wave/tracing.py | 10 +- lit_tests/lit.cfg.py | 2 +- mypy.ini | 8 +- setup.py | 15 +- shark_turbine/__init__.py | 13 + tests/aot/api_test.py | 2 +- tests/aot/args_test.py | 2 +- tests/aot/compiled_exported_program_test.py | 4 +- tests/aot/decompositions_test.py | 2 +- tests/aot/dynamic_shape_export_test.py | 2 +- tests/aot/functionalize_test.py | 2 +- tests/aot/fx_programs_test.py | 2 +- tests/aot/globals_test.py | 2 +- tests/aot/iree_procedural_test.py | 2 +- tests/aot/jittable_test.py | 2 +- tests/aot/non_strict_export_test.py | 2 +- tests/aot/params_test.py | 2 +- tests/dynamo/importer_dynamic_test.py | 2 +- tests/dynamo/tensor_test.py | 4 +- tests/dynamo/type_conversion_test.py | 2 +- tests/generated/evaluate.py | 2 +- tests/kernel/aot_kernel_test.py | 6 +- tests/kernel/arith_test.py | 8 +- tests/kernel/compiler/utils_test.py | 6 +- tests/kernel/dispatch_codegen_test.py | 8 +- tests/kernel/fused_attention_test.py | 4 +- tests/kernel/indexing_test.py | 4 +- tests/kernel/simple_kernel_test.py | 20 +- tests/kernel/types_test.py | 2 +- tests/kernel/vector_codegen_test.py | 4 +- tests/kernel/wave/constraints_test.py | 4 +- tests/kernel/wave/scheduling_test.py | 30 +- tests/kernel/wave/types_test.py | 6 +- tests/kernel/wave/visualization_test.py | 18 +- tests/kernel/wave/wave_e2e_test.py | 12 +- tests/kernel/wave/wave_gemm_test.py | 10 +- tests/kernel/wave/wave_sim_test.py | 6 +- tests/kernel/wave/wave_utils_test.py | 4 +- tests/ops/iree_test.py | 4 +- tests/runtime/device_test.py | 12 +- tests/runtime/launch_test.py | 4 +- tests/runtime/op_reg/impl_helper_test.py | 2 +- tests/runtime/op_reg/kernel_aot_test.py | 6 +- tests/runtime/op_reg/kernel_reg_test.py | 4 +- tests/tools/interpreter_test.py | 10 +- tests/top_level_package_test.py | 4 +- tests/transforms/general/add_metadata_test.py | 2 +- .../general/custom_op_expansion_test.py | 6 +- .../general/rename_parameters_test.py | 4 +- .../quantization/mm_group_quant_test.py | 4 +- 192 files changed, 250 insertions(+), 807 deletions(-) delete mode 100644 examples/llama2_inference/README.md delete mode 100644 examples/llama2_inference/llama2.ipynb delete mode 100644 examples/llama2_inference/llama2_state_schema.json delete mode 100644 examples/llama2_inference/requirements.txt rename {shark_turbine => iree/turbine}/aot/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/builtins/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/builtins/globals.py (100%) rename {shark_turbine => iree/turbine}/aot/builtins/jittable.py (100%) rename {shark_turbine => iree/turbine}/aot/compiled_module.py (100%) rename {shark_turbine => iree/turbine}/aot/decompositions.py (100%) rename {shark_turbine => iree/turbine}/aot/exporter.py (100%) rename {shark_turbine => iree/turbine}/aot/fx_programs.py (100%) rename {shark_turbine => iree/turbine}/aot/params.py (100%) rename {shark_turbine => iree/turbine}/aot/passes/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/passes/functorch.py (100%) rename {shark_turbine => iree/turbine}/aot/support/ir_utils.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/__init__.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/base.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/exported_program.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/globals.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/iree_emitter.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/primitives.py (100%) rename {shark_turbine => iree/turbine}/aot/support/procedural/tracer.py (100%) rename {shark_turbine => iree/turbine}/aot/tensor_traits.py (100%) rename {shark_turbine => iree/turbine}/dynamo/__init__.py (100%) rename {shark_turbine => iree/turbine}/dynamo/backends/cpu.py (100%) rename {shark_turbine => iree/turbine}/dynamo/decompositions.py (100%) rename {shark_turbine => iree/turbine}/dynamo/executor.py (100%) rename {shark_turbine => iree/turbine}/dynamo/passes.py (100%) rename {shark_turbine => iree/turbine}/dynamo/tensor.py (99%) rename {shark_turbine => iree/turbine}/dynamo/type_conversion.py (100%) rename {shark_turbine => iree/turbine}/importers/README.md (100%) rename {shark_turbine => iree/turbine}/importers/ir.py (100%) rename {shark_turbine => iree/turbine}/importers/utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/context.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/dtype.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/indexing.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/regions.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/shaped_type.py (100%) rename {shark_turbine => iree/turbine}/kernel/_support/tracing.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/base.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/builder.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/dispatch_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/host_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/ir.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/kernel_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/op_matchers.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/compiler/vector_codegen.py (100%) rename {shark_turbine => iree/turbine}/kernel/gen/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/gen/kernel.py (100%) rename {shark_turbine => iree/turbine}/kernel/gen/thread.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/global_symbols.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/grid.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/kernel_buffer.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/prims.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/types.py (100%) rename {shark_turbine => iree/turbine}/kernel/lang/wave_types.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/base.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/control_flow.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/core.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/math.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/memory.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/reduction.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/shape_manipulation.py (100%) rename {shark_turbine => iree/turbine}/kernel/ops/wave_ops.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/README.md (100%) rename {shark_turbine => iree/turbine}/kernel/wave/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/barriers.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/codegen.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/constraints.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/decompose_reduce_ops.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/docs/gemm_example.md (100%) rename {shark_turbine => iree/turbine}/kernel/wave/expansion.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/hoisting.py (95%) rename {shark_turbine => iree/turbine}/kernel/wave/index_sequence_analysis.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/iree_utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/minimize_global_loads.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/promotion.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/__init__.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/graph_utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/loop_reconstruction.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/loop_reconstruction_utils.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/modulo_scheduling.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/resources.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/scheduling/schedule.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/shared_memory_indexing.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/thread_shape_analysis.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/utils.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/visualization.py (100%) rename {shark_turbine => iree/turbine}/kernel/wave/wave.py (99%) rename {shark_turbine => iree/turbine}/kernel/wave/wave_sim.py (100%) rename {shark_turbine => iree/turbine}/ops/__init__.py (100%) rename {shark_turbine => iree/turbine}/ops/_jinja_test_ops.py (100%) rename {shark_turbine => iree/turbine}/ops/_str_format_test_ops.py (100%) rename {shark_turbine => iree/turbine}/ops/iree.py (100%) rename {shark_turbine => iree/turbine}/ops/templates/test_add_jinja.mlir (100%) rename {shark_turbine => iree/turbine}/ops/templates/test_add_strformat.mlir (100%) rename {shark_turbine => iree/turbine}/ops/templates/test_syntax_error.mlir (100%) rename {shark_turbine => iree/turbine}/runtime/__init__.py (100%) rename {shark_turbine => iree/turbine}/runtime/device.py (100%) rename {shark_turbine => iree/turbine}/runtime/launch.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/__init__.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/base.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/compiler.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/eager.py (100%) rename {shark_turbine => iree/turbine}/runtime/op_reg/impl_helper.py (100%) rename {shark_turbine => iree/turbine}/runtime/tracing.py (100%) rename {shark_turbine => iree/turbine}/support/__init__.py (100%) rename {shark_turbine => iree/turbine}/support/conversions.py (100%) rename {shark_turbine => iree/turbine}/support/debugging.py (100%) rename {shark_turbine => iree/turbine}/support/exceptions.py (100%) rename {shark_turbine => iree/turbine}/support/ir_imports.py (100%) rename {shark_turbine => iree/turbine}/support/logging.py (100%) rename {shark_turbine => iree/turbine}/tools/__init__.py (100%) rename {shark_turbine => iree/turbine}/tools/interpreter.py (100%) rename {shark_turbine => iree/turbine}/transforms/builder.py (100%) rename {shark_turbine => iree/turbine}/transforms/general/add_metadata.py (97%) rename {shark_turbine => iree/turbine}/transforms/general/custom_op_expansion.py (100%) rename {shark_turbine => iree/turbine}/transforms/general/rename_parameters.py (100%) rename {shark_turbine => iree/turbine}/transforms/merger.py (100%) rename {shark_turbine => iree/turbine}/transforms/quantization/mm_group_quant.py (100%) rename {shark_turbine => iree/turbine}/transforms/rewriter.py (100%) create mode 100644 shark_turbine/__init__.py diff --git a/MANIFEST.in b/MANIFEST.in index 97971bba..65338637 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,4 @@ include README.md include requirements.txt include pytorch-cpu-requirements.txt include version_info.json -include shark_turbine/ops/templates/*.mlir +include iree/turbine/ops/templates/*.mlir diff --git a/README.md b/README.md index 4d0d0c22..aa01b826 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Turbine provides a collection of tools: * *AOT Export*: For compiling one or more `nn.Module`s to compiled, deployment ready artifacts. This operates via both a simple one-shot export API (Already upstreamed to [torch-mlir](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py)) - for simple models and an underlying [advanced API](shark_turbine/aot/compiled_module.py) for complicated models + for simple models and an underlying [advanced API](iree/turbine/aot/compiled_module.py) for complicated models and accessing the full features of the runtime. * *Eager Execution*: A `torch.compile` backend is provided and a Turbine Tensor/Device is available for more native, interactive use within a PyTorch session. diff --git a/build_tools/build_release.py b/build_tools/build_release.py index 5a6ef98d..5a90a7cf 100755 --- a/build_tools/build_release.py +++ b/build_tools/build_release.py @@ -159,10 +159,8 @@ def main(): print("Downloading remaining requirements") download_requirements(REPO_ROOT / "requirements.txt") - print("Building shark-turbine") - build_wheel(REPO_ROOT) print("Building iree-turbine") - build_wheel(REPO_ROOT, env={"TURBINE_PACKAGE_NAME": "iree-turbine"}) + build_wheel(REPO_ROOT) if __name__ == "__main__": diff --git a/examples/aot_mlp/mlp_export_dynamic.py b/examples/aot_mlp/mlp_export_dynamic.py index cd863655..3bedd7c1 100644 --- a/examples/aot_mlp/mlp_export_dynamic.py +++ b/examples/aot_mlp/mlp_export_dynamic.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot class MLP(nn.Module): diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py index fed4795d..30d7ae95 100644 --- a/examples/aot_mlp/mlp_export_simple.py +++ b/examples/aot_mlp/mlp_export_simple.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot class MLP(nn.Module): diff --git a/examples/llama2_inference/README.md b/examples/llama2_inference/README.md deleted file mode 100644 index 50bc6537..00000000 --- a/examples/llama2_inference/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# LLAMA 2 Inference - -This example require some extra dependencies. Here's an easy way to get it running on a fresh server. - -Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens - -```bash -#!/bin/bash - - -# if you don't insert it, you will be prompted to log in later; -# you may need to rerun this script after logging in -YOUR_HF_TOKEN="insert token for headless" - -# clone and install dependencies -sudo apt install -y git -git clone https://github.com/nod-ai/SHARK-Turbine.git -cd SHARK-Turbine -pip install -r requirements.txt -pip install --update "huggingface_hub[cli]" transformers sentencepiece protobuf - -# do an editable install from the cloned SHARK-Turbine -pip install --editable . - -# Log in with Hugging Face CLI if token setup is required -if [[ $YOUR_HF_TOKEN == hf_* ]]; then - huggingface login --token $YOUR_HF_TOKEN - echo "Logged in with YOUR_HF_TOKEN." -elif [ -f ~/.cache/huggingface/token ]; then - # Read token from the file - TOKEN_CONTENT=$(cat ~/.cache/huggingface/token) - - # Check if the token starts with "hf_" - if [[ $TOKEN_CONTENT == hf_* ]]; then - echo "Already logged in with a Hugging Face token." - else - echo "Token in file does not start with 'hf_'. Please log into huggingface to download models." - huggingface-cli login - fi -else - echo "Please log into huggingface to download models." - huggingface-cli login -fi - -# Step 7: Run the Python script -python examples/llama2_inference/stateless_llama.py -``` diff --git a/examples/llama2_inference/llama2.ipynb b/examples/llama2_inference/llama2.ipynb deleted file mode 100644 index b008bbd2..00000000 --- a/examples/llama2_inference/llama2.ipynb +++ /dev/null @@ -1,503 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "c0c9f034-7af1-4dc2-bbfb-5bb9e27c07ca", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", - "import torch\n", - "from torch.utils import _pytree as pytree\n", - "from shark_turbine.aot import *\n", - "from iree.compiler.ir import Context\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "4d92bb47-2b93-4f32-a445-c0ad2adc37ad", - "metadata": {}, - "outputs": [], - "source": [ - "#set some config values\n", - "\n", - "hf_auth_token = \"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\"\n", - "hf_model_name = \"meta-llama/Llama-2-7b-chat-hf\"\n", - "state_schema_path = \"llama2_state_schema.json\"\n", - "with open(state_schema_path, \"r+\") as f:\n", - " state_schema = pytree.treespec_loads(f.read())\n", - "prompt = \"\"\"\n", - "[INST] <>\n", - "Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST]\n", - "\"\"\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "d4664585-5e15-45c7-8c5c-c8eaf6381435", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:640: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n", - "/home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py:479: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5e411acda19c4228b008ff622bdf110e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00.5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:26 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,234] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:72 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,409] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s2, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:118 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:33,707] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s3 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:189 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:33,845] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s3, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:228 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:33,878] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s4, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:235 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,188] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s5 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:306 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,326] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s5, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:345 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,359] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s6, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:352 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:34,661] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s7 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:423 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:34,800] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s7, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:462 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:34,832] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s8, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:469 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,130] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s9 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:540 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,271] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s9, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:579 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,305] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s10, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:586 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:35,611] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s11 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:657 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:35,762] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s11, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:696 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:35,795] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s12, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:703 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,107] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s13 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:774 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s13, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:813 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,282] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s14, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:820 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:36,589] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s15 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:891 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:36,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s15, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:930 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:36,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s16, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:937 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,105] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s17 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1008 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,249] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s17, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1047 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,286] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s18, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1054 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:37,595] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s19 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1125 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:37,744] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s19, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1164 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:37,778] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s20, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1171 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,090] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s21 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1242 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,238] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s21, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1281 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,272] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s22, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1288 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:38,584] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s23 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1359 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:38,734] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s23, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1398 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:38,768] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s24, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1405 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,086] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s25 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1476 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,239] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s25, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1515 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,274] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s26, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1522 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:39,597] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s27 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1593 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:39,759] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s27, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1632 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:39,812] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s28, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1639 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:40,330] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s29 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1710 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:40,534] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s29, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1749 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:40,582] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s30, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1756 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,068] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s31 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1827 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,242] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s31, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1866 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:41,280] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s32, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1873 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:41,686] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s33 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1944 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:41,968] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s33, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1983 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,004] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s34, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:1990 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:42,419] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s35 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2061 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:42,580] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s35, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2100 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:42,618] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s36, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2107 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,002] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s37 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2178 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,174] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s37, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2217 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,215] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s38, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2224 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:43,566] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s39 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2295 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:43,738] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s39, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2334 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:43,776] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s40, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2341 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,116] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s41 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2412 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,281] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s41, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2451 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,320] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s42, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2458 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:44,656] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s43 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2529 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:44,822] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s43, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2568 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:44,860] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s44, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2575 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,218] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s45 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2646 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,387] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s45, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2685 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,426] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s46, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2692 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:45,772] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s47 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2763 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:45,943] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s47, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2802 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:45,983] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s48, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2809 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,376] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s49 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2880 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:46,563] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s49, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2919 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:46,605] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s50, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2926 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:46,962] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s51 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:2997 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,136] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s51, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3036 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,176] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s52, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3043 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:47,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s53 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3114 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:47,718] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s53, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3153 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:47,758] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s54, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3160 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,125] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s55 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3231 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,308] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s55, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3270 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,349] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s56, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3277 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:48,715] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s57 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3348 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:48,897] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s57, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3387 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:48,937] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s58, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3394 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,317] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s59 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3465 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:49,499] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s59, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3504 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:49,540] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s60, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3511 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:49,915] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s61 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3582 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,113] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s61, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3621 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,155] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s62, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3628 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:50,515] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s63 <= 4096 [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3699 in forward (_decomp/decompositions.py:725 in slice_forward)\n", - "[2023-10-09 18:49:50,697] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s63, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3738 in forward (_subclasses/fake_tensor.py:740 in infer_size)\n", - "[2023-10-09 18:49:50,737] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] eval Eq(s0 + s64, s0 + s1) [guard added] at .5 from /home/dan/SHARK-Turbine/turbine_venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped:3745 in forward (_meta_registrations.py:3515 in common_meta_baddbmm_bmm)\n", - "[2023-10-09 18:49:53,791] [1/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards\n", - "[2023-10-09 18:49:54,155] torch.fx.experimental.symbolic_shapes: [WARNING] Ignored guard s0 + s1 > 4096 == False, this could result in accuracy problems\n", - "[2023-10-09 18:49:54,157] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 + s1 <= 4096 [guard added] (_decomp/decompositions.py:725 in slice_forward)\n" - ] - } - ], - "source": [ - "#Run the export pipeline\n", - "inst = StateUpdateModule(context=Context(), import_to=\"IMPORT\")\n", - "module_str = str(CompiledModule.get_mlir_module(inst))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "bc04e1db-a8cc-4182-884d-ba3d8ae5adeb", - "metadata": {}, - "outputs": [], - "source": [ - "#Output a torch-ir mlir file\n", - "with open(\"llama2_torch.mlir\", \"w+\") as f:\n", - " f.write(module_str)\n", - "#TODO: run the rest of the compile pipeline and do inference" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/llama2_inference/llama2_state_schema.json b/examples/llama2_inference/llama2_state_schema.json deleted file mode 100644 index b5506055..00000000 --- a/examples/llama2_inference/llama2_state_schema.json +++ /dev/null @@ -1 +0,0 @@ -[1, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}, {"type": "builtins.tuple", "context": "null", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]}] diff --git a/examples/llama2_inference/requirements.txt b/examples/llama2_inference/requirements.txt deleted file mode 100644 index acbc93ca..00000000 --- a/examples/llama2_inference/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -protobuf -sentencepiece -shark_turbine -transformers @ git+https://github.com/huggingface/transformers.git@7d8ff3629b2725ec43ace99c1a6e87ac1978d433 diff --git a/examples/resnet-18/requirements.txt b/examples/resnet-18/requirements.txt index a5123e97..b7428649 100644 --- a/examples/resnet-18/requirements.txt +++ b/examples/resnet-18/requirements.txt @@ -1,2 +1,2 @@ transformers -shark_turbine==0.9.2 +iree_turbine==0.9.2 diff --git a/examples/resnet-18/resnet-18.py b/examples/resnet-18/resnet-18.py index 20340013..2b3fce56 100644 --- a/examples/resnet-18/resnet-18.py +++ b/examples/resnet-18/resnet-18.py @@ -1,6 +1,6 @@ from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch -from shark_turbine.aot import * +from iree.turbine.aot import * import iree.runtime as rt # Loading feature extractor and pretrained model from huggingface diff --git a/examples/runtime_torture/launchable_torture.py b/examples/runtime_torture/launchable_torture.py index 56f92a99..d58c6a80 100644 --- a/examples/runtime_torture/launchable_torture.py +++ b/examples/runtime_torture/launchable_torture.py @@ -12,9 +12,9 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot +import iree.turbine.aot as aot -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Launchable, ) diff --git a/iree/turbine/__init__.py b/iree/turbine/__init__.py index c59e85c2..d95aa54f 100644 --- a/iree/turbine/__init__.py +++ b/iree/turbine/__init__.py @@ -8,15 +8,3 @@ # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# TODO: This redirection layer exists while we are migrating from the -# shark_turbine top-level package name to iree.turbine. It exports the -# public API but not the internal details. In a future switch, all code -# will be directly located here and the redirect will be done in the -# shark_turbine namespace. - -from shark_turbine import aot -from shark_turbine import dynamo -from shark_turbine import kernel -from shark_turbine import ops -from shark_turbine import runtime diff --git a/shark_turbine/aot/__init__.py b/iree/turbine/aot/__init__.py similarity index 100% rename from shark_turbine/aot/__init__.py rename to iree/turbine/aot/__init__.py diff --git a/shark_turbine/aot/builtins/__init__.py b/iree/turbine/aot/builtins/__init__.py similarity index 100% rename from shark_turbine/aot/builtins/__init__.py rename to iree/turbine/aot/builtins/__init__.py diff --git a/shark_turbine/aot/builtins/globals.py b/iree/turbine/aot/builtins/globals.py similarity index 100% rename from shark_turbine/aot/builtins/globals.py rename to iree/turbine/aot/builtins/globals.py diff --git a/shark_turbine/aot/builtins/jittable.py b/iree/turbine/aot/builtins/jittable.py similarity index 100% rename from shark_turbine/aot/builtins/jittable.py rename to iree/turbine/aot/builtins/jittable.py diff --git a/shark_turbine/aot/compiled_module.py b/iree/turbine/aot/compiled_module.py similarity index 100% rename from shark_turbine/aot/compiled_module.py rename to iree/turbine/aot/compiled_module.py diff --git a/shark_turbine/aot/decompositions.py b/iree/turbine/aot/decompositions.py similarity index 100% rename from shark_turbine/aot/decompositions.py rename to iree/turbine/aot/decompositions.py diff --git a/shark_turbine/aot/exporter.py b/iree/turbine/aot/exporter.py similarity index 100% rename from shark_turbine/aot/exporter.py rename to iree/turbine/aot/exporter.py diff --git a/shark_turbine/aot/fx_programs.py b/iree/turbine/aot/fx_programs.py similarity index 100% rename from shark_turbine/aot/fx_programs.py rename to iree/turbine/aot/fx_programs.py diff --git a/shark_turbine/aot/params.py b/iree/turbine/aot/params.py similarity index 100% rename from shark_turbine/aot/params.py rename to iree/turbine/aot/params.py diff --git a/shark_turbine/aot/passes/__init__.py b/iree/turbine/aot/passes/__init__.py similarity index 100% rename from shark_turbine/aot/passes/__init__.py rename to iree/turbine/aot/passes/__init__.py diff --git a/shark_turbine/aot/passes/functorch.py b/iree/turbine/aot/passes/functorch.py similarity index 100% rename from shark_turbine/aot/passes/functorch.py rename to iree/turbine/aot/passes/functorch.py diff --git a/shark_turbine/aot/support/ir_utils.py b/iree/turbine/aot/support/ir_utils.py similarity index 100% rename from shark_turbine/aot/support/ir_utils.py rename to iree/turbine/aot/support/ir_utils.py diff --git a/shark_turbine/aot/support/procedural/__init__.py b/iree/turbine/aot/support/procedural/__init__.py similarity index 100% rename from shark_turbine/aot/support/procedural/__init__.py rename to iree/turbine/aot/support/procedural/__init__.py diff --git a/shark_turbine/aot/support/procedural/base.py b/iree/turbine/aot/support/procedural/base.py similarity index 100% rename from shark_turbine/aot/support/procedural/base.py rename to iree/turbine/aot/support/procedural/base.py diff --git a/shark_turbine/aot/support/procedural/exported_program.py b/iree/turbine/aot/support/procedural/exported_program.py similarity index 100% rename from shark_turbine/aot/support/procedural/exported_program.py rename to iree/turbine/aot/support/procedural/exported_program.py diff --git a/shark_turbine/aot/support/procedural/globals.py b/iree/turbine/aot/support/procedural/globals.py similarity index 100% rename from shark_turbine/aot/support/procedural/globals.py rename to iree/turbine/aot/support/procedural/globals.py diff --git a/shark_turbine/aot/support/procedural/iree_emitter.py b/iree/turbine/aot/support/procedural/iree_emitter.py similarity index 100% rename from shark_turbine/aot/support/procedural/iree_emitter.py rename to iree/turbine/aot/support/procedural/iree_emitter.py diff --git a/shark_turbine/aot/support/procedural/primitives.py b/iree/turbine/aot/support/procedural/primitives.py similarity index 100% rename from shark_turbine/aot/support/procedural/primitives.py rename to iree/turbine/aot/support/procedural/primitives.py diff --git a/shark_turbine/aot/support/procedural/tracer.py b/iree/turbine/aot/support/procedural/tracer.py similarity index 100% rename from shark_turbine/aot/support/procedural/tracer.py rename to iree/turbine/aot/support/procedural/tracer.py diff --git a/shark_turbine/aot/tensor_traits.py b/iree/turbine/aot/tensor_traits.py similarity index 100% rename from shark_turbine/aot/tensor_traits.py rename to iree/turbine/aot/tensor_traits.py diff --git a/shark_turbine/dynamo/__init__.py b/iree/turbine/dynamo/__init__.py similarity index 100% rename from shark_turbine/dynamo/__init__.py rename to iree/turbine/dynamo/__init__.py diff --git a/shark_turbine/dynamo/backends/cpu.py b/iree/turbine/dynamo/backends/cpu.py similarity index 100% rename from shark_turbine/dynamo/backends/cpu.py rename to iree/turbine/dynamo/backends/cpu.py diff --git a/shark_turbine/dynamo/decompositions.py b/iree/turbine/dynamo/decompositions.py similarity index 100% rename from shark_turbine/dynamo/decompositions.py rename to iree/turbine/dynamo/decompositions.py diff --git a/shark_turbine/dynamo/executor.py b/iree/turbine/dynamo/executor.py similarity index 100% rename from shark_turbine/dynamo/executor.py rename to iree/turbine/dynamo/executor.py diff --git a/shark_turbine/dynamo/passes.py b/iree/turbine/dynamo/passes.py similarity index 100% rename from shark_turbine/dynamo/passes.py rename to iree/turbine/dynamo/passes.py diff --git a/shark_turbine/dynamo/tensor.py b/iree/turbine/dynamo/tensor.py similarity index 99% rename from shark_turbine/dynamo/tensor.py rename to iree/turbine/dynamo/tensor.py index cd1de1ea..bdf1cb83 100644 --- a/shark_turbine/dynamo/tensor.py +++ b/iree/turbine/dynamo/tensor.py @@ -474,8 +474,8 @@ def _get_device_state() -> DeviceState: return DeviceState(driver="local-task") -# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/shark_turbine/aot/builtins/jittable.py#L212-L237 -# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/dynamo/backends/cpu.py +# Inspiration from https://github.com/nod-ai/SHARK-Turbine/blob/8293de5414889c72ff5cd10bf33c43fb0a3ea3ee/python/iree/turbine/aot/builtins/jittable.py#L212-L237 +# and https://github.com/nod-ai/SHARK-Turbine/blob/main/python/iree/turbine/dynamo/backends/cpu.py # TODO: Try to generalize for other devices. def compute_method(super_fn, *args, **kwargs): # Compute factory fns reserve the last arg as src_op diff --git a/shark_turbine/dynamo/type_conversion.py b/iree/turbine/dynamo/type_conversion.py similarity index 100% rename from shark_turbine/dynamo/type_conversion.py rename to iree/turbine/dynamo/type_conversion.py diff --git a/shark_turbine/importers/README.md b/iree/turbine/importers/README.md similarity index 100% rename from shark_turbine/importers/README.md rename to iree/turbine/importers/README.md diff --git a/shark_turbine/importers/ir.py b/iree/turbine/importers/ir.py similarity index 100% rename from shark_turbine/importers/ir.py rename to iree/turbine/importers/ir.py diff --git a/shark_turbine/importers/utils.py b/iree/turbine/importers/utils.py similarity index 100% rename from shark_turbine/importers/utils.py rename to iree/turbine/importers/utils.py diff --git a/shark_turbine/kernel/__init__.py b/iree/turbine/kernel/__init__.py similarity index 100% rename from shark_turbine/kernel/__init__.py rename to iree/turbine/kernel/__init__.py diff --git a/shark_turbine/kernel/_support/context.py b/iree/turbine/kernel/_support/context.py similarity index 100% rename from shark_turbine/kernel/_support/context.py rename to iree/turbine/kernel/_support/context.py diff --git a/shark_turbine/kernel/_support/dtype.py b/iree/turbine/kernel/_support/dtype.py similarity index 100% rename from shark_turbine/kernel/_support/dtype.py rename to iree/turbine/kernel/_support/dtype.py diff --git a/shark_turbine/kernel/_support/indexing.py b/iree/turbine/kernel/_support/indexing.py similarity index 100% rename from shark_turbine/kernel/_support/indexing.py rename to iree/turbine/kernel/_support/indexing.py diff --git a/shark_turbine/kernel/_support/regions.py b/iree/turbine/kernel/_support/regions.py similarity index 100% rename from shark_turbine/kernel/_support/regions.py rename to iree/turbine/kernel/_support/regions.py diff --git a/shark_turbine/kernel/_support/shaped_type.py b/iree/turbine/kernel/_support/shaped_type.py similarity index 100% rename from shark_turbine/kernel/_support/shaped_type.py rename to iree/turbine/kernel/_support/shaped_type.py diff --git a/shark_turbine/kernel/_support/tracing.py b/iree/turbine/kernel/_support/tracing.py similarity index 100% rename from shark_turbine/kernel/_support/tracing.py rename to iree/turbine/kernel/_support/tracing.py diff --git a/shark_turbine/kernel/compiler/base.py b/iree/turbine/kernel/compiler/base.py similarity index 100% rename from shark_turbine/kernel/compiler/base.py rename to iree/turbine/kernel/compiler/base.py diff --git a/shark_turbine/kernel/compiler/builder.py b/iree/turbine/kernel/compiler/builder.py similarity index 100% rename from shark_turbine/kernel/compiler/builder.py rename to iree/turbine/kernel/compiler/builder.py diff --git a/shark_turbine/kernel/compiler/dispatch_codegen.py b/iree/turbine/kernel/compiler/dispatch_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/dispatch_codegen.py rename to iree/turbine/kernel/compiler/dispatch_codegen.py diff --git a/shark_turbine/kernel/compiler/host_codegen.py b/iree/turbine/kernel/compiler/host_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/host_codegen.py rename to iree/turbine/kernel/compiler/host_codegen.py diff --git a/shark_turbine/kernel/compiler/ir.py b/iree/turbine/kernel/compiler/ir.py similarity index 100% rename from shark_turbine/kernel/compiler/ir.py rename to iree/turbine/kernel/compiler/ir.py diff --git a/shark_turbine/kernel/compiler/kernel_codegen.py b/iree/turbine/kernel/compiler/kernel_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/kernel_codegen.py rename to iree/turbine/kernel/compiler/kernel_codegen.py diff --git a/shark_turbine/kernel/compiler/op_matchers.py b/iree/turbine/kernel/compiler/op_matchers.py similarity index 100% rename from shark_turbine/kernel/compiler/op_matchers.py rename to iree/turbine/kernel/compiler/op_matchers.py diff --git a/shark_turbine/kernel/compiler/utils.py b/iree/turbine/kernel/compiler/utils.py similarity index 100% rename from shark_turbine/kernel/compiler/utils.py rename to iree/turbine/kernel/compiler/utils.py diff --git a/shark_turbine/kernel/compiler/vector_codegen.py b/iree/turbine/kernel/compiler/vector_codegen.py similarity index 100% rename from shark_turbine/kernel/compiler/vector_codegen.py rename to iree/turbine/kernel/compiler/vector_codegen.py diff --git a/shark_turbine/kernel/gen/__init__.py b/iree/turbine/kernel/gen/__init__.py similarity index 100% rename from shark_turbine/kernel/gen/__init__.py rename to iree/turbine/kernel/gen/__init__.py diff --git a/shark_turbine/kernel/gen/kernel.py b/iree/turbine/kernel/gen/kernel.py similarity index 100% rename from shark_turbine/kernel/gen/kernel.py rename to iree/turbine/kernel/gen/kernel.py diff --git a/shark_turbine/kernel/gen/thread.py b/iree/turbine/kernel/gen/thread.py similarity index 100% rename from shark_turbine/kernel/gen/thread.py rename to iree/turbine/kernel/gen/thread.py diff --git a/shark_turbine/kernel/lang/__init__.py b/iree/turbine/kernel/lang/__init__.py similarity index 100% rename from shark_turbine/kernel/lang/__init__.py rename to iree/turbine/kernel/lang/__init__.py diff --git a/shark_turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py similarity index 100% rename from shark_turbine/kernel/lang/global_symbols.py rename to iree/turbine/kernel/lang/global_symbols.py diff --git a/shark_turbine/kernel/lang/grid.py b/iree/turbine/kernel/lang/grid.py similarity index 100% rename from shark_turbine/kernel/lang/grid.py rename to iree/turbine/kernel/lang/grid.py diff --git a/shark_turbine/kernel/lang/kernel_buffer.py b/iree/turbine/kernel/lang/kernel_buffer.py similarity index 100% rename from shark_turbine/kernel/lang/kernel_buffer.py rename to iree/turbine/kernel/lang/kernel_buffer.py diff --git a/shark_turbine/kernel/lang/prims.py b/iree/turbine/kernel/lang/prims.py similarity index 100% rename from shark_turbine/kernel/lang/prims.py rename to iree/turbine/kernel/lang/prims.py diff --git a/shark_turbine/kernel/lang/types.py b/iree/turbine/kernel/lang/types.py similarity index 100% rename from shark_turbine/kernel/lang/types.py rename to iree/turbine/kernel/lang/types.py diff --git a/shark_turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py similarity index 100% rename from shark_turbine/kernel/lang/wave_types.py rename to iree/turbine/kernel/lang/wave_types.py diff --git a/shark_turbine/kernel/ops/__init__.py b/iree/turbine/kernel/ops/__init__.py similarity index 100% rename from shark_turbine/kernel/ops/__init__.py rename to iree/turbine/kernel/ops/__init__.py diff --git a/shark_turbine/kernel/ops/base.py b/iree/turbine/kernel/ops/base.py similarity index 100% rename from shark_turbine/kernel/ops/base.py rename to iree/turbine/kernel/ops/base.py diff --git a/shark_turbine/kernel/ops/control_flow.py b/iree/turbine/kernel/ops/control_flow.py similarity index 100% rename from shark_turbine/kernel/ops/control_flow.py rename to iree/turbine/kernel/ops/control_flow.py diff --git a/shark_turbine/kernel/ops/core.py b/iree/turbine/kernel/ops/core.py similarity index 100% rename from shark_turbine/kernel/ops/core.py rename to iree/turbine/kernel/ops/core.py diff --git a/shark_turbine/kernel/ops/math.py b/iree/turbine/kernel/ops/math.py similarity index 100% rename from shark_turbine/kernel/ops/math.py rename to iree/turbine/kernel/ops/math.py diff --git a/shark_turbine/kernel/ops/memory.py b/iree/turbine/kernel/ops/memory.py similarity index 100% rename from shark_turbine/kernel/ops/memory.py rename to iree/turbine/kernel/ops/memory.py diff --git a/shark_turbine/kernel/ops/reduction.py b/iree/turbine/kernel/ops/reduction.py similarity index 100% rename from shark_turbine/kernel/ops/reduction.py rename to iree/turbine/kernel/ops/reduction.py diff --git a/shark_turbine/kernel/ops/shape_manipulation.py b/iree/turbine/kernel/ops/shape_manipulation.py similarity index 100% rename from shark_turbine/kernel/ops/shape_manipulation.py rename to iree/turbine/kernel/ops/shape_manipulation.py diff --git a/shark_turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py similarity index 100% rename from shark_turbine/kernel/ops/wave_ops.py rename to iree/turbine/kernel/ops/wave_ops.py diff --git a/shark_turbine/kernel/wave/README.md b/iree/turbine/kernel/wave/README.md similarity index 100% rename from shark_turbine/kernel/wave/README.md rename to iree/turbine/kernel/wave/README.md diff --git a/shark_turbine/kernel/wave/__init__.py b/iree/turbine/kernel/wave/__init__.py similarity index 100% rename from shark_turbine/kernel/wave/__init__.py rename to iree/turbine/kernel/wave/__init__.py diff --git a/shark_turbine/kernel/wave/barriers.py b/iree/turbine/kernel/wave/barriers.py similarity index 100% rename from shark_turbine/kernel/wave/barriers.py rename to iree/turbine/kernel/wave/barriers.py diff --git a/shark_turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py similarity index 99% rename from shark_turbine/kernel/wave/codegen.py rename to iree/turbine/kernel/wave/codegen.py index e218d71c..233a571d 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -40,10 +40,10 @@ scf_d, vector_d, ) -from shark_turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type +from iree.turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type # TK infrastructure imports. -from shark_turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.lang.global_symbols import * from ..ops.wave_ops import ( write, broadcast, diff --git a/shark_turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py similarity index 100% rename from shark_turbine/kernel/wave/constraints.py rename to iree/turbine/kernel/wave/constraints.py diff --git a/shark_turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py similarity index 100% rename from shark_turbine/kernel/wave/decompose_reduce_ops.py rename to iree/turbine/kernel/wave/decompose_reduce_ops.py diff --git a/shark_turbine/kernel/wave/docs/gemm_example.md b/iree/turbine/kernel/wave/docs/gemm_example.md similarity index 100% rename from shark_turbine/kernel/wave/docs/gemm_example.md rename to iree/turbine/kernel/wave/docs/gemm_example.md diff --git a/shark_turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py similarity index 100% rename from shark_turbine/kernel/wave/expansion.py rename to iree/turbine/kernel/wave/expansion.py diff --git a/shark_turbine/kernel/wave/hoisting.py b/iree/turbine/kernel/wave/hoisting.py similarity index 95% rename from shark_turbine/kernel/wave/hoisting.py rename to iree/turbine/kernel/wave/hoisting.py index df68c753..5a4773d7 100644 --- a/shark_turbine/kernel/wave/hoisting.py +++ b/iree/turbine/kernel/wave/hoisting.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ...support.logging import get_logger -from shark_turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.tracing import CapturedTrace import torch.fx as fx from ..ops.wave_ops import * from ..lang.global_symbols import * diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py similarity index 100% rename from shark_turbine/kernel/wave/index_sequence_analysis.py rename to iree/turbine/kernel/wave/index_sequence_analysis.py diff --git a/shark_turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py similarity index 100% rename from shark_turbine/kernel/wave/iree_utils.py rename to iree/turbine/kernel/wave/iree_utils.py diff --git a/shark_turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py similarity index 100% rename from shark_turbine/kernel/wave/minimize_global_loads.py rename to iree/turbine/kernel/wave/minimize_global_loads.py diff --git a/shark_turbine/kernel/wave/promotion.py b/iree/turbine/kernel/wave/promotion.py similarity index 100% rename from shark_turbine/kernel/wave/promotion.py rename to iree/turbine/kernel/wave/promotion.py diff --git a/shark_turbine/kernel/wave/scheduling/__init__.py b/iree/turbine/kernel/wave/scheduling/__init__.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/__init__.py rename to iree/turbine/kernel/wave/scheduling/__init__.py diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/iree/turbine/kernel/wave/scheduling/graph_utils.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/graph_utils.py rename to iree/turbine/kernel/wave/scheduling/graph_utils.py diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/loop_reconstruction.py rename to iree/turbine/kernel/wave/scheduling/loop_reconstruction.py diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py rename to iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/modulo_scheduling.py rename to iree/turbine/kernel/wave/scheduling/modulo_scheduling.py diff --git a/shark_turbine/kernel/wave/scheduling/resources.py b/iree/turbine/kernel/wave/scheduling/resources.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/resources.py rename to iree/turbine/kernel/wave/scheduling/resources.py diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/iree/turbine/kernel/wave/scheduling/schedule.py similarity index 100% rename from shark_turbine/kernel/wave/scheduling/schedule.py rename to iree/turbine/kernel/wave/scheduling/schedule.py diff --git a/shark_turbine/kernel/wave/shared_memory_indexing.py b/iree/turbine/kernel/wave/shared_memory_indexing.py similarity index 100% rename from shark_turbine/kernel/wave/shared_memory_indexing.py rename to iree/turbine/kernel/wave/shared_memory_indexing.py diff --git a/shark_turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py similarity index 99% rename from shark_turbine/kernel/wave/thread_shape_analysis.py rename to iree/turbine/kernel/wave/thread_shape_analysis.py index 5fd0b999..927bd363 100644 --- a/shark_turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ...support.logging import get_logger -from shark_turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.tracing import CapturedTrace import torch.fx as fx from ..ops.wave_ops import * from ..lang.global_symbols import * diff --git a/shark_turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py similarity index 99% rename from shark_turbine/kernel/wave/utils.py rename to iree/turbine/kernel/wave/utils.py index 869df061..020adf1f 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -33,7 +33,7 @@ TilingConstraint, ) import torch.fx as fx -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel.lang as tkl import tempfile diff --git a/shark_turbine/kernel/wave/visualization.py b/iree/turbine/kernel/wave/visualization.py similarity index 100% rename from shark_turbine/kernel/wave/visualization.py rename to iree/turbine/kernel/wave/visualization.py diff --git a/shark_turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py similarity index 99% rename from shark_turbine/kernel/wave/wave.py rename to iree/turbine/kernel/wave/wave.py index 202cdd92..21485ed1 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -42,7 +42,7 @@ from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, CompiledContext, diff --git a/shark_turbine/kernel/wave/wave_sim.py b/iree/turbine/kernel/wave/wave_sim.py similarity index 100% rename from shark_turbine/kernel/wave/wave_sim.py rename to iree/turbine/kernel/wave/wave_sim.py diff --git a/shark_turbine/ops/__init__.py b/iree/turbine/ops/__init__.py similarity index 100% rename from shark_turbine/ops/__init__.py rename to iree/turbine/ops/__init__.py diff --git a/shark_turbine/ops/_jinja_test_ops.py b/iree/turbine/ops/_jinja_test_ops.py similarity index 100% rename from shark_turbine/ops/_jinja_test_ops.py rename to iree/turbine/ops/_jinja_test_ops.py diff --git a/shark_turbine/ops/_str_format_test_ops.py b/iree/turbine/ops/_str_format_test_ops.py similarity index 100% rename from shark_turbine/ops/_str_format_test_ops.py rename to iree/turbine/ops/_str_format_test_ops.py diff --git a/shark_turbine/ops/iree.py b/iree/turbine/ops/iree.py similarity index 100% rename from shark_turbine/ops/iree.py rename to iree/turbine/ops/iree.py diff --git a/shark_turbine/ops/templates/test_add_jinja.mlir b/iree/turbine/ops/templates/test_add_jinja.mlir similarity index 100% rename from shark_turbine/ops/templates/test_add_jinja.mlir rename to iree/turbine/ops/templates/test_add_jinja.mlir diff --git a/shark_turbine/ops/templates/test_add_strformat.mlir b/iree/turbine/ops/templates/test_add_strformat.mlir similarity index 100% rename from shark_turbine/ops/templates/test_add_strformat.mlir rename to iree/turbine/ops/templates/test_add_strformat.mlir diff --git a/shark_turbine/ops/templates/test_syntax_error.mlir b/iree/turbine/ops/templates/test_syntax_error.mlir similarity index 100% rename from shark_turbine/ops/templates/test_syntax_error.mlir rename to iree/turbine/ops/templates/test_syntax_error.mlir diff --git a/shark_turbine/runtime/__init__.py b/iree/turbine/runtime/__init__.py similarity index 100% rename from shark_turbine/runtime/__init__.py rename to iree/turbine/runtime/__init__.py diff --git a/shark_turbine/runtime/device.py b/iree/turbine/runtime/device.py similarity index 100% rename from shark_turbine/runtime/device.py rename to iree/turbine/runtime/device.py diff --git a/shark_turbine/runtime/launch.py b/iree/turbine/runtime/launch.py similarity index 100% rename from shark_turbine/runtime/launch.py rename to iree/turbine/runtime/launch.py diff --git a/shark_turbine/runtime/op_reg/__init__.py b/iree/turbine/runtime/op_reg/__init__.py similarity index 100% rename from shark_turbine/runtime/op_reg/__init__.py rename to iree/turbine/runtime/op_reg/__init__.py diff --git a/shark_turbine/runtime/op_reg/base.py b/iree/turbine/runtime/op_reg/base.py similarity index 100% rename from shark_turbine/runtime/op_reg/base.py rename to iree/turbine/runtime/op_reg/base.py diff --git a/shark_turbine/runtime/op_reg/compiler.py b/iree/turbine/runtime/op_reg/compiler.py similarity index 100% rename from shark_turbine/runtime/op_reg/compiler.py rename to iree/turbine/runtime/op_reg/compiler.py diff --git a/shark_turbine/runtime/op_reg/eager.py b/iree/turbine/runtime/op_reg/eager.py similarity index 100% rename from shark_turbine/runtime/op_reg/eager.py rename to iree/turbine/runtime/op_reg/eager.py diff --git a/shark_turbine/runtime/op_reg/impl_helper.py b/iree/turbine/runtime/op_reg/impl_helper.py similarity index 100% rename from shark_turbine/runtime/op_reg/impl_helper.py rename to iree/turbine/runtime/op_reg/impl_helper.py diff --git a/shark_turbine/runtime/tracing.py b/iree/turbine/runtime/tracing.py similarity index 100% rename from shark_turbine/runtime/tracing.py rename to iree/turbine/runtime/tracing.py diff --git a/shark_turbine/support/__init__.py b/iree/turbine/support/__init__.py similarity index 100% rename from shark_turbine/support/__init__.py rename to iree/turbine/support/__init__.py diff --git a/shark_turbine/support/conversions.py b/iree/turbine/support/conversions.py similarity index 100% rename from shark_turbine/support/conversions.py rename to iree/turbine/support/conversions.py diff --git a/shark_turbine/support/debugging.py b/iree/turbine/support/debugging.py similarity index 100% rename from shark_turbine/support/debugging.py rename to iree/turbine/support/debugging.py diff --git a/shark_turbine/support/exceptions.py b/iree/turbine/support/exceptions.py similarity index 100% rename from shark_turbine/support/exceptions.py rename to iree/turbine/support/exceptions.py diff --git a/shark_turbine/support/ir_imports.py b/iree/turbine/support/ir_imports.py similarity index 100% rename from shark_turbine/support/ir_imports.py rename to iree/turbine/support/ir_imports.py diff --git a/shark_turbine/support/logging.py b/iree/turbine/support/logging.py similarity index 100% rename from shark_turbine/support/logging.py rename to iree/turbine/support/logging.py diff --git a/shark_turbine/tools/__init__.py b/iree/turbine/tools/__init__.py similarity index 100% rename from shark_turbine/tools/__init__.py rename to iree/turbine/tools/__init__.py diff --git a/shark_turbine/tools/interpreter.py b/iree/turbine/tools/interpreter.py similarity index 100% rename from shark_turbine/tools/interpreter.py rename to iree/turbine/tools/interpreter.py diff --git a/shark_turbine/transforms/builder.py b/iree/turbine/transforms/builder.py similarity index 100% rename from shark_turbine/transforms/builder.py rename to iree/turbine/transforms/builder.py diff --git a/shark_turbine/transforms/general/add_metadata.py b/iree/turbine/transforms/general/add_metadata.py similarity index 97% rename from shark_turbine/transforms/general/add_metadata.py rename to iree/turbine/transforms/general/add_metadata.py index 44aa2413..340169ec 100644 --- a/shark_turbine/transforms/general/add_metadata.py +++ b/iree/turbine/transforms/general/add_metadata.py @@ -12,7 +12,7 @@ import re -from shark_turbine.support.ir_imports import * +from iree.turbine.support.ir_imports import * from ..rewriter import * from iree.compiler.ir import Context, DictAttr diff --git a/shark_turbine/transforms/general/custom_op_expansion.py b/iree/turbine/transforms/general/custom_op_expansion.py similarity index 100% rename from shark_turbine/transforms/general/custom_op_expansion.py rename to iree/turbine/transforms/general/custom_op_expansion.py diff --git a/shark_turbine/transforms/general/rename_parameters.py b/iree/turbine/transforms/general/rename_parameters.py similarity index 100% rename from shark_turbine/transforms/general/rename_parameters.py rename to iree/turbine/transforms/general/rename_parameters.py diff --git a/shark_turbine/transforms/merger.py b/iree/turbine/transforms/merger.py similarity index 100% rename from shark_turbine/transforms/merger.py rename to iree/turbine/transforms/merger.py diff --git a/shark_turbine/transforms/quantization/mm_group_quant.py b/iree/turbine/transforms/quantization/mm_group_quant.py similarity index 100% rename from shark_turbine/transforms/quantization/mm_group_quant.py rename to iree/turbine/transforms/quantization/mm_group_quant.py diff --git a/shark_turbine/transforms/rewriter.py b/iree/turbine/transforms/rewriter.py similarity index 100% rename from shark_turbine/transforms/rewriter.py rename to iree/turbine/transforms/rewriter.py diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index fcb7dbed..14eb2e60 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -3,18 +3,18 @@ import logging from typing import Callable import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders -from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders +from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 4800e9bd..30144024 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -2,11 +2,11 @@ import pytest from typing import Callable -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.utils import run_test +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import run_test import torch M = tkl.sym.M diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 6f4e2f29..efcdd582 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -2,13 +2,13 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import run_test, print_trace import sympy # Input sizes diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 7dd266ee..f0149b70 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -2,22 +2,22 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) -from shark_turbine.kernel.wave.index_sequence_analysis import ( +from iree.turbine.kernel.wave.index_sequence_analysis import ( partition_strided_operators, ) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 7596a94b..329a9ccf 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -3,21 +3,21 @@ import logging from typing import Callable import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.barriers import add_shared_memory_barriers -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.visualization import visualize_graph -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.visualization import visualize_graph +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 3843c406..c3836f4f 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -2,16 +2,16 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_node, promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_trace +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_trace def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index eafabb27..aefad516 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -2,22 +2,22 @@ import logging import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import * -from shark_turbine.kernel.wave.utils import run_test, print_subgraph -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.shared_memory_indexing import ( +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import * +from iree.turbine.kernel.wave.utils import run_test, print_subgraph +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.shared_memory_indexing import ( apply_shared_memory_indexing_corrections, ) -from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph +from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph # Input sizes diff --git a/lit_tests/kernel/wave/tracing.py b/lit_tests/kernel/wave/tracing.py index 283b6436..f6c9306b 100644 --- a/lit_tests/kernel/wave/tracing.py +++ b/lit_tests/kernel/wave/tracing.py @@ -1,11 +1,11 @@ # RUN: python %s | FileCheck %s from typing import Callable -from shark_turbine.kernel._support.tracing import CapturedTrace -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.ops.wave_ops import get_custom, Read, Write -from shark_turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel._support.tracing import CapturedTrace +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.ops.wave_ops import get_custom, Read, Write +from iree.turbine.kernel.wave.utils import run_test, print_trace M = tkl.sym.M N = tkl.sym.N diff --git a/lit_tests/lit.cfg.py b/lit_tests/lit.cfg.py index 5b40c7eb..614383fc 100644 --- a/lit_tests/lit.cfg.py +++ b/lit_tests/lit.cfg.py @@ -7,7 +7,7 @@ import lit.llvm -from shark_turbine.support.logging import get_logger +from iree.turbine.support.logging import get_logger logger = get_logger("turbine.lit_tests") diff --git a/mypy.ini b/mypy.ini index 528b8d48..5638faef 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ explicit_package_bases = True mypy_path = $MYPY_CONFIG_FILE_DIR -packages = shark_turbine +packages = iree.turbine # Missing typing stubs for iree.compiler. [mypy-iree.compiler.*] @@ -13,15 +13,15 @@ ignore_missing_imports = True ignore_missing_imports = True # fx_importer needs to be fixed upstream. -[mypy-shark_turbine.importers.fx_importer.*] +[mypy-iree.turbine.importers.fx_importer.*] ignore_errors = True # TODO: Fix all typing errors in TK. -[mypy-shark_turbine.kernel.*] +[mypy-iree.turbine.kernel.*] ignore_errors = True # TODO: Some pytorch errors. -[mypy-shark_turbine.tools.interpreter] +[mypy-iree.turbine.tools.interpreter] ignore_errors = True # Ignore all typing errors in tests/tools (these depend on TK). diff --git a/setup.py b/setup.py index 63a028cb..c73c3532 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,7 @@ REPO_DIR = THIS_DIR VERSION_INFO_FILE = os.path.join(REPO_DIR, "version_info.json") -# Transitional as we migrate from shark-turbine -> iree-turbine. -TURBINE_PACKAGE_NAME = os.getenv("TURBINE_PACKAGE_NAME", "shark-turbine") +TURBINE_PACKAGE_NAME = "iree-turbine" with open( os.path.join( @@ -81,12 +80,12 @@ def initialize_options(self): setup( name=f"{TURBINE_PACKAGE_NAME}", version=f"{PACKAGE_VERSION}", - author="SHARK Authors", - author_email="stella@nod.ai", - description="SHARK Turbine Machine Learning Deployment Tools", + author="IREE Authors", + author_email="iree-technical-discussion@lists.lfaidata.foundation", + description="IREE Turbine Machine Learning Deployment Tools", long_description=README, long_description_content_type="text/markdown", - url="https://github.com/nod-ai/SHARK-Turbine", + url="https://github.com/iree-org/iree-turbine/", license="Apache-2.0", classifiers=[ "Development Status :: 5 - Production/Stable", @@ -96,11 +95,11 @@ def initialize_options(self): packages=packages, include_package_data=True, package_data={ - "shark_turbine": ["ops/templates/*.mlir"], # Include MLIR templates + "iree.turbine": ["ops/templates/*.mlir"], # Include MLIR templates }, entry_points={ "torch_dynamo_backends": [ - "turbine_cpu = shark_turbine.dynamo.backends.cpu:backend", + "turbine_cpu = iree.turbine.dynamo.backends.cpu:backend", ], }, install_requires=[ diff --git a/shark_turbine/__init__.py b/shark_turbine/__init__.py new file mode 100644 index 00000000..f1e1c318 --- /dev/null +++ b/shark_turbine/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +# Temp redirect from old shark_turbine namespace. +from iree.turbine import aot +from iree.turbine import dynamo +from iree.turbine import kernel +from iree.turbine import ops +from iree.turbine import runtime diff --git a/tests/aot/api_test.py b/tests/aot/api_test.py index e038704d..0d5f4215 100644 --- a/tests/aot/api_test.py +++ b/tests/aot/api_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * import torch import torch.nn as nn diff --git a/tests/aot/args_test.py b/tests/aot/args_test.py index d7ec458d..efbce489 100644 --- a/tests/aot/args_test.py +++ b/tests/aot/args_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class ArgsTest(unittest.TestCase): diff --git a/tests/aot/compiled_exported_program_test.py b/tests/aot/compiled_exported_program_test.py index baaeb9bb..6b86b185 100644 --- a/tests/aot/compiled_exported_program_test.py +++ b/tests/aot/compiled_exported_program_test.py @@ -14,8 +14,8 @@ Context, ) -from shark_turbine.aot import * -from shark_turbine.aot.builtins import * +from iree.turbine.aot import * +from iree.turbine.aot.builtins import * class TorchExportTests(unittest.TestCase): diff --git a/tests/aot/decompositions_test.py b/tests/aot/decompositions_test.py index baf96604..f186cf12 100644 --- a/tests/aot/decompositions_test.py +++ b/tests/aot/decompositions_test.py @@ -9,7 +9,7 @@ import logging import unittest -from shark_turbine.aot import decompositions +from iree.turbine.aot import decompositions class DecompTest(unittest.TestCase): diff --git a/tests/aot/dynamic_shape_export_test.py b/tests/aot/dynamic_shape_export_test.py index da8c11b7..8f53df27 100644 --- a/tests/aot/dynamic_shape_export_test.py +++ b/tests/aot/dynamic_shape_export_test.py @@ -2,7 +2,7 @@ import pytest -from shark_turbine.aot import * +from iree.turbine.aot import * @pytest.mark.parametrize( diff --git a/tests/aot/functionalize_test.py b/tests/aot/functionalize_test.py index 0cad8e93..2a2ea309 100644 --- a/tests/aot/functionalize_test.py +++ b/tests/aot/functionalize_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class FunctionalizeTests(unittest.TestCase): diff --git a/tests/aot/fx_programs_test.py b/tests/aot/fx_programs_test.py index c54f1851..f2c70456 100644 --- a/tests/aot/fx_programs_test.py +++ b/tests/aot/fx_programs_test.py @@ -10,7 +10,7 @@ import pytest import torch -from shark_turbine.aot import ( +from iree.turbine.aot import ( FxPrograms, FxProgramsBuilder, ) diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 607382fd..7a250531 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -11,7 +11,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * import torch import torch.nn as nn diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index 9f479921..251c8f12 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class CompiledModuleAPI(unittest.TestCase): diff --git a/tests/aot/jittable_test.py b/tests/aot/jittable_test.py index 9c87fb11..d19988bc 100644 --- a/tests/aot/jittable_test.py +++ b/tests/aot/jittable_test.py @@ -13,7 +13,7 @@ Context, ) -from shark_turbine.aot import * +from iree.turbine.aot import * class JittableTests(unittest.TestCase): diff --git a/tests/aot/non_strict_export_test.py b/tests/aot/non_strict_export_test.py index ece961dc..2ed1b603 100644 --- a/tests/aot/non_strict_export_test.py +++ b/tests/aot/non_strict_export_test.py @@ -3,7 +3,7 @@ from torch import nn import torch -from shark_turbine.aot import * +from iree.turbine.aot import * logger = logging.getLogger(__file__) diff --git a/tests/aot/params_test.py b/tests/aot/params_test.py index a1d64206..895cb2b9 100644 --- a/tests/aot/params_test.py +++ b/tests/aot/params_test.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from shark_turbine.aot import ( +from iree.turbine.aot import ( export, externalize_module_parameters, save_module_parameters, diff --git a/tests/dynamo/importer_dynamic_test.py b/tests/dynamo/importer_dynamic_test.py index 72ff4f82..682aa140 100644 --- a/tests/dynamo/importer_dynamic_test.py +++ b/tests/dynamo/importer_dynamic_test.py @@ -14,7 +14,7 @@ # from torch._export.constraints import constrain_as_size, constrain_as_value from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline +from iree.turbine.dynamo.passes import turbine_cpu_pass_pipeline import torch import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd diff --git a/tests/dynamo/tensor_test.py b/tests/dynamo/tensor_test.py index fcd40660..0562c071 100644 --- a/tests/dynamo/tensor_test.py +++ b/tests/dynamo/tensor_test.py @@ -12,8 +12,8 @@ import torch # Public API imports. -from shark_turbine.runtime import Device -from shark_turbine.dynamo import TurbineMode, DeviceTensor +from iree.turbine.runtime import Device +from iree.turbine.dynamo import TurbineMode, DeviceTensor class TensorTest(unittest.TestCase): diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index 617c5d05..70375efb 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -12,7 +12,7 @@ Type as IrType, ) -import shark_turbine.dynamo.type_conversion as tc +import iree.turbine.dynamo.type_conversion as tc class TypeConversionTest(unittest.TestCase): diff --git a/tests/generated/evaluate.py b/tests/generated/evaluate.py index 3184930d..a971e23c 100644 --- a/tests/generated/evaluate.py +++ b/tests/generated/evaluate.py @@ -2,7 +2,7 @@ import logging from iree.compiler.extras.fx_importer import FxImporter -from shark_turbine.dynamo.passes import turbine_cpu_pass_pipeline +from iree.turbine.dynamo.passes import turbine_cpu_pass_pipeline import torch import torch._dynamo as dynamo from torch._dynamo.backends.common import aot_autograd diff --git a/tests/kernel/aot_kernel_test.py b/tests/kernel/aot_kernel_test.py index 690e366a..16363048 100644 --- a/tests/kernel/aot_kernel_test.py +++ b/tests/kernel/aot_kernel_test.py @@ -8,9 +8,9 @@ import unittest import torch -from shark_turbine.aot import export -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +from iree.turbine.aot import export +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl def export_softmax_kernel(): diff --git a/tests/kernel/arith_test.py b/tests/kernel/arith_test.py index 1631454c..ce9e659e 100644 --- a/tests/kernel/arith_test.py +++ b/tests/kernel/arith_test.py @@ -8,15 +8,15 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl -from shark_turbine.kernel.compiler import ( +from iree.turbine.kernel.compiler import ( builder, kernel_codegen, vector_codegen, ) -from shark_turbine.kernel._support import ( +from iree.turbine.kernel._support import ( indexing, ) diff --git a/tests/kernel/compiler/utils_test.py b/tests/kernel/compiler/utils_test.py index be084613..6f2db310 100644 --- a/tests/kernel/compiler/utils_test.py +++ b/tests/kernel/compiler/utils_test.py @@ -1,9 +1,9 @@ import logging import pytest import unittest -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel._support.indexing import IndexSymbol, IndexingContext -from shark_turbine.kernel.compiler.utils import strides_from_symbolic_shape +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel._support.indexing import IndexSymbol, IndexingContext +from iree.turbine.kernel.compiler.utils import strides_from_symbolic_shape class UtilsTest(unittest.TestCase): diff --git a/tests/kernel/dispatch_codegen_test.py b/tests/kernel/dispatch_codegen_test.py index be17a86d..b76ed2e1 100644 --- a/tests/kernel/dispatch_codegen_test.py +++ b/tests/kernel/dispatch_codegen_test.py @@ -8,16 +8,16 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl -from shark_turbine.kernel.compiler import ( +from iree.turbine.kernel.compiler import ( builder, dispatch_codegen, kernel_codegen, vector_codegen, ) -from shark_turbine.kernel._support import ( +from iree.turbine.kernel._support import ( indexing, ) diff --git a/tests/kernel/fused_attention_test.py b/tests/kernel/fused_attention_test.py index 89883780..abc9d7ad 100644 --- a/tests/kernel/fused_attention_test.py +++ b/tests/kernel/fused_attention_test.py @@ -8,8 +8,8 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl BATCH = tkl.sym.BATCH N_HEADS = tkl.sym.N_HEADS diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index 8bc27c50..677bbf09 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -9,8 +9,8 @@ import torch -from shark_turbine.kernel._support.indexing import * -from shark_turbine.kernel.lang import * +from iree.turbine.kernel._support.indexing import * +from iree.turbine.kernel.lang import * M = sym.M N = sym.N diff --git a/tests/kernel/simple_kernel_test.py b/tests/kernel/simple_kernel_test.py index 87cf3ed2..bffe723c 100644 --- a/tests/kernel/simple_kernel_test.py +++ b/tests/kernel/simple_kernel_test.py @@ -9,8 +9,8 @@ import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl M = tk.lang.sym.M K = tk.lang.sym.K @@ -37,9 +37,9 @@ def iota_kernel(out: tk.lang.KernelBuffer[M, tkl.index]): print(iota_kernel._trace().region_graph) # Prints: # .graph(): - # %out : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_global_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) + # %out : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] + # %program_id : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_global_buffer_setitem : [num_users=0] = call_function[target=iree.turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) # return None def testSoftmax(self): @@ -76,17 +76,17 @@ def softmax(x): print(softmax_kernel._trace().region_graph) # Prints: # graph(): - # %input_1 : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] - # %output : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] - # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %input_1 : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] + # %output : iree.turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] + # %program_id : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) # %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%input_1, (%program_id, slice(None, None, None))), kwargs = {}) # %max_1 : [num_users=1] = call_function[target=torch.max](args = (%getitem,), kwargs = {}) # %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem, %max_1), kwargs = {}) # %exp : [num_users=2] = call_function[target=torch.exp](args = (%sub,), kwargs = {}) # %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%exp,), kwargs = {}) # %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%exp, %sum_1), kwargs = {}) - # %program_id_1 : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) - # %_kernel_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) + # %program_id_1 : [num_users=1] = call_function[target=iree.turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_kernel_buffer_setitem : [num_users=0] = call_function[target=iree.turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) # return None diff --git a/tests/kernel/types_test.py b/tests/kernel/types_test.py index 87dc6536..e355db31 100644 --- a/tests/kernel/types_test.py +++ b/tests/kernel/types_test.py @@ -7,7 +7,7 @@ import logging import unittest -from shark_turbine.kernel.lang import ( +from iree.turbine.kernel.lang import ( Index, ) diff --git a/tests/kernel/vector_codegen_test.py b/tests/kernel/vector_codegen_test.py index fcd33462..696852c0 100644 --- a/tests/kernel/vector_codegen_test.py +++ b/tests/kernel/vector_codegen_test.py @@ -8,8 +8,8 @@ import unittest import torch -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl M = tk.lang.sym.M K = tk.lang.sym.K diff --git a/tests/kernel/wave/constraints_test.py b/tests/kernel/wave/constraints_test.py index 418c3c8b..f2915ac4 100644 --- a/tests/kernel/wave/constraints_test.py +++ b/tests/kernel/wave/constraints_test.py @@ -8,8 +8,8 @@ import pytest import unittest from sympy import ceiling -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel.wave.constraints import ( +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.wave.constraints import ( WorkgroupConstraint, get_grid_shape, TilingConstraint, diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 93d9cb6c..bb7cbc25 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -6,32 +6,32 @@ import unittest import logging -from shark_turbine.kernel.wave.scheduling.modulo_scheduling import ( +from iree.turbine.kernel.wave.scheduling.modulo_scheduling import ( ModuloScheduler, EdgeWeight, Edge, ) import torch.fx as fx import numpy as np -from shark_turbine.kernel.wave.visualization import visualize_graph -from shark_turbine.kernel.wave.scheduling.graph_utils import ( +from iree.turbine.kernel.wave.visualization import visualize_graph +from iree.turbine.kernel.wave.scheduling.graph_utils import ( find_strongly_connected_components, find_cycles_in_scc, all_pairs_longest_paths, evaluate_all_pairs_longest_paths, ) -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.wave.promotion import promote_placeholders -from shark_turbine.kernel.wave.hoisting import hoist_allocs -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads -from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph -from shark_turbine.kernel.ops.wave_ops import get_custom +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.wave.promotion import promote_placeholders +from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph +from iree.turbine.kernel.ops.wave_ops import get_custom class SchedulingTest(unittest.TestCase): diff --git a/tests/kernel/wave/types_test.py b/tests/kernel/wave/types_test.py index d27c4c47..cdb05c3e 100644 --- a/tests/kernel/wave/types_test.py +++ b/tests/kernel/wave/types_test.py @@ -9,9 +9,9 @@ import sympy import unittest -from shark_turbine.kernel.lang import Memory, Register, sym, f16 -from shark_turbine.kernel.lang.wave_types import AddressSpace -from shark_turbine.kernel.lang.kernel_buffer import KernelBufferUsage +from iree.turbine.kernel.lang import Memory, Register, sym, f16 +from iree.turbine.kernel.lang.wave_types import AddressSpace +from iree.turbine.kernel.lang.kernel_buffer import KernelBufferUsage M = sym.M N = sym.N diff --git a/tests/kernel/wave/visualization_test.py b/tests/kernel/wave/visualization_test.py index 17cce11c..ebe6a75f 100644 --- a/tests/kernel/wave/visualization_test.py +++ b/tests/kernel/wave/visualization_test.py @@ -9,15 +9,15 @@ import unittest import os import pytest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.expansion import expand_graph -from shark_turbine.kernel._support.tracing import CapturedTrace -from shark_turbine.kernel._support.indexing import IndexingContext -from shark_turbine.kernel.ops.wave_ops import get_custom -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.visualization import visualize_graph +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.ops.wave_ops import get_custom +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.visualization import visualize_graph def run(func: Callable[[], None]) -> Callable[[], None]: diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index cf2d1315..0611257d 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -4,12 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.wave_sim import wave_sim -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.iree_utils import generate_iree_ref +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.wave_sim import wave_sim +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref import torch from numpy.testing import assert_allclose, assert_equal import pytest diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index f9de28b2..e9487a64 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -8,11 +8,11 @@ import pytest import torch import unittest -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * -from shark_turbine.kernel.wave.iree_utils import generate_iree_ref +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref import os import json from torch.testing import assert_close diff --git a/tests/kernel/wave/wave_sim_test.py b/tests/kernel/wave/wave_sim_test.py index 5fa5695a..58ec1255 100644 --- a/tests/kernel/wave/wave_sim_test.py +++ b/tests/kernel/wave/wave_sim_test.py @@ -6,9 +6,9 @@ import pytest import torch -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.wave.wave_sim import wave_sim +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.wave.wave_sim import wave_sim from numpy.testing import assert_allclose diff --git a/tests/kernel/wave/wave_utils_test.py b/tests/kernel/wave/wave_utils_test.py index ec1198fd..bce6de9f 100644 --- a/tests/kernel/wave/wave_utils_test.py +++ b/tests/kernel/wave/wave_utils_test.py @@ -6,8 +6,8 @@ import logging import unittest -from shark_turbine.kernel.lang import sym -from shark_turbine.kernel.wave.utils import delinearize_index +from iree.turbine.kernel.lang import sym +from iree.turbine.kernel.wave.utils import delinearize_index import numpy as np M = sym.M diff --git a/tests/ops/iree_test.py b/tests/ops/iree_test.py index b06a7910..facbf545 100644 --- a/tests/ops/iree_test.py +++ b/tests/ops/iree_test.py @@ -10,8 +10,8 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot -import shark_turbine.ops as ops +import iree.turbine.aot as aot +import iree.turbine.ops as ops # See runtime/op_reg/kernel_aot_test.py for additional tests of the trace diff --git a/tests/runtime/device_test.py b/tests/runtime/device_test.py index e78aff8e..89cd83c8 100644 --- a/tests/runtime/device_test.py +++ b/tests/runtime/device_test.py @@ -14,17 +14,17 @@ from iree.runtime import HalElementType # Public API imports. -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Device, ) # Internals. -from shark_turbine.runtime.device import ( +from iree.turbine.runtime.device import ( _CURRENT_THREAD, get_device_from_torch, ) -from shark_turbine.support.exceptions import * +from iree.turbine.support.exceptions import * class DeviceTest(unittest.TestCase): @@ -151,7 +151,7 @@ def testFromTorchDevice(self): print(device.dump_device_info()) def testJit(self): - from shark_turbine.ops import _str_format_test_ops as test_ops + from iree.turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cuda:0") result = test_ops.test_add(t, t) @@ -161,7 +161,7 @@ def testJit(self): class TorchCPUInterop(unittest.TestCase): def testJitStrFormat(self): - from shark_turbine.ops import _str_format_test_ops as test_ops + from iree.turbine.ops import _str_format_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = test_ops.test_add(t, t) @@ -169,7 +169,7 @@ def testJitStrFormat(self): torch.testing.assert_close(result, expected) def testJitJinja(self): - from shark_turbine.ops import _jinja_test_ops as test_ops + from iree.turbine.ops import _jinja_test_ops as test_ops t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu") result = test_ops.test_add(t, t) diff --git a/tests/runtime/launch_test.py b/tests/runtime/launch_test.py index 1a142161..ad12b2e3 100644 --- a/tests/runtime/launch_test.py +++ b/tests/runtime/launch_test.py @@ -8,11 +8,11 @@ import torch import unittest -from shark_turbine.aot.params import ( +from iree.turbine.aot.params import ( ParameterArchiveBuilder, ) -from shark_turbine.runtime import ( +from iree.turbine.runtime import ( Launchable, ) diff --git a/tests/runtime/op_reg/impl_helper_test.py b/tests/runtime/op_reg/impl_helper_test.py index b0797c2d..2661dc5b 100644 --- a/tests/runtime/op_reg/impl_helper_test.py +++ b/tests/runtime/op_reg/impl_helper_test.py @@ -9,7 +9,7 @@ import torch -from shark_turbine.ops import _str_format_test_ops +from iree.turbine.ops import _str_format_test_ops class KernelRegTest(unittest.TestCase): diff --git a/tests/runtime/op_reg/kernel_aot_test.py b/tests/runtime/op_reg/kernel_aot_test.py index 4aa04857..4533326a 100644 --- a/tests/runtime/op_reg/kernel_aot_test.py +++ b/tests/runtime/op_reg/kernel_aot_test.py @@ -10,10 +10,10 @@ import torch import torch.nn as nn -import shark_turbine.aot as aot -import shark_turbine.ops as ops +import iree.turbine.aot as aot +import iree.turbine.ops as ops -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from iree.turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass class MLP(nn.Module): diff --git a/tests/runtime/op_reg/kernel_reg_test.py b/tests/runtime/op_reg/kernel_reg_test.py index 75554b04..dfc88d83 100644 --- a/tests/runtime/op_reg/kernel_reg_test.py +++ b/tests/runtime/op_reg/kernel_reg_test.py @@ -9,9 +9,9 @@ import torch -from shark_turbine.runtime.op_reg import * +from iree.turbine.runtime.op_reg import * -from shark_turbine.runtime.op_reg.compiler import _testing_get_cache_size +from iree.turbine.runtime.op_reg.compiler import _testing_get_cache_size class KernelRegTest(unittest.TestCase): diff --git a/tests/tools/interpreter_test.py b/tests/tools/interpreter_test.py index 0513b10b..2152c701 100644 --- a/tests/tools/interpreter_test.py +++ b/tests/tools/interpreter_test.py @@ -1,8 +1,8 @@ -from shark_turbine.tools.interpreter import Interpreter -import shark_turbine.kernel as tk -import shark_turbine.kernel.lang as tkl -import shark_turbine.kernel.wave as tkw -from shark_turbine.kernel.lang.global_symbols import * +from iree.turbine.tools.interpreter import Interpreter +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * import torch diff --git a/tests/top_level_package_test.py b/tests/top_level_package_test.py index 52ea796b..b2c04cdd 100644 --- a/tests/top_level_package_test.py +++ b/tests/top_level_package_test.py @@ -11,8 +11,8 @@ class TopLevelPackageTest(unittest.TestCase): def testIreeTurbineRedirect(self): # We have a temporary redirect of the top-level API to the - # iree.turbine namespace. - from iree.turbine import aot, dynamo, kernel, ops, runtime + # shark-turbine namespace. + from shark_turbine import aot, dynamo, kernel, ops, runtime if __name__ == "__main__": diff --git a/tests/transforms/general/add_metadata_test.py b/tests/transforms/general/add_metadata_test.py index 8055fa26..da5d0207 100644 --- a/tests/transforms/general/add_metadata_test.py +++ b/tests/transforms/general/add_metadata_test.py @@ -11,7 +11,7 @@ from iree.compiler.ir import Context, Operation, Module -from shark_turbine.transforms.general import add_metadata +from iree.turbine.transforms.general import add_metadata SIMPLE_FUNC_ASM = r""" func.func @list_func(%arg0 : !iree_input.list) -> !iree_input.list { diff --git a/tests/transforms/general/custom_op_expansion_test.py b/tests/transforms/general/custom_op_expansion_test.py index b94e2750..f621320d 100644 --- a/tests/transforms/general/custom_op_expansion_test.py +++ b/tests/transforms/general/custom_op_expansion_test.py @@ -9,15 +9,15 @@ import torch import unittest -from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass -from shark_turbine.runtime.op_reg import ( +from iree.turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass +from iree.turbine.runtime.op_reg import ( def_library, CustomOp, KernelBuilder, KernelSelection, ) -from shark_turbine.support.ir_imports import ( +from iree.turbine.support.ir_imports import ( Context, Module, ) diff --git a/tests/transforms/general/rename_parameters_test.py b/tests/transforms/general/rename_parameters_test.py index 74fc6753..a14dbcbd 100644 --- a/tests/transforms/general/rename_parameters_test.py +++ b/tests/transforms/general/rename_parameters_test.py @@ -14,8 +14,8 @@ Operation, ) -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.general import rename_parameters +from iree.turbine.transforms import rewriter +from iree.turbine.transforms.general import rename_parameters SIMPLE_GLOBALS_ASM = r""" module { diff --git a/tests/transforms/quantization/mm_group_quant_test.py b/tests/transforms/quantization/mm_group_quant_test.py index c6870d2c..b465301b 100644 --- a/tests/transforms/quantization/mm_group_quant_test.py +++ b/tests/transforms/quantization/mm_group_quant_test.py @@ -14,8 +14,8 @@ Operation, ) -from shark_turbine.transforms import rewriter -from shark_turbine.transforms.quantization import mm_group_quant +from iree.turbine.transforms import rewriter +from iree.turbine.transforms.quantization import mm_group_quant MM_F32_TO_INT4_CONTENTS = ( Path(__file__).resolve().parent / "mm_f32_to_int4.mlir"