Skip to content

Commit

Permalink
feat: rms norm, tranpose
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghuaWang committed Oct 16, 2024
1 parent ad3b148 commit 886eba7
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/backends/xnnpack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_library(MllmXnnpackBackend
Ops/XpD2H.cpp

Functions/XpBinaryFunc.cpp
Functions/XpTransposeFunc.cpp
)
target_include_directories(MllmXnnpackBackend PUBLIC third_party/XNNPACK/src/)
target_include_directories(MllmXnnpackBackend PUBLIC ${CMAKE_CURRENT_LIST_DIR}/../)
Expand Down
76 changes: 76 additions & 0 deletions src/backends/xnnpack/Functions/XpTransposeFunc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <array>
#include "backends/xnnpack/Functions/XpTransposeFunc.hpp"

namespace mllm::xnnpack {

void XpTransposeFunction::setup(vector<Tensor *> outputs, vector<Tensor *> inputs, vector<float> args) {
Chl axis0_ = (Chl)args[0];
Chl axis1_ = (Chl)args[1];

// inputs[0]->transShape(SEQUENCE, DIMENSION);
if (axis0_ == SEQUENCE && axis1_ == DIMENSION) {
if (inputs[0]->ctype() == BSHD) {
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->dimension(), inputs[0]->sequence());
}
} else if (axis0_ == THW && axis1_ == CHANNLE) {
if (inputs[0]->ctype() == BCTHW) {
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->time(), inputs[0]->height(), inputs[0]->width(), inputs[0]->channel());
}
} else if (axis0_ == BATCH && axis1_ == SEQUENCE) {
if (inputs[0]->ctype() == BSHD) {
outputs[0]->reshape(inputs[0]->sequence(), inputs[0]->head(), inputs[0]->batch(), inputs[0]->dimension());
}
}
}

void XpTransposeFunction::execute(vector<Tensor *> outputs, vector<Tensor *> inputs, vector<float> args) {
auto xpb = (XnnpackBackend *)inputs[0]->backend();
tryDefineAllXpTensors(xpb, inputs);
tryDefineAllXpTensors(xpb, outputs);

std::array<size_t, 4> perm{3, 2, 1, 0};

Chl axis0_ = (Chl)args[0];
Chl axis1_ = (Chl)args[1];

// inputs[0]->transShape(SEQUENCE, DIMENSION);
if (axis0_ == SEQUENCE && axis1_ == DIMENSION) {
if (inputs[0]->ctype() == BSHD) {
std::swap(perm[2], perm[3]);
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}
} else if (axis0_ == THW && axis1_ == CHANNLE) {
if (inputs[0]->ctype() == BCTHW) {
Log::error("XpTransposeFunction NYI");
exit(-1);
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->time(), inputs[0]->height(), inputs[0]->width(), inputs[0]->channel());
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}
} else if (axis0_ == BATCH && axis1_ == SEQUENCE) {
if (inputs[0]->ctype() == BSHD) {
Log::error("XpTransposeFunction NYI");
exit(-1);
outputs[0]
->reshape(inputs[0]->sequence(), inputs[0]->head(), inputs[0]->batch(), inputs[0]->dimension());
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}

auto status = xnn_define_static_transpose(xpb->getXnnSubgraph(), 4, perm.data(), inputs[0]->uuid(), outputs[0]->uuid(), 0);

if (status != xnn_status_success) {
Log::error("XpGeLU::execute Error");
exit(-1);
}
}

} // namespace mllm::xnnpack
23 changes: 23 additions & 0 deletions src/backends/xnnpack/Functions/XpTransposeFunc.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/**
* @file XpTransposeFunc.hpp
* @author Chenghua Wang (chenghua.wang.edu@gmail.com)
* @version 0.1
* @date 2024-10-16
*
* @copyright Copyright (c) 2024
*
*/
#pragma once

#include "Backend.hpp"
#include "xnnpack/XpInterface.hpp"
namespace mllm::xnnpack {

class XpTransposeFunction : public TensorFunction, public XpTensorDefineInterface<XpTransposeFunction> {
public:
void setup(vector<Tensor *> outputs, vector<Tensor *> inputs, vector<float> args) override;

void execute(vector<Tensor *> outputs, vector<Tensor *> inputs, vector<float> args) override;
};

} // namespace mllm::xnnpack
100 changes: 89 additions & 11 deletions src/backends/xnnpack/Ops/XpRMSNorm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "backends/xnnpack/Ops/XpRMSNorm.hpp"
#include "backends/xnnpack/XnnpackBackend.hpp"
#include "Types.hpp"
#include "xnnpack.h"

namespace mllm::xnnpack {

Expand All @@ -12,38 +13,115 @@ ErrorCode XpRMSNorm::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<
ErrorCode XpRMSNorm::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
auto xpb = (XnnpackBackend *)backend();
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension());
defineWeightTensor(xpb, &weight_);
return Op::reshape(inputs, outputs);
}

ErrorCode XpRMSNorm::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
auto xpb = (XnnpackBackend *)backend();
tryDefineAllXpTensors(xpb, inputs);
tryDefineAllXpTensors(xpb, outputs);
defineWeightTensor(xpb, &weight_params_);
defineWeightTensor(xpb, &epsilon_param_);

// TODO
auto dtype = inputs[0]->dtype();
size_t b = inputs[0]->shape()[0];
size_t s = inputs[0]->shape()[1];
size_t h = inputs[0]->shape()[2];
size_t d = inputs[0]->shape()[3];

// x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
auto x_powed = defineTemporaryTensor(xpb, {b, s, h, d}, dtype);
{
auto status = xnn_define_square(xpb->getXnnSubgraph(), inputs[0]->uuid(), x_powed, 0);
if (status != xnn_status_success) {
Log::error("XpRMSNorm: xnn_define_square failed");
exit(-1);
}
}
auto x_powed_mean = defineTemporaryTensor(xpb, {b, s, h, d}, dtype);
{
std::array<size_t, 1> along_axes{3};
auto status = xnn_define_static_mean(xpb->getXnnSubgraph(), 1, along_axes.data(), x_powed, x_powed_mean, XNN_FLAG_KEEP_DIMS);
if (status != xnn_status_success) {
Log::error("XpRMSNorm: xnn_define_static_mean failed");
exit(-1);
}
}
auto x_pow_mean_eps = defineTemporaryTensor(xpb, {b, s, h, d}, dtype);
{
auto status = xnn_define_binary(xpb->getXnnSubgraph(), xnn_binary_add, nullptr, x_powed_mean, epsilon_param_.uuid(), x_pow_mean_eps, 0);
if (status != xnn_status_success) {
Log::error("XpRMSNorm: xnn_define_binary xnn_binary_add failed");
exit(-1);
}
}
auto x_pme_rsqrt = defineTemporaryTensor(xpb, {b, s, h, d}, dtype);
{
auto status = xnn_define_reciprocal_square_root(xpb->getXnnSubgraph(), x_pow_mean_eps, x_pme_rsqrt, 0);
if (status != xnn_status_success) {
Log::error("XpRMSNorm: xnn_define_reciprocal_square_root failed");
exit(-1);
}
}
auto x_1 = defineTemporaryTensor(xpb, {b, s, h, d}, dtype);
{
auto status = xnn_define_binary(xpb->getXnnSubgraph(), xnn_binary_multiply, nullptr, inputs[0]->uuid(), x_pme_rsqrt, x_1, 0);
if (status != xnn_status_success) {
Log::error("XpRMSNorm: xnn_define_binary xnn_binary_multiply x * epsed failed");
exit(-1);
}
}

{
auto status = xnn_define_binary(xpb->getXnnSubgraph(), xnn_binary_multiply, nullptr, x_1, weight_params_.uuid(), outputs[0]->uuid(), 0);
if (status != xnn_status_success) {
Log::error("XpRMSNorm: xnn_define_binary xnn_binary_multiply x * weight failed");
exit(-1);
}
}

return MLLM_NO_ERROR;
}

ErrorCode XpRMSNorm::load(AbstructLoader &loader) {
auto xpb = (XnnpackBackend *)backend();

weight_.setName(name() + ".weight");
weight_.reshape(1, 1, 1, norm_size_);
if (loader.getDataType(weight_.name()) != MLLM_TYPE_COUNT) {
weight_.setDtype(loader.getDataType(weight_.name()));
weight_.alloc();
loader.load(&weight_);
weight_params_.setName(name() + ".weight");
weight_params_.reshape(1, 1, 1, norm_size_);
if (loader.getDataType(weight_params_.name()) != MLLM_TYPE_COUNT) {
weight_params_.setDtype(loader.getDataType(weight_params_.name()));
weight_params_.alloc();
loader.load(&weight_params_);
} else {
weight_.setDtype(MLLM_TYPE_F32);
weight_.alloc();
weight_params_.setDtype(MLLM_TYPE_F32);
weight_params_.alloc();
}

// preset epsilon
epsilon_param_.setName(name() + ".epsilon");
epsilon_param_.reshape(1, 1, 1, 1);
epsilon_param_.setDtype(MLLM_TYPE_F32);
epsilon_param_.alloc();
epsilon_param_.setDataAt(0, 0, 0, 0, epsilon_);

// FIXME: make this process more efficient.
if (add_unit_offset_) {
int batch = weight_params_.batch();
int dim = weight_params_.dimension();
int seq = weight_params_.sequence();
int head = weight_params_.head();
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (auto i = 0; i < batch * dim * seq * head; ++i) {
*(weight_params_.hostPtr<float>()) = *(weight_params_.hostPtr<float>()) + 1;
}
}

return Op::load(loader);
}

ErrorCode XpRMSNorm::free(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
weight_.free();
weight_params_.free();
epsilon_param_.free();
return Op::free(inputs, outputs);
}

Expand Down
6 changes: 4 additions & 2 deletions src/backends/xnnpack/Ops/XpRMSNorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class XpRMSNorm final : public Op, public XpTensorDefineInterface<XpRMSNorm> {
public:
XpRMSNorm(Backend *bk, const std::string &op_name, int norm_size, float epsilon = 1e-6, bool add_unit_offset = false, int thread_count = 4) :
Op(bk, op_name), norm_size_(norm_size), epsilon_(epsilon), add_unit_offset_(add_unit_offset), thread_count_(thread_count) {
weight_.setBackend(bk);
weight_params_.setBackend(bk);
epsilon_param_.setBackend(bk);
}

~XpRMSNorm() override = default;
Expand All @@ -36,7 +37,8 @@ class XpRMSNorm final : public Op, public XpTensorDefineInterface<XpRMSNorm> {
ErrorCode free(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;

private:
Tensor weight_;
Tensor weight_params_;
Tensor epsilon_param_;
float epsilon_;
int norm_size_;
bool add_unit_offset_;
Expand Down
84 changes: 84 additions & 0 deletions src/backends/xnnpack/Ops/XpTranspose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include "backends/xnnpack/Ops/XpTranspose.hpp"
#include "Types.hpp"
#include "xnnpack.h"
#include <utility>

namespace mllm::xnnpack {

ErrorCode XpTranspose::setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
return MLLM_NO_ERROR;
}

ErrorCode XpTranspose::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
// inputs[0]->transShape(SEQUENCE, DIMENSION);
if (axis0_ == SEQUENCE && axis1_ == DIMENSION) {
if (inputs[0]->ctype() == BSHD) {
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->dimension(), inputs[0]->sequence());
}
} else if (axis0_ == THW && axis1_ == CHANNLE) {
if (inputs[0]->ctype() == BCTHW) {
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->time(), inputs[0]->height(), inputs[0]->width(), inputs[0]->channel());
}
} else if (axis0_ == BATCH && axis1_ == SEQUENCE) {
if (inputs[0]->ctype() == BSHD) {
outputs[0]->reshape(inputs[0]->sequence(), inputs[0]->head(), inputs[0]->batch(), inputs[0]->dimension());
}
}
return MLLM_NO_ERROR;
}

ErrorCode XpTranspose::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
auto xpb = (XnnpackBackend *)inputs[0]->backend();
tryDefineAllXpTensors(xpb, inputs);
tryDefineAllXpTensors(xpb, outputs);

std::array<size_t, 4> perm{3, 2, 1, 0};

// inputs[0]->transShape(SEQUENCE, DIMENSION);
if (axis0_ == SEQUENCE && axis1_ == DIMENSION) {
if (inputs[0]->ctype() == BSHD) {
std::swap(perm[2], perm[3]);
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}
} else if (axis0_ == THW && axis1_ == CHANNLE) {
if (inputs[0]->ctype() == BCTHW) {
Log::error("XpTransposeFunction NYI");
exit(-1);
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->time(), inputs[0]->height(), inputs[0]->width(), inputs[0]->channel());
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}
} else if (axis0_ == BATCH && axis1_ == SEQUENCE) {
if (inputs[0]->ctype() == BSHD) {
Log::error("XpTransposeFunction NYI");
exit(-1);
outputs[0]
->reshape(inputs[0]->sequence(), inputs[0]->head(), inputs[0]->batch(), inputs[0]->dimension());
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}
} else {
Log::error("XpTransposeFunction NYI");
exit(-1);
}

auto status = xnn_define_static_transpose(xpb->getXnnSubgraph(), 4, perm.data(), inputs[0]->uuid(), outputs[0]->uuid(), 0);

if (status != xnn_status_success) {
Log::error("XpGeLU::execute Error");
exit(-1);
}

return MLLM_NO_ERROR;
}

Op *XpTransposeCreator::create(OpParam op_param, Backend *bk, const string &name, int thread_count) const {
int axis0 = (int)op_param["axis0"];
int axis1 = (int)op_param["axis1"];
return new XpTranspose(bk, axis0, axis1, name, thread_count);
}
} // namespace mllm::xnnpack
44 changes: 44 additions & 0 deletions src/backends/xnnpack/Ops/XpTranspose.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* @file XpTranspose.hpp
* @author Chenghua Wang (chenghua.wang.edu@gmail.com)
* @brief
* @version 0.1
* @date 2024-10-16
*
* @copyright Copyright (c) 2024
*
*/
#pragma once

#include "Backend.hpp"
#include "Op.hpp"
#include "backends/xnnpack/XnnpackBackend.hpp"
#include "backends/xnnpack/XpInterface.hpp"

namespace mllm::xnnpack {

class XpTranspose final : public Op, public XpTensorDefineInterface<XpTranspose> {
public:
XpTranspose(Backend *bk, int axis0, int axis1, const std::string &op_name, int thread_count) :
Op(bk, op_name), axis0_(axis0), axis1_(axis1), thread_count_(thread_count) {
}

~XpTranspose() override = default;

ErrorCode setUp(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;

ErrorCode reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;

ErrorCode execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;

private:
int axis0_;
int axis1_;
int thread_count_ = 4;
};

struct XpTransposeCreator : public XnnpackBackend::Creator {
Op *create(OpParam op_param, Backend *bk, const string &name, int thread_count) const override;
};

} // namespace mllm::xnnpack
Loading

0 comments on commit 886eba7

Please sign in to comment.