Skip to content

Commit

Permalink
Add skeleton regression test jobs.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Aug 24, 2023
1 parent 8313854 commit 0828d05
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 0 deletions.
205 changes: 205 additions & 0 deletions build_tools/pkgci/setup_venv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#!/usr/bin/env python3
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Sets up a Python venv with compiler/runtime from a workflow run.
There are two modes in which to use this script:
* Within a workflow, an artifact action will typically be used to fetch
relevant package artifacts. Specify the fetch location with
`--artifact-path=`.
* Locally, the `--fetch-gh-workflow=WORKFLOW_ID` can be used instead in order
to download and setup the venv in one step.
You must have the `gh` command line tool installed and authenticated if you
will be fetching artifacts.
"""

from typing import Optional, Dict, Tuple

import argparse
import functools
from glob import glob
import json
import sys
from pathlib import Path
import platform
import subprocess
import sys
import tempfile
import zipfile


@functools.lru_cache
def list_gh_artifacts(run_id: str) -> Dict[str, str]:
print(f"Fetching artifacts for workflow run {run_id}")
base_path = f"/repos/openxla/iree"
output = subprocess.check_output(
[
"gh",
"api",
"-H",
"Accept: application/vnd.github+json",
"-H",
"X-GitHub-Api-Version: 2022-11-28",
f"{base_path}/actions/runs/{run_id}/artifacts",
]
)
data = json.loads(output)
# Uncomment to debug:
# print(json.dumps(data, indent=2))
artifacts = {
rec["name"]: f"{base_path}/actions/artifacts/{rec['id']}/zip"
for rec in data["artifacts"]
}
print("Found artifacts:")
for k, v in artifacts.items():
print(f" {k}: {v}")
return artifacts


def fetch_gh_artifact(api_path: str, file: Path):
print(f"Downloading artifact {api_path}")
contents = subprocess.check_output(
[
"gh",
"api",
"-H",
"Accept: application/vnd.github+json",
"-H",
"X-GitHub-Api-Version: 2022-11-28",
api_path,
]
)
with open(file, "wb") as f:
f.write(contents)


def find_venv_python(venv_path: Path) -> Optional[Path]:
paths = [venv_path / "bin" / "python", venv_path / "Scripts" / "python.exe"]
for p in paths:
if p.exists():
return p
return None


def parse_arguments(argv=None):
parser = argparse.ArgumentParser(description="Setup venv")
parser.add_argument("--artifact-path", help="Path in which to find/fetch artifacts")
parser.add_argument(
"--fetch-gh-workflow", help="Fetch artifacts from a GitHub workflow"
)
parser.add_argument(
"--compiler-variant",
default="",
help="Package variant to install for the compiler ('', 'asserts')",
)
parser.add_argument(
"--runtime-variant",
default="",
help="Package variant to install for the runtime ('', 'asserts')",
)
parser.add_argument("venv_dir", help="Directory in which to create the venv")
args = parser.parse_args(argv)
return args


def main(args):
# Make sure we have an artifact path if fetching.
if not args.artifact_path and args.fetch_gh_workflow:
with tempfile.TemporaryDirectory() as td:
args.artifact_path = td
return main(args)

artifact_prefix = f"{platform.system().lower()}_{platform.machine()}"
wheels = []
for package_stem, variant in [
("iree-compiler", args.compiler_variant),
("iree-runtime", args.runtime_variant),
]:
wheels.append(
find_wheel_for_variants(args, artifact_prefix, package_stem, variant)
)
print("Installing wheels:", wheels)

# Set up venv.
venv_path = Path(args.venv_dir)
python_exe = find_venv_python(venv_path)

if not python_exe:
print(f"Creating venv at {str(venv_path)}")
subprocess.check_call([sys.executable, "-m", "venv", str(venv_path)])
python_exe = find_venv_python(venv_path)
if not python_exe:
raise RuntimeError("Error creating venv")

for artifact_path, package_name in wheels:
cmd = [
python_exe,
"-m",
"pip",
"install",
"--no-deps",
"--no-index",
"-f",
str(artifact_path),
"--force-reinstall",
package_name,
]
print(f"Running command: {' '.join([str(c) for c in cmd])}")
subprocess.check_call(cmd)

return 0


def find_wheel_for_variants(
args, artifact_prefix: str, package_stem: str, variant: str
) -> Tuple[Path, str]:
artifact_path = Path(args.artifact_path)
package_suffix = "" if variant == "" else f"-{variant}"
package_name = f"{package_stem}{package_suffix}"

def has_package():
norm_package_name = package_name.replace("-", "_")
pattern = str(artifact_path / f"{norm_package_name}-*.whl")
files = glob(pattern)
return bool(files)

if has_package():
return (artifact_path, package_name)

if not args.fetch_gh_workflow:
raise RuntimeError(
f"Could not find package {package_name} to install from {artifact_path}"
)

# Fetch.
artifact_path.mkdir(parents=True, exist_ok=True)
artifact_suffix = "" if variant == "" else f"_{variant}"
artifact_name = f"{artifact_prefix}_release{artifact_suffix}_packages"
artifact_file = artifact_path / f"{artifact_name}.zip"
if not artifact_file.exists():
print(f"Package {package_name} not found. Fetching from {artifact_name}...")
artifacts = list_gh_artifacts(args.fetch_gh_workflow)
if artifact_name not in artifacts:
raise RuntimeError(
f"Could not find required artifact {artifact_name} in run {args.fetch_gh_workflow}"
)
fetch_gh_artifact(artifacts[artifact_name], artifact_file)
print(f"Extracting {artifact_file}")
with zipfile.ZipFile(artifact_file) as zip_ref:
zip_ref.extractall(artifact_path)

# Try again.
if not has_package():
raise RuntimeError(f"Could not find {package_name} in {artifact_path}")
return (artifact_path, package_name)


if __name__ == "__main__":
sys.exit(main(parse_arguments()))
9 changes: 9 additions & 0 deletions experimental/regression_suite/ireers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .fetch import (
fetch_source,
)
9 changes: 9 additions & 0 deletions experimental/regression_suite/ireers/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest


15 changes: 15 additions & 0 deletions experimental/regression_suite/ireers/fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest


def fetch_source(url: str):
@pytest.fixture
def fetcher(tmp_path_factory, worker_id):
return f"Hi: {tmp_path_factory}, {worker_id}"

return fetcher
10 changes: 10 additions & 0 deletions experimental/regression_suite/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
markers = [
"plat_rdna_vulkan: mark tests as running on AMD RDNA Vulkan device",
"presubmit: mark test as running on presubmit",
"unstable_linalg: mark test as depending on unstable, serialized linalg IR",
]
3 changes: 3 additions & 0 deletions experimental/regression_suite/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool:pytest]
testpaths =
./tests
24 changes: 24 additions & 0 deletions experimental/regression_suite/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from setuptools import find_namespace_packages, setup

setup(
name=f"iree-regression-suite",
version=f"0.1dev1",
packages=find_namespace_packages(include=[
"ireers",
],),
install_requires=[
"numpy",
"pytest",
"pytest-xdist",
"wget",
"PyYAML",
],
extras_require={
},
)
17 changes: 17 additions & 0 deletions experimental/regression_suite/tests/pregenerated/test_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest
from ireers import *

llama2_7b_f16qi4_source = fetch_source("foobar")

@pytest.mark.plat_rdna_vulkan
@pytest.mark.presubmit
@pytest.mark.unstable_linalg
def test_step_rdna_vulkan(llama2_7b_f16qi4_source):
print("GOOD")
print(llama2_7b_f16qi4_source)

0 comments on commit 0828d05

Please sign in to comment.