From cc9c43bd4a8f5ba4eb2fdeca2d8d0c97573a9fad Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Mon, 27 May 2024 20:22:34 -0600 Subject: [PATCH] PTQ readme --- .../ml/resnet/ptq_conv2x/CMakeLists.txt | 89 ------------------- .../ml/resnet/ptq_conv2x/Makefile | 2 +- .../ml/resnet/ptq_conv2x/README.md | 76 ++++++++++++++++ .../ml/resnet/ptq_conv2x/test.py | 16 ++-- 4 files changed, 85 insertions(+), 98 deletions(-) delete mode 100755 programming_examples/ml/resnet/ptq_conv2x/CMakeLists.txt create mode 100644 programming_examples/ml/resnet/ptq_conv2x/README.md diff --git a/programming_examples/ml/resnet/ptq_conv2x/CMakeLists.txt b/programming_examples/ml/resnet/ptq_conv2x/CMakeLists.txt deleted file mode 100755 index c7db0e9c5c..0000000000 --- a/programming_examples/ml/resnet/ptq_conv2x/CMakeLists.txt +++ /dev/null @@ -1,89 +0,0 @@ -# This file is licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# -# Copyright (C) 2024, Advanced Micro Devices, Inc. - -# parameters -# -DBOOST_ROOT: Path to Boost install -# -DOpenCV_DIR: Path to OpenCV install -# -DXRT_INC_DIR: Full path to src/runtime_src/core/include in XRT cloned repo -# -DXRT_LIB_DIR: Path to xrt_coreutil.lib -# -DTARGET_NAME: Target name to be built - -# cmake needs this line -cmake_minimum_required(VERSION 3.1) - -find_program(WSL NAMES powershell.exe) - -if (NOT WSL) - set(BOOST_ROOT /usr/include/boost CACHE STRING "Path to Boost install") - set(OpenCV_DIR /usr/include/opencv4 CACHE STRING "Path to OpenCV install") - set(XRT_INC_DIR /opt/xilinx/xrt/include CACHE STRING "Path to XRT cloned repo") - set(XRT_LIB_DIR /opt/xilinx/xrt/lib CACHE STRING "Path to xrt_coreutil.lib") -else() - set(BOOST_ROOT C:/Technical/thirdParty/boost_1_83_0 CACHE STRING "Path to Boost install") - set(OpenCV_DIR C:/Technical/thirdParty/opencv/build CACHE STRING "Path to OpenCV install") - set(XRT_INC_DIR C:/Technical/XRT/src/runtime_src/core/include CACHE STRING "Path to XRT cloned repo") - set(XRT_LIB_DIR C:/Technical/xrtNPUfromDLL CACHE STRING "Path to xrt_coreutil.lib") -endif () - -set(EDGEDETECT_WIDTH 1920 CACHE STRING "image width") -set(EDGEDETECT_HEIGHT 1080 CACHE STRING "image height") - -set(TARGET_NAME test CACHE STRING "Target to be built") - -SET (ProjectName ${TARGET_NAME}) -SET (currentTarget ${TARGET_NAME}) - -if ( WSL ) - set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_BINARY_DIR}) -endif () - -project(${ProjectName}) - -# Find packages -find_package(Boost REQUIRED) -find_package(OpenCV REQUIRED) -message("opencv library paht: ${OpenCV_LIB_PATH}") -message("opencv libs: ${OpenCV_LIBS}") - - -add_executable(${currentTarget} - ${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/OpenCVUtils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/xrtUtils.cpp - test.cpp -) - -target_compile_definitions(${currentTarget} PUBLIC - EDGEDETECT_WIDTH=${EDGEDETECT_WIDTH} - EDGEDETECT_HEIGHT=${EDGEDETECT_HEIGHT} - DISABLE_ABI_CHECK=1 - ) - -target_include_directories (${currentTarget} PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../../utils - ${XRT_INC_DIR} - ${OpenCV_INCLUDE_DIRS} - ${Boost_INCLUDE_DIRS} -) - -target_link_directories(${currentTarget} PUBLIC - ${XRT_LIB_DIR} - ${OpenCV_LIB_PATH} - ${Boost_LIBRARY_DIRS} -) - -if (NOT WSL) - target_link_libraries(${currentTarget} PUBLIC - xrt_coreutil - ${OpenCV_LIBS} - boost_program_options - boost_filesystem - ) -else() - target_link_libraries(${currentTarget} PUBLIC - xrt_coreutil - ${OpenCV_LIBS} - ) -endif() diff --git a/programming_examples/ml/resnet/ptq_conv2x/Makefile b/programming_examples/ml/resnet/ptq_conv2x/Makefile index 4b40c07da9..79e443d308 100755 --- a/programming_examples/ml/resnet/ptq_conv2x/Makefile +++ b/programming_examples/ml/resnet/ptq_conv2x/Makefile @@ -43,7 +43,7 @@ build/final.xclbin: build/${mlirFileName}.mlir build/conv2dk1_i8.o build/conv2dk clean: rm -rf build/*.elf* build/*.lst build/*.bif log* build/${mlirFileName}.mlir.prj build/*.xclbin sim \ build/chess* build/insts.txt \ - build/*.log build/aie_partition.json build/*.bin build/BOOT.BIN _x test.exe + build/*.log build/aie_partition.json build/*.bin build/BOOT.BIN _x run_py: ${powershell} python3 ${srcdir}/test.py -x build/final.xclbin -i build/insts.txt -k MLIR_AIE diff --git a/programming_examples/ml/resnet/ptq_conv2x/README.md b/programming_examples/ml/resnet/ptq_conv2x/README.md new file mode 100644 index 0000000000..f1e68a3b3e --- /dev/null +++ b/programming_examples/ml/resnet/ptq_conv2x/README.md @@ -0,0 +1,76 @@ + + +# ResNet with Offloaded Conv2_x Layers and Post-Training Quantization + +Quantization involves reducing the precision of the weights and activations of a neural network from floating-point (e.g., 32-bit float) to lower bit-width formats (e.g., 8-bit integers). Quantization reduces model size and speeds up inference, making it more suitable for deployment on resource-constrained devices. In AI Engine (AIE), we use a power-of-two scale factor to set up the SRS to shift and scale the values to the integer range. A power of two is a number of the form 2^n, where n is an integer. Power-of-two scale factors can lead to more efficient hardware implementations, as multiplication by a power of two can be performed using bit shifts rather than more complex multiplication operations. + +[Brevitas](https://github.com/Xilinx/brevitas) is a PyTorch-based library designed for quantization of neural networks. It enables users to train models with reduced numerical precision, typically using lower bit widths for weights and activations, which can lead to significant improvements in computational efficiency and memory usage. Brevitas supports various quantization schemes, including uniform and non-uniform quantization, and can be used to target a wide range of hardware platforms, including FPGAs, ASICs, and CPUs. We use Brevitas to: +1. Quantize weights and activations of a model to lower bit format for AIE deployment, and +2. Extract proper power-of-two scale factors to set up the SRS unit. + +## Source Files Overview + +``` +. ++-- ptq_conv2x # Implementation of ResNet conv2_x layers on NPU with PTQ ++-- +-- data # Labels for CIFAR dataset. +| +-- aie2.py # A Python script that defines the AIE array structural design using MLIR-AIE operations. +| +-- Makefile # Contains instructions for building and compiling software projects. +| +-- model.py # Python code for ResNet Model where we apply PTQ. +| +-- README.md # This file. +| +-- requirements.txt # pip requirements to perform PTQ. +| +-- run_makefile.lit # For LLVM Integrated Tester (LIT) of the design. +| +-- test.py # Python code testbench for the design example. +| +-- utils.py # Python code for miscellaneous functions needed for inference. + + +``` + +# Post-Training Quantization Using Brevitas +To enhance the efficiency of our implementation, we perform post-training quantization on the model using the Brevitas library. This step converts the model to use 8-bit weights and power-of-two scale factors, optimizing it for deployment on hardware with limited precision requirements. + + +## Step-by-Step Process +We use test.py to: + +**1. Loading the Pre-trained ResNet Model**: The script begins by loading a pre-trained ResNet model, which serves as the baseline for quantization and inference. + +**2. Applying Post-Training Quantization (PTQ)**: Using the Brevitas library, the script applies PTQ to the conv2_x layers of the ResNet model. This involves converting the weights and activations to 8-bit precision. + +**3. Extracting Power-of-Two Scale Factors**: After quantizing the weights and activations, the script extracts the power-of-two scale factors. These factors are crucial for efficient hardware implementation, as they simplify multiplication operations to bit shifts. + +**4. Calculating Combined Scales**: The combined scale factors are calculated by multiplying the extracted weight and activation scales for each layer. These combined scales are then used to set up the SRS unit. + +**5. Setting Up the SRS Unit**: +The SRS unit uses the calculated combined scales to efficiently shift and scale the values to the integer range required for the NPU. + +**6. Running Inference**: Finally, the script runs inference on the quantized model. The conv2_x layers are offloaded to the NPU, utilizing the SRS unit to scale the quantized weights and activations to the int8 range properly. + +# Compilation and Execution + +## Prerequisites +Ensure you have the necessary dependencies installed. You can install the required packages using: + +``` +pip install -r requirements.txt +``` +## Compilation +To compile the design: +``` +make +``` + +## Running the Design + +To run the design: +``` +make run_py +``` diff --git a/programming_examples/ml/resnet/ptq_conv2x/test.py b/programming_examples/ml/resnet/ptq_conv2x/test.py index e9278132b5..67257625a9 100755 --- a/programming_examples/ml/resnet/ptq_conv2x/test.py +++ b/programming_examples/ml/resnet/ptq_conv2x/test.py @@ -70,6 +70,7 @@ def main(opts): # ------------------------------------------------------ # Post training quantization to get int8 weights and activation for AIE # ------------------------------------------------------ + # Step 1: Load the pre-trained ResNet model num_classes = 10 model = res.Resnet50_conv2x_offload(num_classes) weights = "trained_resnet50/weight.tar" # trained FP model @@ -109,7 +110,7 @@ def main(opts): root=data_dir, train=False, transform=transform_test, download=True ) - # Data loader + # Data loader for calibration indices = torch.arange(256) tr_sub = data_utils.Subset(train_dataset, indices) val_sub = data_utils.Subset(test_dataset, indices) @@ -119,6 +120,8 @@ def main(opts): val_loader = torch.utils.data.DataLoader( dataset=val_sub, batch_size=64, shuffle=False ) + + # Step 2: Apply quantization to the conv2_x layers to convert weights to 8-bit precision img_shape = 32 model_aie = preprocess_for_flexml_quantize( model.aie, @@ -131,7 +134,7 @@ def main(opts): quant_model = quantize_model( model_aie, backend="flexml", - scale_factor_type="po2_scale", + scale_factor_type="po2_scale", # Ensuring scale factors are powers of two bias_bit_width=32, weight_bit_width=8, weight_narrow_range=False, @@ -165,29 +168,25 @@ def main(opts): from numpy import load + # Extracting quantized weights and scale factors params = {} weights = {} for name, module in model.named_modules(): if isinstance(module, QuantConv2d): - # print(name) - # print(module.quant_weight().scale) weights[name + ".int_weight"] = module.quant_weight().int( float_datatype=False ) params[name + "_scale"] = module.quant_weight().scale.detach().numpy() if isinstance(module, QuantIdentity): - # print(name) - # print(module.quant_act_scale()) params[name + "_scale"] = module.quant_act_scale() if isinstance(module, QuantReLU): - # print(name) - # print(module.quant_act_scale()) params[name + "_scale"] = module.quant_act_scale() np.savez(os.path.join(os.getcwd(), "int_weights.npz"), **weights) np.savez(os.path.join(os.getcwd(), "int_conv_scale.npz"), **params) int_wts_data = load("int_weights.npz", allow_pickle=True) int_scale_data = load("int_conv_scale.npz", allow_pickle=True) + # Loading weights and scales int_wts_data_lst = int_wts_data.files block_0_int_weight_1 = torch.from_numpy(int_wts_data["aie.layer1.conv1.int_weight"]) block_0_int_weight_2 = torch.from_numpy(int_wts_data["aie.layer1.conv2.int_weight"]) @@ -239,6 +238,7 @@ def main(opts): if name.endswith(".bias"): param.data.fill_(0) + # Calculate combined scales block_0_combined_scale1 = -math.log( init_scale * block_0_weight_scale_1 / block_0_relu_1, 2 ) # after conv1x1