-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8313854
commit 0828d05
Showing
8 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Sets up a Python venv with compiler/runtime from a workflow run. | ||
There are two modes in which to use this script: | ||
* Within a workflow, an artifact action will typically be used to fetch | ||
relevant package artifacts. Specify the fetch location with | ||
`--artifact-path=`. | ||
* Locally, the `--fetch-gh-workflow=WORKFLOW_ID` can be used instead in order | ||
to download and setup the venv in one step. | ||
You must have the `gh` command line tool installed and authenticated if you | ||
will be fetching artifacts. | ||
""" | ||
|
||
from typing import Optional, Dict, Tuple | ||
|
||
import argparse | ||
import functools | ||
from glob import glob | ||
import json | ||
import sys | ||
from pathlib import Path | ||
import platform | ||
import subprocess | ||
import sys | ||
import tempfile | ||
import zipfile | ||
|
||
|
||
@functools.lru_cache | ||
def list_gh_artifacts(run_id: str) -> Dict[str, str]: | ||
print(f"Fetching artifacts for workflow run {run_id}") | ||
base_path = f"/repos/openxla/iree" | ||
output = subprocess.check_output( | ||
[ | ||
"gh", | ||
"api", | ||
"-H", | ||
"Accept: application/vnd.github+json", | ||
"-H", | ||
"X-GitHub-Api-Version: 2022-11-28", | ||
f"{base_path}/actions/runs/{run_id}/artifacts", | ||
] | ||
) | ||
data = json.loads(output) | ||
# Uncomment to debug: | ||
# print(json.dumps(data, indent=2)) | ||
artifacts = { | ||
rec["name"]: f"{base_path}/actions/artifacts/{rec['id']}/zip" | ||
for rec in data["artifacts"] | ||
} | ||
print("Found artifacts:") | ||
for k, v in artifacts.items(): | ||
print(f" {k}: {v}") | ||
return artifacts | ||
|
||
|
||
def fetch_gh_artifact(api_path: str, file: Path): | ||
print(f"Downloading artifact {api_path}") | ||
contents = subprocess.check_output( | ||
[ | ||
"gh", | ||
"api", | ||
"-H", | ||
"Accept: application/vnd.github+json", | ||
"-H", | ||
"X-GitHub-Api-Version: 2022-11-28", | ||
api_path, | ||
] | ||
) | ||
with open(file, "wb") as f: | ||
f.write(contents) | ||
|
||
|
||
def find_venv_python(venv_path: Path) -> Optional[Path]: | ||
paths = [venv_path / "bin" / "python", venv_path / "Scripts" / "python.exe"] | ||
for p in paths: | ||
if p.exists(): | ||
return p | ||
return None | ||
|
||
|
||
def parse_arguments(argv=None): | ||
parser = argparse.ArgumentParser(description="Setup venv") | ||
parser.add_argument("--artifact-path", help="Path in which to find/fetch artifacts") | ||
parser.add_argument( | ||
"--fetch-gh-workflow", help="Fetch artifacts from a GitHub workflow" | ||
) | ||
parser.add_argument( | ||
"--compiler-variant", | ||
default="", | ||
help="Package variant to install for the compiler ('', 'asserts')", | ||
) | ||
parser.add_argument( | ||
"--runtime-variant", | ||
default="", | ||
help="Package variant to install for the runtime ('', 'asserts')", | ||
) | ||
parser.add_argument("venv_dir", help="Directory in which to create the venv") | ||
args = parser.parse_args(argv) | ||
return args | ||
|
||
|
||
def main(args): | ||
# Make sure we have an artifact path if fetching. | ||
if not args.artifact_path and args.fetch_gh_workflow: | ||
with tempfile.TemporaryDirectory() as td: | ||
args.artifact_path = td | ||
return main(args) | ||
|
||
artifact_prefix = f"{platform.system().lower()}_{platform.machine()}" | ||
wheels = [] | ||
for package_stem, variant in [ | ||
("iree-compiler", args.compiler_variant), | ||
("iree-runtime", args.runtime_variant), | ||
]: | ||
wheels.append( | ||
find_wheel_for_variants(args, artifact_prefix, package_stem, variant) | ||
) | ||
print("Installing wheels:", wheels) | ||
|
||
# Set up venv. | ||
venv_path = Path(args.venv_dir) | ||
python_exe = find_venv_python(venv_path) | ||
|
||
if not python_exe: | ||
print(f"Creating venv at {str(venv_path)}") | ||
subprocess.check_call([sys.executable, "-m", "venv", str(venv_path)]) | ||
python_exe = find_venv_python(venv_path) | ||
if not python_exe: | ||
raise RuntimeError("Error creating venv") | ||
|
||
for artifact_path, package_name in wheels: | ||
cmd = [ | ||
python_exe, | ||
"-m", | ||
"pip", | ||
"install", | ||
"--no-deps", | ||
"--no-index", | ||
"-f", | ||
str(artifact_path), | ||
"--force-reinstall", | ||
package_name, | ||
] | ||
print(f"Running command: {' '.join([str(c) for c in cmd])}") | ||
subprocess.check_call(cmd) | ||
|
||
return 0 | ||
|
||
|
||
def find_wheel_for_variants( | ||
args, artifact_prefix: str, package_stem: str, variant: str | ||
) -> Tuple[Path, str]: | ||
artifact_path = Path(args.artifact_path) | ||
package_suffix = "" if variant == "" else f"-{variant}" | ||
package_name = f"{package_stem}{package_suffix}" | ||
|
||
def has_package(): | ||
norm_package_name = package_name.replace("-", "_") | ||
pattern = str(artifact_path / f"{norm_package_name}-*.whl") | ||
files = glob(pattern) | ||
return bool(files) | ||
|
||
if has_package(): | ||
return (artifact_path, package_name) | ||
|
||
if not args.fetch_gh_workflow: | ||
raise RuntimeError( | ||
f"Could not find package {package_name} to install from {artifact_path}" | ||
) | ||
|
||
# Fetch. | ||
artifact_path.mkdir(parents=True, exist_ok=True) | ||
artifact_suffix = "" if variant == "" else f"_{variant}" | ||
artifact_name = f"{artifact_prefix}_release{artifact_suffix}_packages" | ||
artifact_file = artifact_path / f"{artifact_name}.zip" | ||
if not artifact_file.exists(): | ||
print(f"Package {package_name} not found. Fetching from {artifact_name}...") | ||
artifacts = list_gh_artifacts(args.fetch_gh_workflow) | ||
if artifact_name not in artifacts: | ||
raise RuntimeError( | ||
f"Could not find required artifact {artifact_name} in run {args.fetch_gh_workflow}" | ||
) | ||
fetch_gh_artifact(artifacts[artifact_name], artifact_file) | ||
print(f"Extracting {artifact_file}") | ||
with zipfile.ZipFile(artifact_file) as zip_ref: | ||
zip_ref.extractall(artifact_path) | ||
|
||
# Try again. | ||
if not has_package(): | ||
raise RuntimeError(f"Could not find {package_name} in {artifact_path}") | ||
return (artifact_path, package_name) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main(parse_arguments())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .fetch import ( | ||
fetch_source, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
|
||
|
||
def fetch_source(url: str): | ||
@pytest.fixture | ||
def fetcher(tmp_path_factory, worker_id): | ||
return f"Hi: {tmp_path_factory}, {worker_id}" | ||
|
||
return fetcher |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[build-system] | ||
requires = ["setuptools", "wheel"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.pytest.ini_options] | ||
markers = [ | ||
"plat_rdna_vulkan: mark tests as running on AMD RDNA Vulkan device", | ||
"presubmit: mark test as running on presubmit", | ||
"unstable_linalg: mark test as depending on unstable, serialized linalg IR", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[tool:pytest] | ||
testpaths = | ||
./tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from setuptools import find_namespace_packages, setup | ||
|
||
setup( | ||
name=f"iree-regression-suite", | ||
version=f"0.1dev1", | ||
packages=find_namespace_packages(include=[ | ||
"ireers", | ||
],), | ||
install_requires=[ | ||
"numpy", | ||
"pytest", | ||
"pytest-xdist", | ||
"wget", | ||
"PyYAML", | ||
], | ||
extras_require={ | ||
}, | ||
) |
17 changes: 17 additions & 0 deletions
17
experimental/regression_suite/tests/pregenerated/test_llama2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2023 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pytest | ||
from ireers import * | ||
|
||
llama2_7b_f16qi4_source = fetch_source("foobar") | ||
|
||
@pytest.mark.plat_rdna_vulkan | ||
@pytest.mark.presubmit | ||
@pytest.mark.unstable_linalg | ||
def test_step_rdna_vulkan(llama2_7b_f16qi4_source): | ||
print("GOOD") | ||
print(llama2_7b_f16qi4_source) |