-
Notifications
You must be signed in to change notification settings - Fork 84
Getting started: An example of using MIGraphX 0.1
This page demonstrates how MIGraphX version 0.1 can be used to speed up deep learning inference. It walks through steps of importing an ONNX model, compiling that model and then executing the model with input data. The particular application uses Resnet50 to classify images.
Release 0.1, the first release of MIGraphX is a technology demonstration that has sufficient functionality to do end-to-end steps for speeding machine learning inference. However, it also still has a number of limitations that we expect to remove as functionality is added in future releases. Some of these limitations include:
- A formal API has not yet been defined. This example makes calls directly to MIGraphX routines that may be replaced with a more formal interface definition. Hence, header files, routines, arguments may all change. These instructions may not work for future versions of MIGraphX.
- Partial support for ONNX layers and operators. Examples used for this demo and some others should work. However, trying other ONNX models one may encounter functionality not yet implemented.
- No support for multiple GPUs.
- No support for machine learning training.
First, install the AMD ROCm platform. Details are here.
Next, install other ROCm related libraries including MIOpen version 1.6 or greater:
sudo apt install rocm-libs miopen-hip
Either install a pre-built version 0.1 of MIGraphX (where?) or follow (instructions to build MIGraphX. As part of this installation/build several dependencies will need to satisfied:
- MIGraphX uses the Google Protobuf library. The default version of this library provided with distributions such as Ubuntu (libprotobuf-dev) is not configured with large enough buffers for MIGraphX. One way around this is to build libprotobuf from source. Version v3.2.1 built from https://github.com/protocolbuffers/protobuf worked for this example.
Set up your compiler and linker environment to use MIGraphX headers at compilation and to link against MIGraphX libraries. If MIGraphX is installed at $(MIGRAPHX), the following Makefile definitions can be used with g++:
CXXFLAGS=-std=c++14
MIGRAPHX_INCLUDES= -I /opt/rocm/include -I /opt/rocm/targets/gpu/include
MIGRAPHX_LIBS = -L /opt/rocm/lib -lmigraphx -lmigraphx_onnx -lmigraphx_gpu
Add MIGraphX header definitions to the program source. For this example, the following are used:
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/generate.hpp>
MIGraphX works with a variety of ONNX files. For this example, we will use a trained model as generated from pytorch 0.4.0. As described on this page, pytorch can be installed with
conda install pytorch=0.4.0 -c pytorch
Once we have pytorch installed, the following python code creates our ONNX file:
import torch
import torchvision.models as models
batch_size = 1
resnet50 = models.resnet50(pretrained=True)
resnet50.eval()
torch.onnx.export(resnet50, torch.randn(batch_size,3,224,224), "resnet50.onnx")
The result of the preceding steps is an ONNX file, resnet50.onnx that can be loaded into our C++ application. The code to do this is simply a call to the migraphx::parse_onnx routine. This returns a migraphx::program object (reminder, these calls may change once a more formal API is defined):
using namespace migraphx;
program prog = parse_onnx("resnet50.onnx");
Once our application has parsed the ONNX file, the next step is to optimize the graph. Release 0.1 includes basic optimizations including fusing adjacent operators, graph transformations and optimizing memory usage. The compilation step also "lowers" the graph to a GPU target representation. The lowering includes allocating space and copying literal weights on the GPU. All these steps happen with a single call:
prog.compile(gpu::target{});
After compilation, the model is ready to be executed on the GPU. The remaining steps are (1) allocating space for input parameters, output parameters, scratch space (2) copying the input parameters to the GPU (3) executing the model and (4) copying back the results. Steps 2-4 will happen once for each batch of input that is processed.
In our example, space can be allocated with the following code:
program::parameter_map m;
m["scratch"] = gpu::allocate_gpu(prog.get_parameter_shape("scratch"));
m["output"] = gpu::allocate_gpu(prog.get_parameter_shape("output"));
In addition to allocating space, the results are saved in a "parameter_map" object that maintains correspondence between the user-provided names from the ONNX file, e.g. "0", "output", "scratch" and the GPU memory locations. The code above has one additional enhancement as well. We know that our pytorch model uses "0" to name the input parameter and that code following will allocate space for the input when it is copied to the GPU. Hence, we add a check "x.first != "0" to avoid allocating this space twice.
At this point, our application is now ready to receive batch input for the resnet50 model in NCHW format for Imagenet data. If we assume that input is in an object named "image" and that an "image.data()" method will retrieve the underlying data pointer, the step of copying the input data to the GPU is simple:
auto s = shape{shape::float_type,{(unsigned) batch_size,3,224,224}};
m["0"] = gpu::to_gpu(argument{s,image.data()});
The next step is to evaluate the model on the GPU. This returns an argument pointer that can be used to retrieve the data.
auto resarg = prog.eval(m);
the final step is to copy back the results from the GPU to our application:
argument result = gpu::from_gpu(resarg[0]);
The final thing to note is how memory on the GPU is deallocated. MIGraphX uses the C++ std::unique_ptr mechanism for garbage-collected memory, including for pointers to GPU memory that are stored in our parameter_map object (m). When the parameter_map object is deleted, the destructor can also deallocate GPU memory if the pointers to GPU memory it contains have no other references.
The description above describes the end-to-end process of using MIGraphX to import an ONNX model, compile the model and execute the model with input data. This example is also shown below as a complete C++ file and Makefile.
// Example program for using MIGraphX release 0.1
//
#include <iostream>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/generate.hpp>
using namespace migraphx;
int main(void){
unsigned batch_size = 1;
program prog;
// loading an ONNX file
try {
prog = parse_onnx("resnet50.onnx");
} catch(...){
std::cerr << "Unable to open ONNX file: resnet50.onnx" << std::endl;
return 0;
}
// compiling the graph
prog.compile(gpu::target{});
// allocating space for parameters other than the input (0)
program::parameter_map m;
// allocate space for scratch and output parameters
m["scratch"] = gpu::allocate_gpu(prog.get_parameter_shape("scratch"));
m["output"] = gpu::allocate_gpu(prog.get_parameter_shape("output"));
// create a dummy input parameter named image.
// In an actual program, this would be copied in from elsewhere...
auto s = shape{shape::float_type,{batch_size,3,224,224}};
auto image = generate_argument(prog.get_parameter_shape("0"));
m["0"] = gpu::to_gpu(argument{s,image.data()});
// evaluating the graph
auto resarg = prog.eval(m);
// fetching the result
argument result = gpu::from_gpu(resarg[0]);
}
The Makefile for this file is as follows
CXXFLAGS=-g -std=c++14
MIGRAPHX_INCLUDES= -I /opt/rocm/include -I /opt/rocm/targets/gpu/include
MIGRAPHX_LIBS = -L /opt/rocm/lib -lmigraphx -lmigraphx_onnx -lmigraphx_gpu
CXX=/opt/rocm/hcc/bin/hcc
example: example.cxx
$(CXX) -o example $(CXXFLAGS) $(MIGRAPHX_INCLUDES) example.cxx $(MIGRAPHX_LIBS)