Skip to content

Commit

Permalink
Add integration_tests.yml
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed May 16, 2024
1 parent e5616d1 commit 68114b5
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Integration Tests

on: pull_request

jobs:
integration-tests:
runs-on: ubuntu-latest

steps:
- name: Checkout Code
uses: actions/checkout@v2
with:
lfs: true

- name: Setup Python Environment
uses: actions/setup-python@v2
with:
python-version: '3.10.14'

- name: Install Dependencies
run: |
pip install -r requirements-dev.txt
pip install .
- name: Install PARAM
run: |
git clone https://github.com/facebookresearch/param.git
cd param/train/compute/python/
git checkout c83ce8429110a86549c40fec5a01acbd9fbd54a4
pip install .
- name: Extract and Validate
run: |
python3 ci_tools/integration_tests.py --tgz_path tests/data/1.0.2-chakra.0.0.4/llama_pytorch24.05.tgz \
--num_ranks 8 --tolerance 0.05 --expected_times_ms 14597 14597 14968 14638 14649 14700 14677 14735
160 changes: 160 additions & 0 deletions ci_tools/integration_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import argparse
import concurrent.futures
import os
import re
import subprocess
import tarfile


def extract_tgz(tgz_path: str, extract_to: str) -> None:
"""
Extracts a .tgz file to the specified directory.
Args:
tgz_path (str): Path to the .tgz file.
extract_to (str): Directory to extract the files to.
"""
print(f"Extracting {tgz_path} to {extract_to}")
with tarfile.open(tgz_path, "r:gz") as tar:
tar.extractall(path=extract_to)


def run_command(command: str) -> None:
"""
Executes a given shell command and checks for errors.
Args:
command (str): The shell command to execute.
Raises:
RuntimeError: If the command fails.
"""
print(f"Running command: {command}")
os.environ["PATH"] = "/Users/theo/venv/bin/:" + os.environ.get("PATH", "")
try:
subprocess.run(command, check=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Command failed: {command}") from e


def run_commands_in_parallel(commands: list) -> None:
"""
Executes multiple commands in parallel.
Args:
commands (list): A list of shell commands to execute.
"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(run_command, cmd) for cmd in commands]
for future in concurrent.futures.as_completed(futures):
future.result()


def run_trace_link(data_path: str, num_ranks: int) -> None:
"""
Prepares and runs chakra_trace_link commands in parallel for each file pair.
Args:
data_path (str): The directory where the data files are located.
num_ranks (int): The number of file pairs to process.
"""
commands = [
f"chakra_trace_link --pytorch-et-file {data_path}/chakra_host_et_{i}.json "
f"--kineto-file {data_path}/kineto_{i}.json "
f"--output-file {data_path}/chakra_et_plus_{i}.json"
for i in range(num_ranks)
]
run_commands_in_parallel(commands)


def run_converter(data_path: str, num_ranks: int) -> None:
"""
Prepares and runs chakra_converter commands in parallel for each output of chakra_trace_link.
Args:
data_path (str): The directory where the output files are located.
num_ranks (int): The number of output files to process.
"""
commands = [
f"chakra_converter --input_filename {data_path}/chakra_et_plus_{i}.json "
f"--output_filename {data_path}/chakra_final_{i}.chakra "
f"--input_type PyTorch --log_filename /tmp/rank_{i}.log"
for i in range(num_ranks)
]
run_commands_in_parallel(commands)


def validate_log(filename: str, expected_time_us: int, tolerance: float) -> None:
"""
Validates the log file to ensure the last operation completes within the expected time with an allowable error.
Args:
filename (str): Path to the log file.
expected_time_us (int): Expected completion time in microseconds.
tolerance (float): Acceptable error percentage as a decimal.
Raises:
ValueError: If the log does not contain the expected output or is outside the acceptable time range.
"""
completion_pattern = re.compile(
r"INFO \[\d{2}/\d{2}/\d{4} \d{2}:\d{2}:\d{2} PM\] GPU Node ID \d+ completed at (\d+)us"
)
with open(filename, "r") as file:
last_time = None
for line in file:
match = completion_pattern.search(line)
if match:
last_time = int(match.group(1))

if last_time is None:
raise ValueError(f"No completion time found in {filename}")

lower_bound = expected_time_us * (1 - tolerance)
upper_bound = expected_time_us * (1 + tolerance)

if not lower_bound <= last_time <= upper_bound:
raise ValueError(
f"Completion time in {filename} is {last_time}us; expected between {lower_bound}us and {upper_bound}us."
)
print(f"Validation successful for {filename}: {last_time}us is within the acceptable range.")


def parse_args():
"""
Parses command line arguments.
"""
parser = argparse.ArgumentParser(description="Run integration tests for chakra_trace_link and chakra_converter.")
parser.add_argument("--tgz_path", type=str, required=True, help="Path to the tgz file to extract.")
parser.add_argument("--num_ranks", type=int, required=True, help="Number of ranks to process.")
parser.add_argument("--tolerance", type=float, required=True, help="Acceptable error percentage as a decimal.")
parser.add_argument(
"--expected_times_ms", type=int, nargs="+", required=True, help="List of expected times in milliseconds."
)
return parser.parse_args()


def main() -> None:
"""
Main function to execute the integration test sequence.
"""
args = parse_args()
extract_dir = os.path.dirname(args.tgz_path)
data_path = os.path.join(extract_dir, os.path.basename(args.tgz_path).replace(".tgz", ""))

# Extracting files
extract_tgz(args.tgz_path, extract_dir)

expected_times_us = [time * 1000 for time in args.expected_times_ms]

# Run trace link and converter processes
run_trace_link(data_path, args.num_ranks)
run_converter(data_path, args.num_ranks)

# Validate output logs
for i in range(args.num_ranks):
log_file = f"/tmp/rank_{i}.log"
validate_log(log_file, expected_times_us[i], args.tolerance)


if __name__ == "__main__":
main()
49 changes: 49 additions & 0 deletions tests/ci_tools/test_integration_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import subprocess
from unittest.mock import MagicMock, patch

import pytest

from ci_tools.integration_tests import extract_tgz, run_command, validate_log


def test_extract_tgz():
"""Test extracting a tgz file to ensure tarfile.open and extractall are called."""
with patch("tarfile.open", MagicMock()) as mock_tar:
mock_tar.return_value.__enter__.return_value.extractall = MagicMock()
extract_tgz("path/to/test.tgz", "path/to/extract")
mock_tar.assert_called_once_with("path/to/test.tgz", "r:gz")
mock_tar.return_value.__enter__.return_value.extractall.assert_called_once_with(path="path/to/extract")


def test_run_command_success():
"""Test run_command with a command that succeeds without raising an error."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
run_command("echo 'Hello World'")
mock_run.assert_called_once()


def test_run_command_failure():
"""Test run_command with a command that fails and should raise RuntimeError."""
with patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd", "Error message")):
with pytest.raises(RuntimeError) as excinfo:
run_command("exit 1")
assert "Command failed: exit 1" in str(excinfo.value)


def test_validate_log_success(tmp_path):
"""Test validate_log to ensure it passes when the last operation completes within the expected time."""
log_file = tmp_path / "log.txt"
log_file.write_text("INFO [05/15/2024 08:32:04 PM] GPU Node ID 301123 completed at 1000000us")
validate_log(str(log_file), 1000000, 0.05)


def test_validate_log_failure(tmp_path):
"""
Test validate_log to ensure it raises a ValueError when the last operation is outside the acceptable time range.
"""
log_file = tmp_path / "log.txt"
log_file.write_text("INFO [05/15/2024 08:32:04 PM] GPU Node ID 301123 completed at 900000us")
with pytest.raises(ValueError) as excinfo:
validate_log(str(log_file), 1000000, 0.05)
assert "expected between" in str(excinfo.value)

0 comments on commit 68114b5

Please sign in to comment.