diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml new file mode 100644 index 00000000..6bb72ebe --- /dev/null +++ b/.github/workflows/integration_tests.yml @@ -0,0 +1,35 @@ +name: Integration Tests + +on: pull_request + +jobs: + integration-tests: + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v2 + with: + lfs: true + + - name: Setup Python Environment + uses: actions/setup-python@v2 + with: + python-version: '3.10.14' + + - name: Install Dependencies + run: | + pip install -r requirements-dev.txt + pip install . + + - name: Install PARAM + run: | + git clone https://github.com/facebookresearch/param.git + cd param/train/compute/python/ + git checkout c83ce8429110a86549c40fec5a01acbd9fbd54a4 + pip install . + + - name: Extract and Validate + run: | + python3 ci_tools/integration_tests.py --tgz_path tests/data/1.0.2-chakra.0.0.4/llama_pytorch24.05.tgz \ + --num_ranks 8 --tolerance 0.05 --expected_times_ms 14597 14597 14968 14638 14649 14700 14677 14735 diff --git a/ci_tools/integration_tests.py b/ci_tools/integration_tests.py new file mode 100644 index 00000000..bc2e4ddb --- /dev/null +++ b/ci_tools/integration_tests.py @@ -0,0 +1,160 @@ +import argparse +import concurrent.futures +import os +import re +import subprocess +import tarfile + + +def extract_tgz(tgz_path: str, extract_to: str) -> None: + """ + Extracts a .tgz file to the specified directory. + + Args: + tgz_path (str): Path to the .tgz file. + extract_to (str): Directory to extract the files to. + """ + print(f"Extracting {tgz_path} to {extract_to}") + with tarfile.open(tgz_path, "r:gz") as tar: + tar.extractall(path=extract_to) + + +def run_command(command: str) -> None: + """ + Executes a given shell command and checks for errors. + + Args: + command (str): The shell command to execute. + + Raises: + RuntimeError: If the command fails. + """ + print(f"Running command: {command}") + os.environ["PATH"] = "/Users/theo/venv/bin/:" + os.environ.get("PATH", "") + try: + subprocess.run(command, check=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Command failed: {command}") from e + + +def run_commands_in_parallel(commands: list) -> None: + """ + Executes multiple commands in parallel. + + Args: + commands (list): A list of shell commands to execute. + """ + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(run_command, cmd) for cmd in commands] + for future in concurrent.futures.as_completed(futures): + future.result() + + +def run_trace_link(data_path: str, num_ranks: int) -> None: + """ + Prepares and runs chakra_trace_link commands in parallel for each file pair. + + Args: + data_path (str): The directory where the data files are located. + num_ranks (int): The number of file pairs to process. + """ + commands = [ + f"chakra_trace_link --pytorch-et-file {data_path}/chakra_host_et_{i}.json " + f"--kineto-file {data_path}/kineto_{i}.json " + f"--output-file {data_path}/chakra_et_plus_{i}.json" + for i in range(num_ranks) + ] + run_commands_in_parallel(commands) + + +def run_converter(data_path: str, num_ranks: int) -> None: + """ + Prepares and runs chakra_converter commands in parallel for each output of chakra_trace_link. + + Args: + data_path (str): The directory where the output files are located. + num_ranks (int): The number of output files to process. + """ + commands = [ + f"chakra_converter --input_filename {data_path}/chakra_et_plus_{i}.json " + f"--output_filename {data_path}/chakra_final_{i}.chakra " + f"--input_type PyTorch --log_filename /tmp/rank_{i}.log" + for i in range(num_ranks) + ] + run_commands_in_parallel(commands) + + +def validate_log(filename: str, expected_time_us: int, tolerance: float) -> None: + """ + Validates the log file to ensure the last operation completes within the expected time with an allowable error. + + Args: + filename (str): Path to the log file. + expected_time_us (int): Expected completion time in microseconds. + tolerance (float): Acceptable error percentage as a decimal. + + Raises: + ValueError: If the log does not contain the expected output or is outside the acceptable time range. + """ + completion_pattern = re.compile( + r"INFO \[\d{2}/\d{2}/\d{4} \d{2}:\d{2}:\d{2} PM\] GPU Node ID \d+ completed at (\d+)us" + ) + with open(filename, "r") as file: + last_time = None + for line in file: + match = completion_pattern.search(line) + if match: + last_time = int(match.group(1)) + + if last_time is None: + raise ValueError(f"No completion time found in {filename}") + + lower_bound = expected_time_us * (1 - tolerance) + upper_bound = expected_time_us * (1 + tolerance) + + if not lower_bound <= last_time <= upper_bound: + raise ValueError( + f"Completion time in {filename} is {last_time}us; expected between {lower_bound}us and {upper_bound}us." + ) + print(f"Validation successful for {filename}: {last_time}us is within the acceptable range.") + + +def parse_args(): + """ + Parses command line arguments. + """ + parser = argparse.ArgumentParser(description="Run integration tests for chakra_trace_link and chakra_converter.") + parser.add_argument("--tgz_path", type=str, required=True, help="Path to the tgz file to extract.") + parser.add_argument("--num_ranks", type=int, required=True, help="Number of ranks to process.") + parser.add_argument("--tolerance", type=float, required=True, help="Acceptable error percentage as a decimal.") + parser.add_argument( + "--expected_times_ms", type=int, nargs="+", required=True, help="List of expected times in milliseconds." + ) + return parser.parse_args() + + +def main() -> None: + """ + Main function to execute the integration test sequence. + """ + args = parse_args() + extract_dir = os.path.dirname(args.tgz_path) + data_path = os.path.join(extract_dir, os.path.basename(args.tgz_path).replace(".tgz", "")) + + # Extracting files + extract_tgz(args.tgz_path, extract_dir) + + expected_times_us = [time * 1000 for time in args.expected_times_ms] + + # Run trace link and converter processes + run_trace_link(data_path, args.num_ranks) + run_converter(data_path, args.num_ranks) + + # Validate output logs + for i in range(args.num_ranks): + log_file = f"/tmp/rank_{i}.log" + validate_log(log_file, expected_times_us[i], args.tolerance) + + +if __name__ == "__main__": + main() diff --git a/tests/ci_tools/test_integration_tests.py b/tests/ci_tools/test_integration_tests.py new file mode 100644 index 00000000..ee516cd0 --- /dev/null +++ b/tests/ci_tools/test_integration_tests.py @@ -0,0 +1,49 @@ +import subprocess +from unittest.mock import MagicMock, patch + +import pytest + +from ci_tools.integration_tests import extract_tgz, run_command, validate_log + + +def test_extract_tgz(): + """Test extracting a tgz file to ensure tarfile.open and extractall are called.""" + with patch("tarfile.open", MagicMock()) as mock_tar: + mock_tar.return_value.__enter__.return_value.extractall = MagicMock() + extract_tgz("path/to/test.tgz", "path/to/extract") + mock_tar.assert_called_once_with("path/to/test.tgz", "r:gz") + mock_tar.return_value.__enter__.return_value.extractall.assert_called_once_with(path="path/to/extract") + + +def test_run_command_success(): + """Test run_command with a command that succeeds without raising an error.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + run_command("echo 'Hello World'") + mock_run.assert_called_once() + + +def test_run_command_failure(): + """Test run_command with a command that fails and should raise RuntimeError.""" + with patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd", "Error message")): + with pytest.raises(RuntimeError) as excinfo: + run_command("exit 1") + assert "Command failed: exit 1" in str(excinfo.value) + + +def test_validate_log_success(tmp_path): + """Test validate_log to ensure it passes when the last operation completes within the expected time.""" + log_file = tmp_path / "log.txt" + log_file.write_text("INFO [05/15/2024 08:32:04 PM] GPU Node ID 301123 completed at 1000000us") + validate_log(str(log_file), 1000000, 0.05) + + +def test_validate_log_failure(tmp_path): + """ + Test validate_log to ensure it raises a ValueError when the last operation is outside the acceptable time range. + """ + log_file = tmp_path / "log.txt" + log_file.write_text("INFO [05/15/2024 08:32:04 PM] GPU Node ID 301123 completed at 900000us") + with pytest.raises(ValueError) as excinfo: + validate_log(str(log_file), 1000000, 0.05) + assert "expected between" in str(excinfo.value)