From d1b9d1ef7648b224d0211aa12dc11ad6ef427def Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Mon, 19 Aug 2024 17:22:06 -0700 Subject: [PATCH] Python APIs for testdata generation (#2446) This PR exposes some key APIs to auto-generate StableHLO test programs in testdata format, which are leveraged in https://github.com/openxla/stablehlo/pull/2404. ## Proposed API ```python def testdata_generator( module: ir.Module, args: Sequence[np.ndarray] = [] ) -> ir.Module: ``` - `module`: The StableHLO module to generate test data for. - `args`: (Optional) A sequence of NumPy arrays representing input values for the module. If not provided, the function will attempt to extract input values from the module itself. ## Example ```python # Input (module_str) module_str = """ module { func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> return %0 : tensor<2xf32> } } """ # Input (args) args = [ np.array([1.0, 2.0], dtype=np.float32), np.array([3.0, 4.0], dtype=np.float32) ] # Generate test data module_output = testdata_generator(module, args) # Output (module_output) module_output_str = """ module { func.func @main() -> tensor { %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> %cst_0 = stablehlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> %cst_1 = stablehlo.constant dense<[4.000000e+00, 6.000000e+00]> : tensor<2xf32> %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> %1 = stablehlo.custom_call @check.eq(%cst_1, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor return %1 : tensor } } """ ``` Note to reviewer: The current PR is based on https://github.com/openxla/stablehlo/pull/2445, so please review that first. --- stablehlo/integrations/python/CMakeLists.txt | 9 + .../stablehlo/testdata_generator/README.md | 55 ++ .../testdata_execution_utils.py | 40 ++ .../testdata_generator_lib.py | 101 ++++ .../testdata_generator/testdata_processor.py | 488 ++++++++++++++++++ .../integrations/python/tests/CMakeLists.txt | 1 + .../python/tests/testdata_generator_test.py | 212 ++++++++ 7 files changed, 906 insertions(+) create mode 100644 stablehlo/integrations/python/stablehlo/testdata_generator/README.md create mode 100644 stablehlo/integrations/python/stablehlo/testdata_generator/testdata_execution_utils.py create mode 100644 stablehlo/integrations/python/stablehlo/testdata_generator/testdata_generator_lib.py create mode 100644 stablehlo/integrations/python/stablehlo/testdata_generator/testdata_processor.py create mode 100644 stablehlo/integrations/python/tests/testdata_generator_test.py diff --git a/stablehlo/integrations/python/CMakeLists.txt b/stablehlo/integrations/python/CMakeLists.txt index 43dda81f8fb..2028efd34f4 100644 --- a/stablehlo/integrations/python/CMakeLists.txt +++ b/stablehlo/integrations/python/CMakeLists.txt @@ -65,6 +65,15 @@ declare_mlir_python_sources(StablehloToSavedModelPythonSources stablehlo/savedmodel/stablehlo_to_tf_saved_model.py ) +declare_mlir_python_sources(StablehloTestdataGeneratorPythonSources + ADD_TO_PARENT StablehloPythonSources + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" + SOURCES + stablehlo/testdata_generator/testdata_execution_utils.py + stablehlo/testdata_generator/testdata_generator_lib.py + stablehlo/testdata_generator/testdata_processor.py +) + declare_mlir_python_sources(VhloPythonSources) declare_mlir_python_sources(VhloPythonSources.Dialects ADD_TO_PARENT VhloPythonSources diff --git a/stablehlo/integrations/python/stablehlo/testdata_generator/README.md b/stablehlo/integrations/python/stablehlo/testdata_generator/README.md new file mode 100644 index 00000000000..b15bbe7171d --- /dev/null +++ b/stablehlo/integrations/python/stablehlo/testdata_generator/README.md @@ -0,0 +1,55 @@ +# Test Data Generation + +This module provides utilities for generating test data for StableHLO modules. +The primary API is the `testdata_generator` function, which automates the +process of creating test cases from existing StableHLO code. + +## Usage + +```python +def testdata_generator( + module: ir.Module, args: Sequence[np.ndarray] = [] +) -> ir.Module: +``` + +* `module`: The StableHLO module to generate test data for. +* `args`: (Optional) A sequence of NumPy arrays representing input values for + the module. If not provided, the function will attempt to extract input + values from the module itself. + +## Example + +```python +# Input (module_str) +module_str = """ +module { + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> + return %0 : tensor<2xf32> + } +} +""" + +# Input (args) +args = [ + np.array([1.0, 2.0], dtype=np.float32), + np.array([3.0, 4.0], dtype=np.float32) +] + +# Generate test data +module_output = testdata_generator(module, args) + +# Output (module_output) +module_output_str = """ +module { + func.func @main() -> tensor { + %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + %cst_1 = stablehlo.constant dense<[4.000000e+00, 6.000000e+00]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%cst_1, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor + return %1 : tensor + } +} +""" +``` diff --git a/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_execution_utils.py b/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_execution_utils.py new file mode 100644 index 00000000000..233139e0a32 --- /dev/null +++ b/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_execution_utils.py @@ -0,0 +1,40 @@ +# Copyright 2024 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Sequence +from absl import logging +from mlir import ir +from mlir.dialects import stablehlo as stablehlo_dialect +import numpy as np + + +def run_stablehlo_interpreter( + module: ir.Module, args: Sequence[np.ndarray] +) -> Sequence[np.ndarray]: + """Evaluates a StableHLO module. + + Args: + module: The MLIR module in StableHLO dialect. + args: Input data for the module as a sequence of NumPy arrays. + + Returns: + Sequence[np.ndarray]: Evaluated results from the interpreter as a sequence + of NumPy arrays. + """ + inputs = [ir.DenseElementsAttr.get(arg) for arg in args] + results = stablehlo_dialect.eval_module(module, inputs) + np_results = [np.array(result) for result in results] + + return np_results diff --git a/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_generator_lib.py b/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_generator_lib.py new file mode 100644 index 00000000000..ac3213ab3d0 --- /dev/null +++ b/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_generator_lib.py @@ -0,0 +1,101 @@ +# Copyright 2024 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testdata Generator Utils.""" + +from typing import Sequence + +from absl import logging +from mlir import ir +from mlir.stablehlo.testdata_generator import testdata_execution_utils +from mlir.stablehlo.testdata_generator import testdata_processor +import numpy as np + + +def testdata_generator( + module: ir.Module, args: Sequence[np.ndarray] = [] +) -> ir.Module: + """Generates test data for a StableHLO module. + + This function takes a StableHLO module and optional input arguments, processes + the module to + extract relevant information, executes the module to obtain golden results, + and then converts + the module, inputs, and golden results into a standardized test data format. + + Args: + module: The StableHLO module to generate test data for. + args: (Optional) A sequence of NumPy arrays representing input values for + the module. If not provided, the function will attempt to extract input + values from the module itself. + + Returns: + An MLIR module in the test data format, containing the original module, + inputs, and golden results. + + Example: + Input (module_str): + ``` + module { + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> + tensor<2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> + return %0 : tensor<2xf32> + } + } + ``` + + Input (args): + ``` + [ + np.array([1.0, 2.0], dtype=np.float32), + np.array([3.0, 4.0], dtype=np.float32) + ] + ``` + + Output (module_output_str): + ``` + module { + func.func @main() -> tensor { + %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : + tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.000000e+00, 4.000000e+00]> : + tensor<2xf32> + %cst_1 = stablehlo.constant dense<[4.000000e+00, 6.000000e+00]> : + tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%cst_1, %0) : (tensor<2xf32>, + tensor<2xf32>) -> tensor + return %1 : tensor + } + } + ``` + """ + module, inputs = testdata_processor.preprocess_input_module(module) + if args: + inputs = args + logging.info( + f"\t[testdata-generator] Processed module and inputs: {module}, {inputs}" + ) + + golden_results = testdata_execution_utils.run_stablehlo_interpreter( + module, inputs + ) + logging.info(f"\t[testdata-generator] Golden results: {golden_results}") + + module_output = testdata_processor.to_testdata_format( + module, inputs, golden_results + ) + + return module_output diff --git a/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_processor.py b/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_processor.py new file mode 100644 index 00000000000..65f54ad8ab0 --- /dev/null +++ b/stablehlo/integrations/python/stablehlo/testdata_generator/testdata_processor.py @@ -0,0 +1,488 @@ +# Copyright 2024 The StableHLO Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils to process testdata.""" + +import re +from typing import Sequence, Set +from absl import logging +from mlir import ir +from mlir import passmanager as pm +from mlir.dialects import check as check_dialect +from mlir.dialects import func as func_dialect +from mlir.dialects import stablehlo as stablehlo_dialect +import numpy as np + + +def _is_check_op(op: ir.Operation) -> bool: + """Checks if an MLIR operation is a supported check operation. + + This function identifies whether the given operation is one of the recognized + check operations, + including: + - check.expect_eq + - check.expect_eq_const + - check.expect_almost_eq + - check.expect_almost_eq_const + Or a stablehlo.custom_call with call_target_name as 'check.eq' + + Args: + op: The MLIR operation to check. + + Returns: + True if the operation is a check operation, False otherwise. + """ + check_ops = [ + check_dialect.ExpectEqOp, + check_dialect.ExpectEqConstOp, + check_dialect.ExpectAlmostEqOp, + check_dialect.ExpectAlmostEqConstOp, + stablehlo_dialect.CustomCallOp, + ] + + if any(isinstance(op.opview, check_op) for check_op in check_ops): + return True + + if isinstance(op.opview, stablehlo_dialect.CustomCallOp): + return op.opview.call_target_name.value == "check.eq" + + return False + + +def _get_constant_ops_providing_program_inputs( + module: ir.Module, +) -> Sequence[ir.Operation]: + """Identifies and returns the constant operations that provide input values to the program logic in an MLIR module. + + This function analyzes the first function within the provided MLIR module and + extracts a sequence of StableHLO constant + operations (`stablehlo.ConstantOp`) whose results are used as inputs to the + main computational logic of the program. + + Args: + module: The MLIR module to analyze. + + Returns: + A list `stablehlo.ConstantOp` operations that are + considered to provide input values to the program's + main logic. + """ + + return [ + op + for op in module.body.operations[0].body.blocks[0].operations + if isinstance(op, stablehlo_dialect.ConstantOp) + and any(not _is_check_op(use.owner) for use in op.result.uses) + ] + + +def _get_ops_under_test( + module: ir.Module, +) -> Set[ir.Operation]: + """Identifies and returns the operations under test within an MLIR module. + + This function performs the following steps to extract operations under test: + + 1. It calls `_get_constant_ops_providing_program_inputs` to obtain the + constant + operations that + provide input values to the main logic of the program. + 2. For each constant operation, it finds all operations that use its results. + 3. It filters out any operations that are of the following types: + - `stablehlo_dialect.CustomCallOp` + - `check.*` + 4. It returns a set of the remaining operations, which are considered the + operations under test. + + Args: + module: The MLIR module (in StableHLO dialect) to analyze. + + Returns: + A set of unique MLIR operations that are identified as being under test. + """ + constant_ops_for_program_inputs = _get_constant_ops_providing_program_inputs( + module + ) + return { + use.owner + for const_op in constant_ops_for_program_inputs + for use in const_op.result.uses + if not _is_check_op(use.owner) + } + + +def _extract_testdata_inputs( + module: ir.Module, +) -> Sequence[np.ndarray]: + """Extracts input data (as NumPy arrays) from a StableHLO module. + + It performs the following steps: + + 1. Identifies the constant operations within the module that are + used as input values for the main computation (excluding those used + for checks or other purposes). + 2. Extracts the numerical values from the dense elements attributes of + these constant operations. + 3. Converts the extracted values into NumPy arrays. + + Args: + module: The MLIR module (in StableHLO dialect). + + Returns: + A sequence of NumPy arrays, where each array corresponds to the input data + extracted from a constant operation. + + Raises: + ValueError: If an error occurs during the extraction of input values, + such as a mismatch between the number of constant operations + and the extracted values. + """ + constant_ops = _get_constant_ops_providing_program_inputs(module) + + input_values = [] + for constant_op in constant_ops: + attr = constant_op.opview.value + if isinstance(attr, ir.DenseElementsAttr): + input_values.append(np.array(attr)) + + if len(input_values) != len(constant_ops): + raise ValueError("Error in extracting input values") + + return input_values + + +def _replace_argument_with_constant( + module: ir.Module, + inputs: Sequence[np.ndarray], +) -> ir.Module: + """Replaces arguments of the main function in an MLIR module with constant values. + + The constant values are derived from the provided `inputs` NumPy arrays. + The function also updates the function signature to reflect the removal of the + arguments. + + Args: + module: The MLIR module containing the function whose arguments are to be + replaced. + inputs: A list of NumPy arrays, where each array represents the constant + value for a corresponding function argument. + + Returns: + ir.Module: The modified MLIR module with the function arguments replaced + by constants. + """ + main_func = module.body.operations[0] + with module.context as ctx, ir.Location.unknown(ctx): + entry_block = main_func.body.blocks[0] + with ir.InsertionPoint.at_block_begin(entry_block): + # Replace function arguments with constants for the input values + for input in inputs: + const_op = stablehlo_dialect.ConstantOp(ir.DenseElementsAttr.get(input)) + entry_block.arguments[0].replace_all_uses_with(const_op.result) + entry_block.erase_argument(0) + + # Update the type of the entry function. + main_ftype = ir.FunctionType.get([], main_func.type.results) + main_func.function_type = ir.TypeAttr.get(main_ftype) + + return module + + +def is_testdata_format(module: ir.Module) -> bool: + """Checks if an MLIR module has one function with no arguments and contains a check operation. + + Args: + module: The MLIR module to be verified. + + Raises: + AssertionError: If the module fails on any of the above criterias. + """ + functions = [ + op for op in module.body.operations if isinstance(op, func_dialect.FuncOp) + ] + if len(functions) != 1: + func_names = [func.name for func in functions] + raise AssertionError( + "Testdata format expected to have module with one function, but got" + f" {func_names}." + ) + + main_func = functions[0] + if len(main_func.body.blocks) != 1: + raise AssertionError( + "Testdata format expected to have the main function with a single" + f" block, but got {len(main_func.body.blocks)} blocks." + ) + + check_op_available = any( + _is_check_op(op.operation) for op in main_func.body.blocks[0].operations + ) + return check_op_available and not main_func.type.inputs + + +def to_testdata_format( + module: ir.Module, + inputs: Sequence[np.ndarray], + golden_results: Sequence[np.ndarray], +) -> ir.Module: + """Transforms an MLIR module to testdata format. + + Transforms `module` with the following modfications: + - Removes Function Arguments: It eliminates the function arguments from the + function definition. + - Introduces Constant Operations: It replaces those arguments with + stablehlo.constant operations on input values `inputs`. + - Maintains operation under test: The operation under test of the function + remains unchanged. It now takes the newly created constant values as + operands. + - Adds Check Operation: A new stablehlo.custom_call operation with a + call_target_name of "check.eq" is added to the function body. + This operation compares the result of the operation under test with the + constant value `golden_results`. The result of this comparison is a tensor + of booleans indicating where the results match the expected values. + - Adjusts Function Return: The function's return type is updated to + reflect the fact that it now returns the result of stablehlo.custom_call + operation. + + Args: + module: The MLIR module containing the stablehlo operations. + inputs: A list of NumPy arrays representing the input data. + golden_results: A list of NumPy arrays representing the expected outputs + of executing `module` on `inputs`. + + Returns: + A new MLIR module with with transformation mentioned above. + + Example Input: + + - `module`: + module { + func.func public @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> + tensor<2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> + return %0 : tensor<2xf32> + } + } + - inputs`: [np.array([1.0, 2.0]), np.array([3.0, 4.0])] + - golden_results: [np.array([4.0, 6.0])] + + Example Output: + module { + func.func @add_op_test_f32() -> tensor<2xxi1> { + %0 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32> + %1 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32> + %2 = stablehlo.add %0, %1 : tensor<2xf32> + %3 = stablehlo.custom_call("check.eq", %2, dense<[4.0, 6.0]> : + tensor<2xi1> + func.return %3 : tensor2xi1<> + } + } + """ + + with module.context as ctx, ir.Location.unknown(ctx): + entry_block = module.body.operations[0].body.blocks[0] + with ir.InsertionPoint.at_block_begin(entry_block): + # Replace function arguments with constants for the input values + for input in inputs: + const_op = stablehlo_dialect.ConstantOp(ir.DenseElementsAttr.get(input)) + entry_block.arguments[0].replace_all_uses_with(const_op.result) + entry_block.erase_argument(0) + + # Create constant ops for golden results + golden_result_constants = [ + stablehlo_dialect.ConstantOp(ir.DenseElementsAttr.get(golden_result)) + for golden_result in golden_results + ] + + # Find the original return operation + return_op = entry_block.operations[len(entry_block.operations) - 1] + if not isinstance(return_op, func_dialect.ReturnOp): + raise AssertionError( + "Expects the last operation in function block to be a return op, but" + f" got: {return_op}" + ) + return_operands = return_op.operands + + # Insert check operations at the end of the block, just before the return + with ir.InsertionPoint.at_block_terminator(entry_block): + check_ops = [] + for idx, operand in enumerate(return_operands): + custom_call = stablehlo_dialect.CustomCallOp( + [ir.RankedTensorType.get([], ir.IntegerType.get_signless(1))], + [golden_result_constants[idx].result, operand], + call_target_name="check.eq", + ) + check_ops.append(custom_call.result) + + # Replace the original return with the check results + new_return_op = func_dialect.ReturnOp(check_ops) + return_op.erase() + + # Update the function's type to reflect the new return values + main_ftype = ir.FunctionType.get( + [], [check_op.type for check_op in check_ops] + ) + module.body.operations[0].function_type = ir.TypeAttr.get(main_ftype) + + # Verify the module is valid and in testdata format + assert module.operation.verify() + assert is_testdata_format(module) + + return module + + +def from_testdata_format(module: ir.Module) -> ir.Module: + """Transforms `module` in testdata format. + + Transforms `module` with the following modfications: + - The constants in the `module` will be replaced with arguments of the main + function in `module`. + - The transformed module returns the results of the the operations under + test. + + Args: + module: The original MLIR module containing the `op_under_test`. + op_under_test: An operation in `module`. + + Returns: + A new MLIR module with transformation mentioned above. + + Example Input: + + - `module`: + module { + func.func @add_op_test_f32() { + %0 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32> + %1 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32> + %2 = stablehlo.add %0, %1 : tensor<2xf32> + check.expect_eq_const %2, dense<[4.0, 6.0]> : tensor<2xf32> + func.return + } + } + - `op_under_test`: %2 = stablehlo.add %0, %1 : tensor<2xf32> + + Example Output: + + module { + func.func public @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> + tensor<2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> + return %0 : tensor<2xf32> + } + } + """ + + if not is_testdata_format(module): + return module + + with module.context as ctx, ir.Location.unknown(ctx) as loc: + main_func = module.body.operations[0] + module_ops = main_func.body.blocks[0].operations + + # Extract constant ops used as inputs for the program (not for check ops) + constant_ops_for_program_inputs = [ + op + for op in module_ops + if isinstance(op, stablehlo_dialect.ConstantOp) + and any(not _is_check_op(use.owner) for use in op.result.uses) + ] + + # Extract check operations + check_ops = [op for op in module_ops if _is_check_op(op)] + + # Extract constant ops feeding into check ops + constant_ops_feeding_to_check_ops = [ + op + for op in module_ops + if isinstance(op, stablehlo_dialect.ConstantOp) + and all(_is_check_op(use.owner) for use in op.result.uses) + ] + + # Extract non-constant ops feeding into check ops + non_constant_ops_feeding_to_check_ops = [ + operand.owner + for check_op in check_ops + for operand in check_op.operands + if not isinstance(operand.owner.opview, stablehlo_dialect.ConstantOp) + ] + + # Remove unused ops that feed into check ops + ops_to_remove = constant_ops_feeding_to_check_ops + check_ops + + # Update the main function's type based on inputs and original output types + input_types = [ + value.result.type for value in constant_ops_for_program_inputs + ] + result_types = [ + check_op.operands[0].owner.result.type for check_op in check_ops + ] + + # Update the function signature with the derived input and result types + main_ftype = ir.FunctionType.get(input_types, result_types) + main_func.function_type = ir.TypeAttr.get(main_ftype) + + # Replace constants with function arguments + entry_block = main_func.body.blocks[0] + for idx, constant_op in enumerate(constant_ops_for_program_inputs): + arg = entry_block.add_argument(constant_op.result.type, loc) + constant_op.result.replace_all_uses_with(arg) + constant_op.erase() + + # Validate that the last operation is a return operation + return_op = entry_block.operations[len(entry_block.operations) - 1] + if not isinstance(return_op, func_dialect.ReturnOp): + raise AssertionError( + "Expects the last operation in function block to be a return op, but" + f" got: {return_op}" + ) + + # Remove unused ops in reverse order to avoid invalidating indices + ops_to_remove.append(return_op) + ops_to_remove.reverse() + [op.erase() for op in ops_to_remove] + + # Update the return statement to return the results of the non-constant operations + with ir.InsertionPoint(entry_block): + func_dialect.ReturnOp(non_constant_ops_feeding_to_check_ops) + + # Verify that the module is valid after the transformations + assert module.operation.verify() + + return module + + +def preprocess_input_module(module: ir.Module): + """Preprocesses a StableHLO module in testdata format. + + This function performs the following key steps: + + 1. Extracts the operation under test and its input values as NumPy arrays. + 2. Transforms the module to isolate the operation under test and make its + operands into function arguments. + + Args: + module: MLIR module in text format. + + Returns: + A tuple containing: + - The preprocessed MLIR module with the isolated operation under test. + - A list of NumPy arrays representing the input values for the operation. + + Raises: + AssertionError: If the module structure fails to comply with testdata + format. + """ + inputs = _extract_testdata_inputs(module) + module = from_testdata_format(module) + return module, inputs diff --git a/stablehlo/integrations/python/tests/CMakeLists.txt b/stablehlo/integrations/python/tests/CMakeLists.txt index 8661662ae96..ae4e0074d07 100644 --- a/stablehlo/integrations/python/tests/CMakeLists.txt +++ b/stablehlo/integrations/python/tests/CMakeLists.txt @@ -29,6 +29,7 @@ add_stablehlo_python_test(stablehlo-python-chlo chlo.py) add_stablehlo_python_test(stablehlo-python-smoketest smoketest.py) add_stablehlo_python_test(stablehlo-python-stablehlo stablehlo.py) add_stablehlo_python_test(stablehlo-python-vhlo vhlo.py) +add_stablehlo_python_test(stablehlo-python-testdata-generator testdata_generator_test.py) if(STABLEHLO_ENABLE_PYTHON_TF_TESTS) add_stablehlo_python_test(stablehlo-python-stablehlo-to-saved-model stablehlo_to_tf_saved_model_test.py) diff --git a/stablehlo/integrations/python/tests/testdata_generator_test.py b/stablehlo/integrations/python/tests/testdata_generator_test.py new file mode 100644 index 00000000000..bb08e97f814 --- /dev/null +++ b/stablehlo/integrations/python/tests/testdata_generator_test.py @@ -0,0 +1,212 @@ +# Copyright 2024 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mlir.dialects.check as check_dialect +import mlir.dialects.stablehlo as stablehlo_dialect +import mlir.ir as ir +from mlir.stablehlo.testdata_generator.testdata_generator_lib import testdata_generator +import numpy as np + + +MODULE_STR = """ +module { + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> + return %0 : tensor<2xf32> + } +} +""" + +MODULE_STR_IN_TESTDATA_FORMAT_WITH_CHECK_OPS = """ +module { + func.func @main() { + %cst = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + check.expect_eq_const %0, dense<[4.0, 6.0]> : tensor<2xf32> + return + } +} +""" + +MODULE_STR_IN_TESTDATA_FORMAT_WITH_CUSTOM_CALL_AS_CHECK_OPS = """ +module { + func.func @main() -> tensor<2xi1> { + %cst = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32> + %golden = stablehlo.constant dense<[4.0, 6.0]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%0, %golden) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + return %1: tensor<2xi1> + } +} +""" + +MODULE_STR_IN_TESTDATA_FORMAT_WITH_SHARED_USE_OF_CONSTANT_OP = """ +module { + func.func @main() -> tensor<2xi1> { + %cst = stablehlo.constant dense<0.0> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%0, %cst) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + return %1 : tensor<2xi1> + } +} +""" + +MODULE_STR_IN_TESTDATA_FORMAT_MUTI_OP_UNDER_TEST = """ +module { + func.func @main(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<2xf32> + %1 = stablehlo.add %0, %0 : tensor<2xf32> + return %1: tensor<2xf32> + } +} +""" + +EXPECTED_TESTDATA_MODULE_STR_1 = """ +module { + func.func @main() -> tensor { + %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + %cst_1 = stablehlo.constant dense<[4.000000e+00, 6.000000e+00]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%cst_1, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor + return %1 : tensor + } +} +""" + +EXPECTED_TESTDATA_MODULE_STR_2 = """ +module { + func.func @main() -> tensor { + %cst = stablehlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.000000e+01, 4.000000e+01]> : tensor<2xf32> + %cst_1 = stablehlo.constant dense<[4.000000e+01, 6.000000e+01]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%cst_1, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor + return %1 : tensor + } +} +""" + +EXPECTED_TESTDATA_MODULE_STR_3 = """ +module { + func.func @main() -> tensor { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst : tensor<2xf32> + %1 = stablehlo.custom_call @check.eq(%cst_0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor + return %1 : tensor + } +} +""" + +EXPECTED_TESTDATA_MODULE_STR_4 = """ +module { + func.func @main() -> tensor { + %cst = stablehlo.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %cst_0 = stablehlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32> + %cst_1 = stablehlo.constant dense<[8.000000e+00, 1.200000e+01]> : tensor<2xf32> + %0 = stablehlo.add %cst, %cst_0 : tensor<2xf32> + %1 = stablehlo.add %0, %0 : tensor<2xf32> + %2 = stablehlo.custom_call @check.eq(%cst_1, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor + return %2 : tensor + } +} +""" + + +def test_testdata_generator( + module_str: str, args: list, expected_output_str: str +) -> None: + """Parses the MLIR module string, registers dialects, runs testdata_generator, + + and compares the output with the expected output. + """ + with ir.Context() as ctx: + stablehlo_dialect.register_dialect(ctx) + check_dialect.register_dialect(ctx) + module = ir.Module.parse(module_str) + result_module = testdata_generator(module, args) + expected_output_module = ir.Module.parse(expected_output_str) + if str(result_module) != str(expected_output_module): + raise AssertionError( + "Output mismatch:\n" + f"Expected:\n{expected_output_module}\n" + f"Got:\n{result_module}" + ) + + +# List of test cases as tuples: (input_module_str, args, expected_output_str) +test_cases = [ + ( + # Typical use-case. + MODULE_STR, + [ + np.array([1.0, 2.0], dtype=np.float32), + np.array([3.0, 4.0], dtype=np.float32), + ], + EXPECTED_TESTDATA_MODULE_STR_1, + ), + ( + # To test + # - Input programs already in testdata format + # - If no concrete inputs are provoded, they are derived from the embedded + # stablehlo.constant ops in the program. + MODULE_STR_IN_TESTDATA_FORMAT_WITH_CHECK_OPS, + [], + EXPECTED_TESTDATA_MODULE_STR_1, + ), + ( + # To test + # - Input programs already in testdata format + # - If concrete inputs are provoded, the embedded stablehlo.constant ops, + # feeding as program inputs, will be ignored. + MODULE_STR_IN_TESTDATA_FORMAT_WITH_CHECK_OPS, + [ + np.array([10.0, 20.0], dtype=np.float32), + np.array([30.0, 40.0], dtype=np.float32), + ], + EXPECTED_TESTDATA_MODULE_STR_2, + ), + ( + # To test + # - Input programs already in testdata format + # - Usage of custom_call ops as check ops. + MODULE_STR_IN_TESTDATA_FORMAT_WITH_CUSTOM_CALL_AS_CHECK_OPS, + [], + EXPECTED_TESTDATA_MODULE_STR_1, + ), + ( + # To test + # - Input programs already in testdata format + # - Proper identification of the constants feeding to program input and + # check ops. + MODULE_STR_IN_TESTDATA_FORMAT_WITH_SHARED_USE_OF_CONSTANT_OP, + [], + EXPECTED_TESTDATA_MODULE_STR_3, + ), + ( + # To test handling of programs with multiple operations. + MODULE_STR_IN_TESTDATA_FORMAT_MUTI_OP_UNDER_TEST, + [ + np.array([1.0, 2.0], dtype=np.float32), + np.array([3.0, 4.0], dtype=np.float32), + ], + EXPECTED_TESTDATA_MODULE_STR_4, + ), +] + +for module_str, args, expected_output_str in test_cases: + test_testdata_generator(module_str, args, expected_output_str)