diff --git a/programming_examples/basic/README.md b/programming_examples/basic/README.md
index bfe8a881ef..699dcc331c 100644
--- a/programming_examples/basic/README.md
+++ b/programming_examples/basic/README.md
@@ -14,10 +14,14 @@ These programming examples provide a good starting point to illustrate how to bu
* [Passthrough DMAs](./passthrough_dmas) - This design demonstrates data movement to implement a memcpy operation using object FIFOs just using DMAs without involving the AIE core.
* [Passthrough Kernel](./passthrough_kernel) - This design demonstrates a simple AIE implementation for vectorized memcpy on a vector of integer involving AIE core kernel programming.
+* [DMA Transpose](./dma_transpose) - Transposes a matrix with the Shim DMA using `npu_dma_memcpy_nd`
* [Vector Scalar Add](./vector_scalar_add) - Single tile performs a very simple `+` operation where the kernel loads data from local memory, increments the value by `1` and stores it back.
* [Vector Scalar Mul](./vector_scalar_mul) - Single tile performs `vector * scalar` of size `4096`. The kernel does a `1024` vector multiply and is invoked multiple times to complete the full `vector * scalar` compute.
+* [Vector Vector Add](./vector_vector_add) - Single tile performs `vector + vector` of size `1024`.
+* [Vector Vector Multiply](./vector_vector_mul) - Single tile performs `vector * vector` of size `1024`.
* [Vector Reduce Add](./vector_reduce_add) - Single tile performs a reduction of a vector to return the `sum` of the elements.
* [Vector Reduce Max](./vector_reduce_max) - Single tile performs a reduction of a vector to return the `max` of the elements.
* [Vector Reduce Min](./vector_reduce_min) - Single tile performs a reduction of a vector to return the `min` of the elements.
* [Vector Exp](./vector_exp) - A simple element-wise exponent function, using the look up table capabilities of the AI Engine.
+* [Matrix Scalar Add](./matrix_scalar_add) - Single tile performs `matrix * vector` with matrix size of `16x8`.
* [Matrix Multiplication](./matrix_multiplication) - This directory contains multiple designs spanning: single core and multi-core (whole array) matrix-matrix multiplication, and matrix-vector multiplication designs. It also contains sweep infrastructure for benchmarking.
\ No newline at end of file
diff --git a/programming_examples/ml/README.md b/programming_examples/ml/README.md
new file mode 100644
index 0000000000..f9525e3e44
--- /dev/null
+++ b/programming_examples/ml/README.md
@@ -0,0 +1,23 @@
+
+
+# Machine Learning Examples
+
+| Design name | Data type | Description |
+|-|-|-|
+| [Eltwise Add](../../programming_examples/ml/eltwise_add/) | bfloat16 | An element by element addition of two vectors |
+| [Eltwise Mul](../../programming_examples/ml/eltwise_mul/) | i32 | An element by element multiplication of two vectors |
+| [ReLU](../../programming_examples/ml/relu/) | bfloat16 | Rectified linear unit (ReLU) activation function on a vector|
+| [Softmax](../../programming_examples/ml/softmax/) | bfloat16 | Softmax operation on a matrix |
+| [Conv2D](../../programming_examples/ml/conv2d) | i8 | A single core 2D convolution for CNNs |
+| [Conv2D+ReLU](../../programming_examples/ml/conv2d_fused_relu) | i8 | A Conv2D with a ReLU fused at the vector register level |
+|[Bottleneck](../../programming_examples/ml/bottleneck/)|ui8|A Bottleneck Residual Block is a variant of the residual block that utilizes three convolutions, using 1x1, 3x3, and 1x1 filter sizes, respectively. The implementation features fusing of multiple kernels and dataflow optimizations, highlighting the unique architectural capabilities of AI Engines|
+|[ResNet](../../programming_examples/ml/resnet/)|ui8|ResNet with offloaded conv2_x layers. The implementation features depth-first implementation of multiple bottleneck blocks across multiple NPU columns.|
+
diff --git a/programming_examples/ml/weight_expand/CMakeLists.txt b/programming_examples/ml/weight_expand/CMakeLists.txt
deleted file mode 100644
index 20f5d8a4a3..0000000000
--- a/programming_examples/ml/weight_expand/CMakeLists.txt
+++ /dev/null
@@ -1,75 +0,0 @@
-# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#
-# (c) Copyright 2023 Advanced Micro Devices, Inc.
-
-# parameters
-# -DBOOST_ROOT: Path to Boost install
-# -DXRT_INC_DIR: Full path to src/runtime_src/core/include in XRT cloned repo
-# -DXRT_LIB_DIR: Path to xrt_coreutil.lib
-# -DTARGET_NAME: Target name to be built
-
-# cmake needs this line
-cmake_minimum_required(VERSION 3.1)
-
-set(CMAKE_CXX_STANDARD 23)
-set(CMAKE_CXX_STANDARD_REQUIRED YES)
-
-find_program(WSL NAMES powershell.exe)
-
-if (NOT WSL)
- set(CMAKE_C_COMPILER gcc-13)
- set(CMAKE_CXX_COMPILER g++-13)
- set(BOOST_ROOT /usr/include/boost CACHE STRING "Path to Boost install")
- set(XRT_INC_DIR /opt/xilinx/xrt/include CACHE STRING "Path to XRT cloned repo")
- set(XRT_LIB_DIR /opt/xilinx/xrt/lib CACHE STRING "Path to xrt_coreutil.lib")
-else()
- set(BOOST_ROOT C:/Technical/thirdParty/boost_1_83_0 CACHE STRING "Path to Boost install")
- set(XRT_INC_DIR C:/Technical/XRT/src/runtime_src/core/include CACHE STRING "Path to XRT cloned repo")
- set(XRT_LIB_DIR C:/Technical/xrtNPUfromDLL CACHE STRING "Path to xrt_coreutil.lib")
-endif()
-
-set(TARGET_NAME test CACHE STRING "Target to be built")
-
-SET (ProjectName ${TARGET_NAME})
-SET (currentTarget ${TARGET_NAME})
-
-if ( WSL )
- set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${CMAKE_BINARY_DIR})
-endif ()
-
-project(${ProjectName})
-
-# Find packages
-find_package(Boost REQUIRED)
-
-add_executable(${currentTarget}
- ${CMAKE_CURRENT_SOURCE_DIR}/../../../runtime_lib/test_lib/test_utils.cpp
- test.cpp
-)
-
-target_compile_definitions(${currentTarget} PUBLIC DISABLE_ABI_CHECK=1)
-
-target_include_directories (${currentTarget} PUBLIC
- ${XRT_INC_DIR}
- ${Boost_INCLUDE_DIRS}
- ${CMAKE_CURRENT_SOURCE_DIR}/../../../runtime_lib/test_lib
-)
-
-target_link_directories(${currentTarget} PUBLIC
- ${XRT_LIB_DIR}
- ${Boost_LIBRARY_DIRS}
-)
-
-if (NOT WSL)
- target_link_libraries(${currentTarget} PUBLIC
- xrt_coreutil
- boost_program_options
- boost_filesystem
- )
-else()
- target_link_libraries(${currentTarget} PUBLIC
- xrt_coreutil
- )
-endif()
diff --git a/programming_examples/ml/weight_expand/Makefile b/programming_examples/ml/weight_expand/Makefile
deleted file mode 100755
index b4967596fb..0000000000
--- a/programming_examples/ml/weight_expand/Makefile
+++ /dev/null
@@ -1,50 +0,0 @@
-##===- Makefile -----------------------------------------------------------===##
-#
-# This file licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#
-##===----------------------------------------------------------------------===##
-
-include ../../makefile-common
-
-all: build/final.xclbin build/insts.txt ${targetname}.exe
-
-targetname = expand
-
-build/%.o: %.cc
- mkdir -p ${@D}
- cd ${@D} && xchesscc_wrapper ${CHESSCCWRAP2_FLAGS} -DBIT_WIDTH=8 -c $(<:%=../%) -o ${@F}
-
-build/aie.mlir: aie2.py
- mkdir -p ${@D}
- python3 $< > $@
-
-build/final.xclbin: build/aie.mlir build/expand.o
- mkdir -p ${@D}
- cd ${@D} && aiecc.py --aie-generate-cdo --no-compile-host --xclbin-name=${@F} \
- --aie-generate-npu --npu-insts-name=insts.txt $(<:%=../%)
-
-${targetname}.exe: test.cpp
- rm -rf _build
- mkdir -p _build
- cd _build && ${powershell} cmake .. -DTARGET_NAME=${targetname}
- cd _build && ${powershell} cmake --build . --config Release
-ifeq "${powershell}" "powershell.exe"
- cp _build/${targetname}.exe $@
-else
- cp _build/${targetname} $@
-endif
-
-run: ${targetname}.exe build/final.xclbin build/insts.txt
- ${powershell} ./$< -x build/final.xclbin -i build/insts.txt -k MLIR_AIE
-
-trace:
- ../../utils/parse_eventIR.py --filename trace.txt --mlir build/aie.mlir --colshift 1 > parse_eventIR_vs.json
-
-clean_trace:
- rm -rf tmpTrace trace.txt
-
-clean: clean_trace
- rm -rf build _build ${targetname}.exe
-
diff --git a/programming_examples/ml/weight_expand/README.md b/programming_examples/ml/weight_expand/README.md
deleted file mode 100644
index 7f4b2e4f76..0000000000
--- a/programming_examples/ml/weight_expand/README.md
+++ /dev/null
@@ -1,74 +0,0 @@
-
-
-
-# int4 -> bfloat16 dequantization
-
-This IRON design flow example is a dequantization kernel which converts from `signed int4` weights, **32** of which share a single `bfloat16` scale factor.
-
-The conversion is to take each `signed int4` weight, and multiply it by the scale factor, giving a vector with 32 elements of bfloat16 numbers. This vector can then be used by another kernel, either running on the same core, or on a different core to do a high-precision operator such as bfloat16 based GEMM or GEMV. The main use model is for generative AI where the loading of the parameters during the generation phase is the limiting factor.
-
-Though other configurations are possible, the design example has a memory layout consisting of **1024** `signed int4` weights followed by **32** `bfloat16` scale factors, meaning the tile to be input is **576 bytes**, or **144 int32 words** in size.
-
-![Memory layout](memory.png?raw=true "Memory layout")
-
-The example consists of two primary design files: `aie2.py` and `scale.cc`, and a testbench `test.cpp`.
-
-## Overview
-
-1. `aie2.py`: A Python script that defines the AIE array structural design using MLIR-AIE operations. This generates MLIR that is then compiled using `aiecc.py` to produce design binaries (ie. XCLBIN and inst.txt for the NPU in Ryzen AI).
-
-1. `expand.cc`: A C++ implementation of vectorized dequantization operations for the AIE core.
-
-1. `test.cpp`: This C++ code is a testbench for the dequantization design example. The code is responsible for loading the compiled XCLBIN file, configuring the AIE module, providing input data, and executing the AIE design on the NPU. After executing, the script verifies the dequantized results.
-
-## Design Component Details
-
-### AIE Array Structural Design
-
-This design performs dequantization operations on a vector of input data. The AIE design is described in a python module as follows:
-
-1. **Constants & Configuration:** The script defines input dimensions (`N`, `n`), as well as the block size (the number of weights which share a scale factor)
-
-1. **AIE Device Definition:** `@device` defines the target device. The `device_body` function contains the AIE array design definition.
-
-1. **Dequantization Function Declarations:** `expand_int4_to_bfloat16` is an external function imported from `expand.cc`.
-
-1. **Tile Definitions:** `ShimTile` handles data movement, and `core0` processes the dequantization operations.
-
-1. **Object Fifos:** `inA` and `outB` are defined to facilitate communication between `ShimTile` and `core0`.
-
-1. **Core Definition:** The `core_body` function loops through sub-vectors of the input data, acquiring elements from `inA`, processing using `expand_int4_to_bfloat16`, and outputting the result to `outB`.
-
-1. **Data Movement Configuration:** The `sequence` function configures data movement and synchronization on the `ShimTile` for input and output buffer management.
-
-1. **Generate the design:** The `my_expand()` function triggers the code generation process. The final print statement outputs the MLIR representation of the AIE array configuration.
-
-### AIE Core Kernel Code
-
-`expand.cc` contains a C++ implementation of scalar and vectorized vector scaling operations designed for AIE cores. It consists of three main sections:
-
-1. **Vectorized dequantization:** The `expand()` function processes multiple data elements simultaneously, taking advantage of AIE vector datapath capabilities.
-
-1. **C-style Wrapper Functions:** `expand_int4_to_bfloat16()` is a C-style wrapper functions to call the `expand()` function from the AIE design implemented in `aie2.py`.
-
-## Usage
-
-To compile the design and testbench:
-
-```
-make all
-```
-
-To run the design:
-
-```
-make run
-```
\ No newline at end of file
diff --git a/programming_examples/ml/weight_expand/aie2.py b/programming_examples/ml/weight_expand/aie2.py
deleted file mode 100755
index 32fe95429f..0000000000
--- a/programming_examples/ml/weight_expand/aie2.py
+++ /dev/null
@@ -1,105 +0,0 @@
-#
-# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#
-# (c) Copyright 2023 AMD Inc.
-
-import sys
-
-from aie.dialects.aie import *
-from aie.dialects.aiex import *
-from aie.dialects.scf import *
-from aie.extras.context import mlir_mod_ctx
-
-
-def my_expand():
-
- SF_BLOCK_SIZE = 32
- word_size_in = 2
- sf_word_size_in = 2
- N = 65536
-
- N_in_bytes = (N // word_size_in) + (N / SF_BLOCK_SIZE) * sf_word_size_in
-
- A_sz_in_i32s = (N // 8) + (
- N // SF_BLOCK_SIZE
- ) // 2 # They are 4 bits per element, we need to add on the scale factors later though
- B_sz_in_i32s = N // 2 # Returning 16 bits at the moment
-
- # Tile sizes
- n = 1024
- block_size = 32
- sf_size = n // block_size
-
- input_buffer_size_bytes = (n // 2) + (
- sf_size * 2
- ) # They are bfloat16 sfs after the private values
- output_buffer_size_bytes = n * 2 # The unscaled values
-
- N_div_n = N // n
-
- n_cores = 1
- tiles = N_div_n // n_cores
- buffer_depth = 2
-
- with mlir_mod_ctx() as ctx:
-
- @device(AIEDevice.npu)
- def device_body():
- memRef_i_ty = T.memref(
- input_buffer_size_bytes, T.i8()
- ) # Just think of the input as a raw byte buffer
- memRef_o_ty = T.memref(output_buffer_size_bytes, T.i8()) # For now
-
- # AIE Core Function declarations
-
- expand_int4_to_bfloat16 = external_func(
- "expand_int4_to_bfloat16", inputs=[memRef_i_ty, memRef_o_ty]
- )
-
- # Tile declarations
- ShimTile = tile(0, 0)
-
- MemTile = tile(0, 1)
- core0 = tile(0, 2)
-
- # AIE-array data movement with object fifos
- # Input
- inA = object_fifo("inA", ShimTile, core0, buffer_depth, memRef_i_ty)
-
- # Output B
- outB = object_fifo("outB", core0, ShimTile, buffer_depth, memRef_o_ty)
-
- # Set up compute tiles
- @core(core0, "expand.o")
- def core_body():
- for _ in for_(0xFFFFFFFF):
- for _ in for_(tiles):
- elem_out = outB.acquire(ObjectFifoPort.Produce, 1)
- elem_in = inA.acquire(ObjectFifoPort.Consume, 1)
-
- call(expand_int4_to_bfloat16, [elem_in, elem_out])
- inA.release(ObjectFifoPort.Consume, 1)
- outB.release(ObjectFifoPort.Produce, 1)
- yield_([])
- yield_([])
-
- # To/from AIE-array data movement
- tensor_ty = T.memref(N, T.i32())
-
- @FuncOp.from_py_func(tensor_ty, tensor_ty)
- def sequence(A, C):
-
- npu_dma_memcpy_nd(
- metadata="outB", bd_id=0, mem=C, sizes=[1, 1, 1, B_sz_in_i32s]
- )
- npu_dma_memcpy_nd(
- metadata="inA", bd_id=1, mem=A, sizes=[1, 1, 1, A_sz_in_i32s]
- )
- npu_sync(column=0, row=0, direction=0, channel=0)
-
- print(ctx.module)
-
-
-my_expand()
diff --git a/programming_examples/ml/weight_expand/expand.cc b/programming_examples/ml/weight_expand/expand.cc
deleted file mode 100755
index 6083e535dc..0000000000
--- a/programming_examples/ml/weight_expand/expand.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-//===- scale.cc -------------------------------------------------*- C++ -*-===//
-//
-// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-// Copyright (C) 2023, Advanced Micro Devices, Inc.
-//
-//===----------------------------------------------------------------------===//
-
-#define __AIENGINE__ 2
-#define NOCPP
-#define __AIEARCH__ 20
-
-#include
-#include
-#include
-#include
-
-#include
-
-template
-void expand(T_in *in, T_out *out) {
-
- /*
- out[0] = 0x0;
- out[1] = 0x1;
- out[2] = 0x2;
- out[3] = 0x3;
-*/
- constexpr int vec_factor = 32;
- constexpr int sf_vec_factor = 8;
-
- event0();
- T_in *__restrict pI = in;
- T_in *__restrict pSFb =
- in + N / 2; // The scale factors are after the integer values
- T_sf *__restrict pSF =
- (T_sf *)pSFb; // But we only advance by the number of bytes not elements
- T_out *__restrict pO = out;
- const int F = N / (vec_factor * sf_vec_factor);
-
- for (int i = 0; i < F; i++)
- chess_prepare_for_pipelining chess_loop_range(16, ) {
-
- // Let's unroll this
- aie::vector sfV =
- aie::load_v(pSF); // For example
- pSF += sf_vec_factor;
-
- for (int k = 0; k < sf_vec_factor; k++)
- chess_unroll_loop(sf_vec_factor) {
- aie::vector I0 =
- aie::load_v(pI); // For example
- pI += vec_factor / 2;
-
- bfloat16 sf = sfV[k % sf_vec_factor];
-
- aie::vector sf_broadcast = aie::broadcast(sf);
-
- // Upsize these to 8 bits -> 16 -> bfloat16
- aie::vector asInt8 = aie::unpack(I0);
- aie::vector asInt16 = aie::unpack(asInt8);
- aie::vector as_bf16 =
- aie::to_float(asInt16, 0);
- aie::vector scaled_bf16 =
- aie::mul(as_bf16, sf_broadcast);
-
- aie::store_v(pO, scaled_bf16);
- pO += vec_factor;
- }
- }
- event1();
-}
-
-extern "C" {
-
-void expand_int4_to_bfloat16(int4 *a_in, bfloat16 *c_out) {
- expand(a_in, c_out);
-}
-
-} // extern "C"
diff --git a/programming_examples/ml/weight_expand/memory.png b/programming_examples/ml/weight_expand/memory.png
deleted file mode 100644
index 0e758119ae..0000000000
Binary files a/programming_examples/ml/weight_expand/memory.png and /dev/null differ
diff --git a/programming_examples/ml/weight_expand/test.cpp b/programming_examples/ml/weight_expand/test.cpp
deleted file mode 100644
index 7cf3c01fea..0000000000
--- a/programming_examples/ml/weight_expand/test.cpp
+++ /dev/null
@@ -1,329 +0,0 @@
-//===- test.cpp -------------------------------------------000---*- C++ -*-===//
-//
-// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-// Copyright (C) 2023, Advanced Micro Devices, Inc.
-//
-//===----------------------------------------------------------------------===//
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include "xrt/xrt_bo.h"
-#include "xrt/xrt_device.h"
-#include "xrt/xrt_kernel.h"
-
-constexpr bool VERIFY = true;
-
-constexpr int TEST_SIZE = 65536;
-constexpr int TILE_SIZE = 1024;
-constexpr int NUM_TILES = TEST_SIZE / TILE_SIZE;
-
-constexpr int SF_BLOCK_SIZE = 32;
-
-constexpr int TOTAL_TILE_SIZE =
- (TILE_SIZE / 2) + (TILE_SIZE / SF_BLOCK_SIZE) * 2;
-
-constexpr int IN_SIZE = (TEST_SIZE / 2) + (TEST_SIZE / SF_BLOCK_SIZE) * 2;
-constexpr int OUT_SIZE = TEST_SIZE * 2;
-
-namespace po = boost::program_options;
-
-void check_arg_file_exists(po::variables_map &vm_in, std::string name) {
- if (!vm_in.count(name)) {
- throw std::runtime_error("Error: no " + name + " file was provided\n");
- } else {
- std::ifstream test(vm_in[name].as());
- if (!test) {
- throw std::runtime_error("The " + name + " file " +
- vm_in[name].as() +
- " does not exist.\n");
- }
- }
-}
-
-static inline std::bfloat16_t random_bfloat16_t() {
- // Random numbers should NOT be uniformly between 0 and 1, because that
- // would make the matrix product AB always close to 1.
- return std::bfloat16_t(4.0 * (float)rand() / (float)(RAND_MAX));
-}
-
-bool nearly_equal(std::bfloat16_t a, std::bfloat16_t b) {
- std::bfloat16_t diff = fabs(a - b);
- if ((diff / a) < 0.01)
- return true;
- else
- return false;
-}
-
-std::vector load_instr_sequence(std::string instr_path) {
- std::ifstream instr_file(instr_path);
- std::string line;
- std::vector instr_v;
- while (std::getline(instr_file, line)) {
- std::istringstream iss(line);
- uint32_t a;
- if (!(iss >> std::hex >> a)) {
- throw std::runtime_error("Unable to parse instruction file\n");
- }
- instr_v.push_back(a);
- }
- return instr_v;
-}
-
-int main(int argc, const char *argv[]) {
-
- // Program arguments parsing
- po::options_description desc("Allowed options");
-
- desc.add_options()("help,h", "produce help message")(
- "xclbin,x", po::value()->required(),
- "the input xclbin path")(
- "kernel,k", po::value()->required(),
- "the kernel name in the XCLBIN (for instance PP_PRE_FD)")(
- "verbosity,v", po::value()->default_value(0),
- "the verbosity of the output")(
- "instr,i", po::value()->required(),
- "path of file containing userspace instructions to be sent to the LX6");
- po::variables_map vm;
-
- try {
- po::store(po::parse_command_line(argc, argv, desc), vm);
- po::notify(vm);
-
- if (vm.count("help")) {
- std::cout << desc << "\n";
- return 1;
- }
- } catch (const std::exception &ex) {
- std::cerr << ex.what() << "\n\n";
- std::cerr << "Usage:\n" << desc << "\n";
- return 1;
- }
-
- check_arg_file_exists(vm, "xclbin");
- check_arg_file_exists(vm, "instr");
-
- std::vector instr_v =
- load_instr_sequence(vm["instr"].as());
-
- int verbosity = vm["verbosity"].as();
- if (verbosity >= 1)
- std::cout << "Sequence instr count: " << instr_v.size() << "\n";
-
- // Start the XRT test code
- // Get a device handle
- unsigned int device_index = 0;
- auto device = xrt::device(device_index);
-
- // Load the xclbin
- if (verbosity >= 1)
- std::cout << "Loading xclbin: " << vm["xclbin"].as() << "\n";
- auto xclbin = xrt::xclbin(vm["xclbin"].as());
-
- if (verbosity >= 1)
- std::cout << "Kernel opcode: " << vm["kernel"].as() << "\n";
- std::string Node = vm["kernel"].as();
-
- // Get the kernel from the xclbin
- auto xkernels = xclbin.get_kernels();
- auto xkernel = *std::find_if(xkernels.begin(), xkernels.end(),
- [Node](xrt::xclbin::kernel &k) {
- auto name = k.get_name();
- std::cout << "Name: " << name << std::endl;
- return name.rfind(Node, 0) == 0;
- });
- auto kernelName = xkernel.get_name();
-
- if (verbosity >= 1)
- std::cout << "Registering xclbin: " << vm["xclbin"].as()
- << "\n";
-
- device.register_xclbin(xclbin);
-
- // get a hardware context
- if (verbosity >= 1)
- std::cout << "Getting hardware context.\n";
- xrt::hw_context context(device, xclbin.get_uuid());
-
- // get a kernel handle
- if (verbosity >= 1)
- std::cout << "Getting handle to kernel:" << kernelName << "\n";
- auto kernel = xrt::kernel(context, kernelName);
-
- auto bo_instr = xrt::bo(device, instr_v.size() * sizeof(int),
- XCL_BO_FLAGS_CACHEABLE, kernel.group_id(0));
- auto bo_in = xrt::bo(device, IN_SIZE * sizeof(char), XRT_BO_FLAGS_HOST_ONLY,
- kernel.group_id(2));
- auto bo_out = xrt::bo(device, OUT_SIZE * sizeof(char), XRT_BO_FLAGS_HOST_ONLY,
- kernel.group_id(3));
-
- if (verbosity >= 1)
- std::cout << "Writing data into buffer objects.\n";
-
- char *bufA = bo_in.map();
- std::vector AVec(IN_SIZE);
-
- std::vector A_private;
- std::vector A_sf;
-
- for (int t = 0; t < NUM_TILES; t++) {
- for (int pr = 0; pr < TILE_SIZE / 2; pr++) {
- std::int8_t lower = (rand()) & 0xf;
- std::int8_t upper = (rand()) & 0xf;
- AVec[t * TOTAL_TILE_SIZE + pr] = ((upper) << 4) + lower;
- A_private.push_back(lower);
- A_private.push_back(upper);
- if (verbosity >= 2) {
- if (t == 0)
- std::cout << std::hex << (t * TOTAL_TILE_SIZE + pr) << " : "
- << ((upper << 4) + lower) << std::dec << std::endl;
- }
- }
- for (int isf = 0; isf < TILE_SIZE / SF_BLOCK_SIZE; isf++) {
- std::bfloat16_t sf =
- std::bfloat16_t(4.0 * (float)rand() / (float)(RAND_MAX));
- std::uint16_t bits = *((std::uint16_t *)&sf);
- std::int8_t upper = (std::int8_t)(bits >> 8);
- std::int8_t lower = (std::int8_t)(bits & 0x00ff);
- AVec[t * TOTAL_TILE_SIZE + TILE_SIZE / 2 + isf * 2] = lower;
- AVec[t * TOTAL_TILE_SIZE + TILE_SIZE / 2 + isf * 2 + 1] = upper;
- A_sf.push_back(sf);
- if (verbosity >= 2) {
- if (t == 0)
- std::cout << std::hex
- << (t * TOTAL_TILE_SIZE + TILE_SIZE / 2 + isf * 2)
- << " and +1 :" << sf << std::dec << std::endl;
- }
- }
- }
-
- memcpy(bufA, AVec.data(), (AVec.size() * sizeof(char)));
-
- if (verbosity >= 2)
- std::cout << "Pre run values in " << std::hex << int(bufA[0]) << ", "
- << int(bufA[1]) << ", " << int(bufA[2]) << std::dec << std::endl;
-
- void *bufInstr = bo_instr.map();
- memcpy(bufInstr, instr_v.data(), instr_v.size() * sizeof(int));
-
- bo_instr.sync(XCL_BO_SYNC_BO_TO_DEVICE);
- bo_in.sync(XCL_BO_SYNC_BO_TO_DEVICE);
-
- int sticky_errors = 0;
-
- unsigned num_iter = 16;
- float npu_time_total = 0;
- float npu_time_min = 9999999;
- float npu_time_max = 0;
- for (unsigned iter = 0; iter < num_iter; iter++) {
-
- if (verbosity >= 1)
- std::cout << "Running Kernel.\n";
-
- auto start = std::chrono::high_resolution_clock::now();
-
- auto run = kernel(bo_instr, instr_v.size(), bo_in, bo_out);
- run.wait();
- auto stop = std::chrono::high_resolution_clock::now();
-
- bo_out.sync(XCL_BO_SYNC_BO_FROM_DEVICE);
-
- std::bfloat16_t *bufOut = bo_out.map();
-
- int errors = 0;
-
- if (verbosity >= 2) {
- std::cout << "First values in " << std::hex << int(bufA[0]) << ", "
- << int(bufA[1]) << ", " << int(bufA[2]) << std::dec
- << std::endl;
- std::cout << "First sf values in " << std::hex << int(bufA[512]) << ", "
- << int(bufA[513]) << ", " << int(bufA[514]) << ", "
- << int(bufA[515]) << std::dec << std::endl;
- std::cout << "First values out " << std::hex << bufOut[0] << ", "
- << bufOut[1] << ", " << bufOut[2] << ", " << bufOut[3]
- << std::dec << std::endl;
- std::cout << "Second values out " << std::hex << bufOut[32] << ", "
- << bufOut[33] << ", " << bufOut[34] << ", " << bufOut[35]
- << std::dec << std::endl;
-
- std::cout << "Reference values " << std::hex << A_private[0] << ", "
- << A_private[1] << std::dec << std::endl;
- std::cout << "Reference sfs " << A_sf[0] << ", " << A_sf[1]
- << std::endl;
- }
- float npu_time =
- std::chrono::duration_cast(stop - start)
- .count();
-
- npu_time_total += npu_time;
- npu_time_min = (npu_time < npu_time_min) ? npu_time : npu_time_min;
- npu_time_max = (npu_time > npu_time_max) ? npu_time : npu_time_max;
-
- if (VERIFY) {
- for (int t = 0; t < NUM_TILES; t++) {
- for (int pr = 0; pr < TILE_SIZE; pr++) {
- std::bfloat16_t sf = A_sf[t * SF_BLOCK_SIZE + pr / SF_BLOCK_SIZE];
- int val = (int)(A_private[t * TILE_SIZE + pr]);
- if (val >= 8)
- val = (val & 0x7) - 8; // Two's complement, but threes a crowd
- std::bfloat16_t scaled = sf * val;
-
- std::bfloat16_t from_AIE = bufOut[(t * TILE_SIZE) + pr];
-
- // These will not exactly match
- // The default rounding mode in AIE2 is to truncate, so we will get
- // off by one errors.
- std::uint16_t from_AIE_raw =
- *reinterpret_cast(&from_AIE);
- std::uint16_t scaled_raw =
- *reinterpret_cast(&scaled);
-
- std::bfloat16_t abs_diff = fabs(from_AIE - scaled);
- if ((abs_diff / fabs(from_AIE)) > 0.01) {
- std::cout << "Tile " << t << ":" << pr << " From AIE "
- << std::setprecision(12) << from_AIE << " ref "
- << std::setprecision(12) << scaled << " from "
- << std::setprecision(12) << sf << "*" << std::hex << val
- << std::dec << std::endl;
- std::cout << "Tile " << t << ":" << pr << " From AIE " << std::hex
- << from_AIE_raw << " ref " << scaled_raw << " from "
- << *reinterpret_cast(&sf) << "*" << val
- << std::dec << std::endl;
- errors++;
- }
- }
- }
- }
-
- if (VERIFY && !errors) {
- std::cout << iter << ": pass!\n";
- } else {
- std::cout << iter << ": fail! " << errors << " errors\n";
- }
- }
-
- std::cout << "Avg NPU exec time: " << npu_time_total / num_iter << "us."
- << std::endl;
- std::cout << "Min NPU matmul time: " << npu_time_min << "us." << std::endl;
- std::cout << "Max NPU matmul time: " << npu_time_max << "us." << std::endl;
- if (VERIFY && !sticky_errors) {
- std::cout << "\nPASS!\n\n";
- return 0;
- } else {
- std::cout << "\nFAIL.\n\n";
- return 1;
- }
-}
diff --git a/programming_examples/vision/README.md b/programming_examples/vision/README.md
index a369ac9bbe..ab9649aa50 100644
--- a/programming_examples/vision/README.md
+++ b/programming_examples/vision/README.md
@@ -12,18 +12,10 @@
The vision pipeline reference designs show how complex vision pipelines can be constructed from basic vision kernel building blocks. Those building blocks can be found in [aie_kernels/aie2](../../aie_kernels/aie2) and contain example kernels written for AI engines in both scalar and unoptimized vector format.
-## [Vision Pass Through](./vision_passthrough/)
-The [Vision Pass Through pipeline design](./vision_passthrough/) consists of a simple pipeline with just one `passThrough` kernel. This pipeline's main purpose is to test whether the data movement works correctly.
-
-## [Color Detect](./color_detect/)
-
-The [Color Detect pipeline design](./color_detect/) consists of the following blocks arranged in a pipeline fashion for the detecting of 2 colors in a sequence of images : `rgba2hue`, `threshold`, `threshold`, `bitwiseOR`, `gray2rgba`, `bitwiseAND`.
-
-## [Edge Detect](./edge_detect/)
-
-The [Edge Detect pipeline design](./edge_detect/) consists of the following blocks arranged in a pipeline fashion for the detection of edges in a sequence of images: `rgba2gray`, `filter2D`, `threshold`, `gray2rgba`, `addWeighted`.
-
-## [Color Threshold](./color_threshold/)
-
-The [Color Threshold pipeline design](./color_threshold/) consists of 4 threshold blocks in separate tiles that process a different region of an input image. The results are then merged back together and sent to the output.
\ No newline at end of file
+| Design name | Data type | Description |
+|-|-|-|
+| [Vision Passthrough](../../programming_examples/vision/vision_passthrough/) | i8 | A simple pipeline with just one `passThrough` kernel. This pipeline mainly aims to test whether the data movement works correctly to copy a greyscale image. |
+| [Color Detect](../../programming_examples/vision/color_detect/) | i32 | This multi-kernel, multi-core pipeline detects colors in an RGBA image. The design consists of the following blocks arranged in a pipeline fashion for the detecting of 2 colors in a sequence of images : `rgba2hue`, `threshold`, `threshold`, `bitwiseOR`, `gray2rgba`, `bitwiseAND`.|
+| [Edge Detect](../../programming_examples/vision/edge_detect/) | i32 | A multi-kernel, multi-core pipeline that detects edges in an image and overlays the detection on the original image. The design consists of the following blocks arranged in a pipeline fashion for the detection of edges in a sequence of images: `rgba2gray`, `filter2D`, `threshold`, `gray2rgba`, `addWeighted`.|
+| [Color Threshold](../../programming_examples/vision/color_threshold/) | i32 | A multi-core data-parallel implementation of color thresholding of a RGBA image. The design consists of 4 threshold blocks in separate tiles that process a different region of an input image. The results are then merged back together and sent to the output.|
diff --git a/programming_guide/section-5/README.md b/programming_guide/section-5/README.md
index 0e22fde08a..bfceae9e95 100644
--- a/programming_guide/section-5/README.md
+++ b/programming_guide/section-5/README.md
@@ -30,12 +30,14 @@ The [passthrough DMAs](../../programming_examples/basic/passthrough_dmas/) examp
|-|-|-|
| [Vector Scalar Add](../../programming_examples/basic/vector_scalar_add/) | i32 | Adds 1 to every element in vector |
| [Vector Scalar Mul](../../programming_examples/basic/vector_scalar_mul/) | i32 | Returns a vector multiplied by a scale factor |
+| [Vector Vector Add](../../programming_examples/basic/vector_vector_add/) | i32 | Returns a vector summed with another vector |
+| [Vector Vector Multiply](../../programming_examples/basic/vector_scalar_mul/) | i32 | Returns a vector multiplied by a vector |
| [Vector Reduce Add](../../programming_examples/basic/vector_reduce_add/) | bfloat16 | Returns the sum of all elements in a vector |
| [Vector Reduce Max](../../programming_examples/basic/vector_reduce_max/) | bfloat16 | Returns the maximum of all elements in a vector |
| [Vector Reduce Min](../../programming_examples/basic/vector_reduce_min/) | bfloat16 | Returns the minimum of all elements in a vector |
| [Vector Exp](../../programming_examples/basic/vector_exp/) | bfloat16 | Returns a vector representing $e^x$ of the inputs |
| [DMA Transpose](../../programming_examples/basic/dma_transpose/) | i32 | Transposes a matrix with the Shim DMA using `npu_dma_memcpy_nd` |
-| [Single core GEMM](../../programming_examples/basic/matrix_multiplication/single_core) | bfloat16 | A single core matrix-matrix multiply |
+| [Matrix Scalar Add](../../programming_examples/basic/matrix_scalar_add/) | i32 | Returns a matrix multiplied by a scalar | [Single core GEMM](../../programming_examples/basic/matrix_multiplication/single_core) | bfloat16 | A single core matrix-matrix multiply |
| [Multi core GEMM](../../programming_examples/basic/matrix_multiplication/whole_array) | bfloat16 | A matrix-matrix multiply using 16 AIEs with operand broadcast. Uses a simple "accumulate in place" strategy |
| [GEMV](../../programming_examples/basic/matrix_multiplication/matrix_vector) | bfloat16 | A vector-matrix multiply returning a vector