Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QNN Module API(new frontend) Preliminary Support #158

Merged
merged 57 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
e0d936e
feat : support multithread mergeouput copy.
liang1232018 Sep 7, 2024
006772d
update : support FP16 in FFN.
liang1232018 Sep 7, 2024
171d9a2
feat : using FP16 KVCache now.
liang1232018 Sep 8, 2024
2588ef0
feat : support reference SuperSiLU.
liang1232018 Sep 8, 2024
abd6afa
feat : implement SuperSiLU with bugs.
liang1232018 Sep 8, 2024
339f31b
fix : silu execution error.
liang1232018 Sep 8, 2024
f247393
fix : SuperSiLU bugs in scale.
liang1232018 Sep 8, 2024
99a0c4b
fix : SuperSiLU quant scale bug.
liang1232018 Sep 9, 2024
6b850fb
feat : optimize shadow execution with NEON and multithread.
liang1232018 Sep 9, 2024
ffe4717
fix: fp16 kvcache npu no copy
oreomaker Sep 10, 2024
5ac41a4
feat: create some tensor funcs' Layers for qnn exe
oreomaker Sep 26, 2024
5d7153d
chore: remove unused duplicate qnn linear
oreomaker Sep 26, 2024
602c807
fix: qnn cmake macro position
oreomaker Sep 26, 2024
4686b91
refactor: qwen qnn new module api modeling
oreomaker Sep 27, 2024
357a4e0
fix: demo qnn weights name
oreomaker Sep 28, 2024
887e5f9
fix: qnn tensor dtype population in module api
oreomaker Oct 5, 2024
6384e4b
refator: qnn transpose and kvcache channel
oreomaker Oct 6, 2024
050ffcb
Merge pull request #3 from liang1232018/develop-xdl-dev
liang1232018 Oct 8, 2024
31d6243
feat : support int8 output RMSNorm.
liang1232018 Oct 9, 2024
6bf38c2
feat : optimize RMSNorm intermediate type.
liang1232018 Oct 10, 2024
674a1b8
fix : res input tensor type error.
liang1232018 Oct 10, 2024
2b6cb3d
fix: qnn module api tensor alloc and dtype set
oreomaker Oct 11, 2024
ceebd84
Merge branch 'develop-zh-alpha' of github.com:liang1232018/mllm into …
oreomaker Oct 11, 2024
6063c0a
Merge branch 'main' into develop-zh
oreomaker Oct 11, 2024
ac16abd
chore: qnn cmake execcutables
oreomaker Oct 11, 2024
c771e14
feat : optimize quantization implementation in RMSNorm.
liang1232018 Oct 11, 2024
111437a
feat : optimize quantize and supersilu sf to int8 operation.
liang1232018 Oct 11, 2024
f4ff8db
feat : optimize SiLU execution without dequantize with uint8 mul.
liang1232018 Oct 13, 2024
6f6c12e
feat : optimize SiLU implementation without quantization.
liang1232018 Oct 14, 2024
8c92c65
fix: qnn new front end scale setting
oreomaker Oct 15, 2024
e1c0997
refactor: remove unecessary qnn backend multi models and graphs
oreomaker Oct 15, 2024
382a5f4
feat : optimize RMSNorm implementation.
liang1232018 Oct 15, 2024
21ea6a5
fix : use old accurate RMSNorm execution method.
liang1232018 Oct 15, 2024
18f0dda
fix : shadow output outlier execution bugs and optimize the input out…
liang1232018 Oct 16, 2024
4657817
Merge pull request #4 from liang1232018/develop-xdl-accuracy
liang1232018 Oct 16, 2024
cf96340
fix : wrong modification of CPUKVCache.
liang1232018 Oct 16, 2024
ed9abe3
feat: qnn new frontend prefill
oreomaker Oct 17, 2024
9eff6aa
fix : QNN backend with no openmp.
liang1232018 Oct 17, 2024
7081416
feat: qnn new frontend kvcache sharing
oreomaker Oct 18, 2024
5b446e5
Merge branch 'develop-xdl' into develop-zh
oreomaker Oct 18, 2024
fe36e69
Merge branch 'main' into develop-zh
yirongjie Oct 18, 2024
d778f72
Merge branch 'main' into develop-zh
yirongjie Oct 18, 2024
5c614c5
fix: bugs in cmake and layer.hpp
yirongjie Oct 18, 2024
d7d68f3
fix: add ign
yirongjie Oct 18, 2024
0805d64
Merge branch 'develop-zh' of github.com:liang1232018/mllm into develo…
oreomaker Oct 19, 2024
dd4c744
fix: qnn cmake executables
oreomaker Oct 19, 2024
f351610
fix: demo qnn padding tockenize
oreomaker Oct 19, 2024
c16091b
Merge branch 'main' into develop-zh
yirongjie Oct 21, 2024
6d79918
fix: cmake
yirongjie Oct 21, 2024
d7e9def
Merge branch 'main' into pr/oreomaker/158
yirongjie Oct 24, 2024
6ef60c5
Merge branch 'main' into develop-zh
UbiquitousLearning Oct 24, 2024
b93e504
fix: change module loader to global
oreomaker Oct 24, 2024
6eea3bd
Merge branch 'develop-zh' of github.com:liang1232018/mllm into develo…
oreomaker Oct 25, 2024
2388adb
Merge branch 'main' into develop-zh
UbiquitousLearning Oct 26, 2024
ea065d0
fix: qnn qwen new frontend modeling
oreomaker Oct 26, 2024
c182fbc
Merge branch 'develop-zh' of github.com:liang1232018/mllm into develo…
oreomaker Oct 26, 2024
c9f732b
fix: cmake
yirongjie Oct 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ endif()

# backend options
option(QNN "Enable QNN" OFF)
option(QNN_OLD_FRONTEND "Enable Old QNN" OFF)
if(QNN)
add_definitions(-DUSE_QNN) # the USE_QNN should come before cpu subdirectory
endif()
if(QNN_OLD_FRONTEND)
add_definitions(-DOLD_QNN)
endif()

if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
Expand Down Expand Up @@ -116,8 +123,7 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/pybind11/include)

add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/backends/cpu)

if(QNN)
add_definitions(-DUSE_QNN)
if(QNN) # QNN lib
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/backends/qnn)
endif()

Expand Down
2 changes: 1 addition & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ endmacro()


## new demos

func_llm_add_executable(benchmark)
func_llm_add_executable(demo_llama)
func_llm_add_executable(demo_tinyllama)
Expand Down Expand Up @@ -82,6 +81,7 @@ func_vlm_add_executable(demo_imagebind_1mod)

# QNN demo
if(QNN)
func_llm_add_executable(demo_qnn)
func_llm_add_executable(main_qwen_npu)
endif()

Expand Down
33 changes: 28 additions & 5 deletions examples/demo_qnn.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "cmdline.h"
#include "models/qwen/configuration_qwen.hpp"
#include "models/qwen/modeling_qwen_npu.hpp"
#include "models/qwen/modeling_qwen.hpp"
#include "models/qwen/tokenization_qwen.hpp"

using namespace mllm;
Expand All @@ -9,7 +10,7 @@ int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/qwen_vocab.mllm");
cmdParser.add<string>("merge", 'e', "specify mllm merge file path", false, "../vocab/qwen_merges.txt");
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/qwen-1.5-1.8b-q8_0.mllm");
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/qwen-1.5-1.8b-chat-int8.mllm");
cmdParser.add<string>("billion", 'b', "[0.5B | 1.8B]", false, "1.8B");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
Expand All @@ -24,25 +25,29 @@ int main(int argc, char **argv) {

auto tokenizer = QWenTokenizer(vocab_path, merge_path);
QWenConfig config(tokens_limit, model_billion, RoPEType::HFHUBROPE);
auto model = QWenForCausalLM(config);
auto model = QWenForCausalLM_NPU(config);
model.load(model_path);
// auto decoding_model = QWenForCausalLM(config);
// decoding_model.load("../models/qwen-1.5-1.8b-chat-q4k.mllm");

vector<string> in_strs = {
" Give me a short introduction to large language model.",
};

for (int i = 0; i < in_strs.size(); ++i) {
auto input_str = tokenizer.apply_chat_template(in_strs[i]);
auto input_tensor = tokenizer.tokenize(input_str);
auto [real_seq_length, input_tensor] = tokenizer.tokenizeWithPadding(input_str, 64, config.vocab_size);
std::cout << "[Q] " << in_strs[i] << std::endl;
std::cout << "[A] " << std::flush;

LlmTextGeneratorOpts opt{
.max_new_tokens = 100,
.do_sample = true,
.do_sample = false,
.temperature = 0.3f,
.top_k = 50,
.top_p = 0.f,
.is_padding = true,
.seq_before_padding = real_seq_length,
};
model.generate(input_tensor, opt, [&](unsigned int out_token) -> bool {
auto out_string = tokenizer.detokenize({out_token});
Expand All @@ -51,6 +56,24 @@ int main(int argc, char **argv) {
std::cout << output_string << std::flush;
return true;
});
std::cout << "FINISH\n";

LlmTextGeneratorOpts decoding_opt{
.max_new_tokens = 100,
.do_sample = false,
.temperature = 0.3f,
.top_k = 50,
.top_p = 0.f,
.is_padding = false,
};
// decoding_model.generate(input_tensor, decoding_opt, [&](unsigned int out_token) -> bool {
// auto out_string = tokenizer.detokenize({out_token});
// auto [isOk, print_string] = processOutput(out_string);
// if (isOk) {
// std::cout << print_string << std::flush;
// } else {
// return false;
// }
// return true;
// });
}
}
13 changes: 4 additions & 9 deletions examples/main_qwen_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,12 @@ int main(int argc, char **argv) {
std::cout << "[Q] " << in_str << std::endl;
std::cout << "[A] " << std::flush;

// cpuExe.run(&cpuNet, {input});
// auto result = cpuExe.result();
// auto token_idx = postProcessing(result[0], input);

// auto out_token = tokenizer.detokenize({token_idx});
// std::cout << out_token << std::flush;
// exit(0);

do {
// 1: Prefill stage using NPU chunk execute
npuExe.run(npu_ctx, &npuNet, {input});
if (chunk == 1)
npuExe.run(npu_ctx, &npuNet, {input});
else
npuExe.runExp(npu_ctx, &npuNet, {input});
auto result = npuExe.result();

// inter model for prefill-decode
Expand Down
56 changes: 11 additions & 45 deletions examples/main_qwen_npu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ namespace modeling {
NetTensor *Qwen_FFN_NPU(Context *c, NetTensor *i, int hidden_dim, int ffn_hidden_dim, string name) {
auto *x = _LinearINT8({i}, hidden_dim, ffn_hidden_dim, false, name + ".gate_proj");
auto *y = _LinearINT8({i}, hidden_dim, ffn_hidden_dim, false, name + ".up_proj");
x = _Dequantize({x}, true, (string)name + ".gate_proj.dequantize", true);
y = _Dequantize({y}, true, (string)name + ".up_proj.dequantize", true);
x = _SiLU({x}, name + ".silu");
x = *x * y;
x = _Quantize({x}, true, (string)name + ".down_proj.quantize");
x = _SuperSiLU({x,y}, name + ".supersilu");
// x = _Dequantize({x}, true, (string)name + ".gate_proj.dequantize", false);
// y = _Dequantize({y}, true, (string)name + ".up_proj.dequantize", false);
// x = _SiLU({x}, name + ".silu");
// x = *x * y;
// x = _Quantize({x}, true, (string)name + ".down_proj.quantize");
x = _LinearINT8({x}, ffn_hidden_dim, hidden_dim, false, name + ".down_proj");
x = _Dequantize({x}, true, (string)name + ".down_proj.dequantize");
return x;
Expand All @@ -29,9 +30,9 @@ std::vector<NetTensor *> Qwen_CPUNPUAttention(Context *c, NetTensor *x, NetTenso
k = k->view(1, head_size, seq / chunk, hidden_size);
v = v->view(1, head_size, seq / chunk, hidden_size);

q = _Dequantize({q}, true, (string)name + ".q_proj.dequantize");
k = _Dequantize({k}, true, (string)name + ".k_proj.dequantize");
v = _Dequantize({v}, true, (string)name + ".v_proj.dequantize");
q = _Dequantize({q}, true, (string)name + ".q_proj.dequantize", true);
k = _Dequantize({k}, true, (string)name + ".k_proj.dequantize", false);
v = _Dequantize({v}, true, (string)name + ".v_proj.dequantize", false);

v = _Transpose({v}, {0, 2, 3, 1}, (string)name + ".v_proj.transpose");

Expand Down Expand Up @@ -153,41 +154,6 @@ NetTensor *Qwen_FFN_CPU_q4k(Context *c, NetTensor *i, int hidden_dim, int ffn_hi
return x;
}

void qwen_cpu(Context *c, int vocab_size = 32000, int hidden_dim = 4096, int ffn_hidden_dim = 11008, int mutil_head_size = 32, int cache_max = 200, int seq = 256, int chunk = 2) {
auto *i = _Input(c);
i = _Embedding({i}, vocab_size, hidden_dim, (string) "model.embed_tokens");

for (int layer = 0; layer < 24; ++layer) {
auto res = _RMSNorm({i}, hidden_dim, 1e-6, (string) "model.layers." + std::to_string(layer) + ".input_layernorm");

i = *Qwen_CPUAttention(c, res, hidden_dim, hidden_dim / mutil_head_size, mutil_head_size, cache_max, (string) "model.layers." + std::to_string(layer) + ".self_attn", seq, chunk) + i;

res = _RMSNorm({i}, hidden_dim, 1e-6, (string) "model.layers." + std::to_string(layer) + ".post_attention_layernorm");

if (layer != 6 && layer != 1 && layer != 2) {
i = *Qwen_FFN_CPU(c, res, hidden_dim, ffn_hidden_dim, (string) "model.layers." + std::to_string(layer) + ".mlp") + i;
} else {
auto name = (string) "model.layers." + std::to_string(layer) + ".mlp";

auto *x = _LinearINT8({res}, hidden_dim, ffn_hidden_dim, false, name + ".gate_proj");
x = _SiLU({x}, name + ".silu");
auto *y = _LinearINT8({res}, hidden_dim, ffn_hidden_dim, false, name + ".up_proj");
x = *x * y; // x = _Mul( {x, y}, name+".dot");

auto *i1 = x;
x = _LinearINT8({x}, ffn_hidden_dim, hidden_dim, false, name + ".down_proj");

auto *i2 = x;

i = *x + i;

i = _LinearINT8Shadow({i1, i2, i}, ffn_hidden_dim, hidden_dim, false, name + ".down_proj.shadow");
}
}
i = _RMSNorm({i}, hidden_dim, 1e-6, (string) "model.norm");
i = _Linear({i}, hidden_dim, vocab_size, false, "lm_head");
}

void qwen_cpu_q4k(Context *c, int vocab_size = 32000, int hidden_dim = 4096, int ffn_hidden_dim = 11008, int mutil_head_size = 32, int cache_max = 200, int seq = 256, int chunk = 2) {
auto *i = _Input(c);
i = _Embedding({i}, vocab_size, hidden_dim, (string) "model.embed_tokens");
Expand Down Expand Up @@ -242,9 +208,9 @@ void qwen_npu(Context *c, int vocab_size = 32000, int hidden_dim = 4096, int ffn

res = i;

i = _RMSNorm({i}, hidden_dim, 1e-6, (string) "model.layers." + std::to_string(layer) + ".post_attention_layernorm");
i = _RMSNorm({i}, hidden_dim, 1e-6, (string) "model.layers." + std::to_string(layer) + ".post_attention_layernorm", false);

i = _Quantize({i}, true, (string) "model.layers." + std::to_string(layer) + ".mlp.up_proj.quantize");
// i = _Quantize({i}, true, (string) "model.layers." + std::to_string(layer) + ".mlp.up_proj.quantize");

i = i->view(1, static_cast<int>(seq / chunk / 32), static_cast<int>(32), hidden_dim);

Expand Down
3 changes: 3 additions & 0 deletions include/OpDefined.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ enum OpType {
MERGEOUTPUT,
SPLITINPUT,
IROPE,
SUPERSILU,
OP_NUM
};

Expand Down Expand Up @@ -107,6 +108,7 @@ static const vector<string> OpNames = {
"Range",
"Where",
"Replace",
"Predictor",
"SparseLinear",
"SparseIdLinear",
"ElasticLinear",
Expand All @@ -117,6 +119,7 @@ static const vector<string> OpNames = {
"MergeOutput",
"SplitInput",
"IRoPE",
"SuperSiLU",
"OP_NUM"};

enum TensorFuncType {
Expand Down
4 changes: 2 additions & 2 deletions include/Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ inline std::map<std::vector<int>, ChlType> Chls2Type = {
{{0, 3, 4, 1, 2}, BWCTH}};

enum TensorType {
INPUT_TENSOR = 0,
INPUT_TENSOR = 0, // used for input of the model
NORMAL_TENSOR,
OUTPUT_TENSOR,
GRAPH_OUTPUT, // used for output of a graph
};

enum Chl {
Expand Down
3 changes: 2 additions & 1 deletion scripts/build_qnn_android.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cmake .. \
-DQNN=ON \
-DDEBUG=OFF \
-DTEST=OFF \
-DQUANT=OFF
-DQUANT=OFF \
-DQNN_OLD_FRONTEND=ON

make -j4
4 changes: 4 additions & 0 deletions src/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <unordered_map>
#include <mutex>
#include "Layer.hpp"

namespace mllm {
extern void registerCPUBackendCreator();
Expand All @@ -29,6 +30,9 @@ static std::unordered_map<BackendType, std::shared_ptr<BackendCreator>> &GetBack
}

const std::shared_ptr<BackendCreator> GetBackendCreator(BackendType type) {
if (type == MLLM_QNN) {
Layer::use_layername_2_tensorname = false;
}
registerBackend();

auto &gExtraCreator = GetBackendCreatorMap();
Expand Down
5 changes: 5 additions & 0 deletions src/Backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class Op;
class Tensor;
class Backend;

// KVCache map for QNN-CPU KVCache sharing
#ifdef USE_QNN
static std::unordered_map<string, Op *> kv_cache_map;
#endif

class TensorFunction {
public:
virtual void setup(vector<Tensor *> outputs, vector<Tensor *> inputs, vector<float> args) = 0;
Expand Down
17 changes: 17 additions & 0 deletions src/Generate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ struct LlmTextGeneratorOpts {
float temperature = 0.7;
int top_k = 5;
float top_p = 0.92;
bool is_padding = false;
int seq_before_padding = 0;
};

template <typename T>
Expand All @@ -47,14 +49,24 @@ enum class LLmTextGeneratorType : int32_t {
};

class _LlmTextGenerateMethod {
bool is_padding = false;
int seq_before_padding = 0;
public:
virtual ~_LlmTextGenerateMethod() = default;
virtual unsigned int generate(Tensor &t) = 0;
inline void setPadding(bool is_padding, int seq_before_padding) {
this->is_padding = is_padding;
this->seq_before_padding = seq_before_padding;
}
inline void _tensor_to_vec(Tensor &t, std::vector<float> &scores) {
assert(t.batch() == 1 && "Batch size of result is not 1. Which is not supported for now.");
assert(t.head() == 1 && "The 3rd dim of result should be one. e.g.:[1, 1, seq, hidden]");
int _dims = t.dimension();
int _seq = t.sequence() - 1;
// padding prefill for QNN
if (is_padding) {
_seq = seq_before_padding - 1;
}
for (int i = 0; i < _dims; ++i) {
auto value = t.dataAt<float>(0, 0, _seq, i);
scores.push_back(value);
Expand Down Expand Up @@ -144,6 +156,11 @@ class LlmTextGenerator {
assert(false && "NIY");
break;
}

// padding prefill for QNN
if (opt.is_padding) {
m_method_class->setPadding(opt.is_padding, opt.seq_before_padding);
}
}

inline unsigned int generate(Tensor &t) {
Expand Down
5 changes: 1 addition & 4 deletions src/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ std::string intToStringWithLeadingZero(int num) {

namespace mllm {

#ifdef USE_QNN
static unordered_map<string, Op *> kv_cache_map;
#endif

Graph::Graph(const NetParameter &param, Backend *bn,
unordered_map<string, shared_ptr<Tensor>> &external_tensors,
Expand Down Expand Up @@ -149,7 +146,7 @@ void Graph::setUpTensors() {
// set graph out tensor TensorType
auto &graph_out_tensors = ops_output_tensors_[op_names_[op_names_.size() - 1]];
for (auto &t : graph_out_tensors) {
t->setTtype(OUTPUT_TENSOR);
t->setTtype(GRAPH_OUTPUT);
}

this->backend_->onSetUpStart(graph_in_tensors, graph_out_tensors);
Expand Down
Loading
Loading