Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add TransformerConfig #162

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ src/models/deepseek/*
examples/demo_phonellm.cpp
src/models/phonellm/*
examples/demo_minicpm3.cpp
src/models/minicpm3/*
src/models/minicpm3/*
examples/demo.cpp
4 changes: 3 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ macro(func_vlm_add_executable target)
${DIR_SRC}
${PROJECT_SOURCE_DIR}/src/tokenizers/Tokenizer.cpp
${PROJECT_SOURCE_DIR}/src/tokenizers/Unigram/Unigram.cpp
${PROJECT_SOURCE_DIR}/src/tokenizers/Unicode.cpp
${PROJECT_SOURCE_DIR}/src/tokenizers/UnicodeData.cpp
${PROJECT_SOURCE_DIR}/src/tokenizers/BPE/Bpe.cpp
${PROJECT_SOURCE_DIR}/src/processor/PreProcess.cpp
${DIR_SRC_PROCESSOE}
Expand All @@ -53,7 +55,6 @@ endmacro()
## new demos

func_llm_add_executable(benchmark)
# func_llm_add_executable(demo)
func_llm_add_executable(demo_llama)
func_llm_add_executable(demo_tinyllama)
func_llm_add_executable(demo_stablelm)
Expand All @@ -77,6 +78,7 @@ func_vlm_add_executable(demo_vit)
func_vlm_add_executable(demo_clip)
func_vlm_add_executable(demo_imagebind)
func_vlm_add_executable(demo_imagebind_1mod)
# func_vlm_add_executable(demo)

# QNN demo
if(QNN)
Expand Down
3 changes: 3 additions & 0 deletions src/Module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ class Module {
void setNoLoadWeightsDtype(DataType dtype) {
Op::noLoadWeightsDtype() = dtype;
}
virtual void clear_kvcache() {
;
}

vector<double> profiling(string name = "") {
vector<double> output;
Expand Down
5 changes: 3 additions & 2 deletions src/models/dclm/configuration_dclm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ class DCLMNameConfig : public TransformerNameConfig {
}
};

struct DCLMConfig {
explicit DCLMConfig(int token_limit, const string billions = "1B", RoPEType type = RoPEType::HFHUBROPE) :
struct DCLMConfig : public TransformerConfig {
explicit DCLMConfig(int token_limit, const string billions = "1B", RoPEType type = RoPEType::HFHUBROPE, int vocab = 50432) :
cache_limit(token_limit) {
names_config.init(type);
if (!(billions == "1B" || billions == "1b")) {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;
vocab_size = vocab;
};

int dim = 2048;
Expand Down
4 changes: 2 additions & 2 deletions src/models/gemma/configuration_gemma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class GemmaNameConfig : public TransformerNameConfig {
std::string _gate_proj_name;
};

struct GemmaConfig {
struct GemmaConfig : public TransformerConfig {
explicit GemmaConfig(int token_limit, const string billions = "2B", RoPEType type = RoPEType::HFHUBROPE) :
cache_limit(token_limit) {
names_config.init(type);
Expand All @@ -93,7 +93,7 @@ struct GemmaConfig {
int intermediate_size = 16384;
int head_dim = 256;
float rms_norm_eps = 1e-6;
float rope_theta= 10000;
float rope_theta = 10000;

int cache_limit;
RoPEType RoPE_type = RoPEType::HFHUBROPE;
Expand Down
8 changes: 4 additions & 4 deletions src/models/imagebind/configuration_imagebind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ImagebindNameConfig : public TransformerNameConfig {
_audio_blocks_name = "modality_trunks.audio.blocks.";
}
};
class ImagebindConfig {
class ImagebindConfig : public TransformerConfig {
public:
ImagebindNameConfig names_config;

Expand All @@ -99,13 +99,13 @@ class ImagebindConfig {
int audio_stride = 10;
int audio_h = 128;
int audio_w = 204;
int audio_block_num =12;
int audio_block_num = 12;

int head_hidden_dim = 1024;

explicit ImagebindConfig(const string& model_type = "huge") {
explicit ImagebindConfig(const string &model_type = "huge") {
if (model_type != "huge") {
std::cerr<<"model type not supported"<<std::endl;
std::cerr << "model type not supported" << std::endl;
}
names_config.init();
}
Expand Down
8 changes: 4 additions & 4 deletions src/models/llama/configuration_llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class LLaMANameConfig : public TransformerNameConfig {
}
};

class LLaMAConfig {
class LLaMAConfig : public TransformerConfig {
public:
int vocab_size{};
int hidden_dim{};
Expand All @@ -84,7 +84,7 @@ class LLaMAConfig {
num_key_value_heads = 32;
ffn_hidden = 11008;
block_num = 32;
max_position_embeddings= 16384;
max_position_embeddings = 16384;
rope_theta = 10000;
} else if (billions == "6B" || billions == "6b") {
// Yi @https://arxiv.org/abs/2403.04652
Expand All @@ -93,10 +93,10 @@ class LLaMAConfig {
num_key_value_heads = 4;
ffn_hidden = 11008;
block_num = 32;
max_position_embeddings= 4096;
max_position_embeddings = 4096;
rope_theta = 5000000.0;
vocab_size = 64000;
}else {
} else {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;
Expand Down
52 changes: 26 additions & 26 deletions src/models/llama/modeling_elastic_llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,36 @@ class ElasticMultiHeadAttention final : public Module {

public:
ElasticMultiHeadAttention() = default;
ElasticMultiHeadAttention(int hidden_dim, int head_size,int kv_head_size, int attn_hidden_dim,
RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias,
const TransformerNameConfig &names, const string &base_name) {
ElasticMultiHeadAttention(int hidden_dim, int head_size, int kv_head_size, int attn_hidden_dim,
RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias,
const TransformerNameConfig &names, const string &base_name) {
assert(kv_head_size_ == head_size_);
attn_hidden_dim_ = attn_hidden_dim;
head_size_ = head_size;
kv_head_size_ = kv_head_size;
q_proj = ElasticLinear(hidden_dim, head_size * attn_hidden_dim, bias, base_name + names._q_proj_name);
k_proj = ElasticLinear(hidden_dim, kv_head_size * attn_hidden_dim, bias, base_name + names._k_proj_name);
v_proj = ElasticLinear(hidden_dim, kv_head_size * attn_hidden_dim, bias, base_name + names._v_proj_name);

if (RoPE_type > 0) {
q_rope = RoPE(RoPE_type, base_name + "q_rope");
k_rope = RoPE(RoPE_type, base_name + "k_rope");
}
if (cache_limit > 0) {
k_cache = KVCache(head_size/kv_head_size, cache_limit, base_name + "k_cache");
v_cache = KVCache(head_size/kv_head_size, cache_limit, base_name + "v_cache");
k_cache = KVCache(head_size / kv_head_size, cache_limit, base_name + "k_cache");
v_cache = KVCache(head_size / kv_head_size, cache_limit, base_name + "v_cache");
}
softmax = Softmax(DIMENSION, do_mask, base_name + "softmax");
o_proj = ElasticLinear(head_size * attn_hidden_dim, hidden_dim, bias, base_name + names._o_proj_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<int> activate_head_dims = std::any_cast<vector<int>>(args[0]);
int activate_head_dim = activate_head_dims[0];
activate_head_dim = (activate_head_dim==-1)? kv_head_size_: (activate_head_dim);
activate_head_dim = (activate_head_dim == -1) ? kv_head_size_ : (activate_head_dim);
Tensor q, k, v;
q = q_proj(inputs[0], -1, activate_head_dim*attn_hidden_dim_);
k = k_proj(inputs[1], -1, activate_head_dim*attn_hidden_dim_);
v = v_proj(inputs[2], -1, activate_head_dim*attn_hidden_dim_);
q = q_proj(inputs[0], -1, activate_head_dim * attn_hidden_dim_);
k = k_proj(inputs[1], -1, activate_head_dim * attn_hidden_dim_);
v = v_proj(inputs[2], -1, activate_head_dim * attn_hidden_dim_);
q = q.view(-1, activate_head_dim, -1, attn_hidden_dim_);
k = k.view(-1, activate_head_dim, -1, attn_hidden_dim_);
v = v.view(-1, activate_head_dim, -1, attn_hidden_dim_);
Expand All @@ -72,19 +72,19 @@ class ElasticMultiHeadAttention final : public Module {
}
k = k.transpose(SEQUENCE, DIMENSION);
auto qk = Tensor::mm(q, k);
qk = qk / std::sqrt(attn_hidden_dim_);//attn_hidden_dim_
qk = qk / std::sqrt(attn_hidden_dim_); // attn_hidden_dim_
if (k_cache.ready() && v_cache.ready()) {
qk = softmax(qk, k_cache.getCacheSeqLen());
}else{
} else {
qk = softmax(qk);
}
auto o = Tensor::mm(qk, v);
o = o.view(-1, 1, -1, attn_hidden_dim_ * activate_head_dim);
o = o_proj(o, activate_head_dim*attn_hidden_dim_, -1);
o = o_proj(o, activate_head_dim * attn_hidden_dim_, -1);
return {o};
}
vector<KVCache*> get_cache() {
return {&k_cache,&v_cache};
vector<KVCache *> get_cache() {
return {&k_cache, &v_cache};
}
};

Expand All @@ -102,7 +102,7 @@ class ElasticLLaMAMLP final : public Module {
up_proj = ElasticLinear(hidden_dim, ffn_hidden, false, base_name + names._up_proj_name);
down_proj = ElasticLinear(ffn_hidden, hidden_dim, false, base_name + names._down_proj_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<int> activate_dims = std::any_cast<vector<int>>(args[0]);
int activate_dim = activate_dims[0];
auto x = gate_proj(inputs[0], -1, activate_dim);
Expand All @@ -124,12 +124,12 @@ class ElasticLLaMABlock final : public Module {
ElasticLLaMABlock() = default;
ElasticLLaMABlock(int hidden_dim, int head_size, int ffn_hidden, RoPEType RoPE_type, int cache_limit, const LLaMANameConfig &names, const string &base_name) {
attention = ElasticMultiHeadAttention(hidden_dim, head_size, head_size, hidden_dim / head_size,
RoPE_type, cache_limit, true, false, names, base_name + names._attn_base_name);
RoPE_type, cache_limit, true, false, names, base_name + names._attn_base_name);
mlp = ElasticLLaMAMLP(hidden_dim, ffn_hidden, names, base_name + names._ffn_base_name);
norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name);
norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<int> activate_dims = std::any_cast<vector<int>>(args[0]);
vector<int> dim_attns = {activate_dims[0]};
vector<int> dim_mlps = {activate_dims[1]};
Expand All @@ -141,7 +141,7 @@ class ElasticLLaMABlock final : public Module {
x = x + tmp;
return {x};
}
ElasticMultiHeadAttention& get_attention() {
ElasticMultiHeadAttention &get_attention() {
return attention;
}
};
Expand All @@ -156,31 +156,31 @@ class ElasticLLaMAModel final : public Module {
public:
explicit ElasticLLaMAModel(const LLaMAConfig &config) :
ElasticLLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.ffn_hidden, config.block_num, config.RoPE_type, config.cache_limit,
config.names_config, config.names_config.blk_name) {
config.names_config, config.names_config.blk_name) {
}
ElasticLLaMAModel(int vocab_size, int hidden_dim, int head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, int cache_limit,
const LLaMANameConfig &names, const string &base_name) {
const LLaMANameConfig &names, const string &base_name) {
embedding = Embedding(vocab_size, hidden_dim, names.token_embd_name);
blocks = List<ElasticLLaMABlock>(block_num, hidden_dim, head_size, ffn_hidden, RoPE_type, cache_limit, names, base_name);
norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name);
lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name);
num_layer_size = block_num;
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<vector<int>> activate_dims = std::any_cast<vector<vector<int>>>(args[0]);
assert(activate_dims.size() == num_layer_size);
auto x = embedding(inputs[0]);
for (int id = 0; id<blocks.size(); id ++){
for (int id = 0; id < blocks.size(); id++) {
x = blocks[id]({x}, activate_dims[id])[0];
}
x = norm(x);
x = lm_head(x);
return {x};
}

void clear_kvcache() {
void clear_kvcache() override {
for (auto &block : blocks) {
auto kvcahce =block.get_attention().get_cache();
auto kvcahce = block.get_attention().get_cache();
for (auto &cache : kvcahce) {
cache->clearCache();
}
Expand Down
18 changes: 9 additions & 9 deletions src/models/llama/modeling_llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class LLaMAMLP final : public Module {
up_proj = Linear(hidden_dim, ffn_hidden, false, base_name + names._up_proj_name);
down_proj = Linear(ffn_hidden, hidden_dim, false, base_name + names._down_proj_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
auto x = gate_proj(inputs[0]);
x = silu(x);
auto y = up_proj(inputs[0]);
Expand All @@ -51,7 +51,7 @@ class LLaMABlock final : public Module {
norm1 = RMSNorm(hidden_dim, 1e-6, base_name + names._attn_norm_name);
norm2 = RMSNorm(hidden_dim, 1e-6, base_name + names._ffn_norm_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
auto x = norm1(inputs[0]);
x = attention({x, x, x})[0];
auto tmp = x + inputs[0];
Expand All @@ -60,8 +60,8 @@ class LLaMABlock final : public Module {
x = x + tmp;
return {x};
}
MultiHeadAttention& get_attention() {

MultiHeadAttention &get_attention() {
return attention;
}
};
Expand All @@ -74,8 +74,8 @@ class LLaMAModel final : public Module {

public:
explicit LLaMAModel(const LLaMAConfig &config) :
LLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num,
config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit,
LLaMAModel(config.vocab_size, config.hidden_dim, config.head_size, config.num_key_value_heads, config.ffn_hidden, config.block_num,
config.RoPE_type, config.rope_theta, config.max_position_embeddings, config.cache_limit,
config.names_config, config.names_config.blk_name) {
}
LLaMAModel(int vocab_size, int hidden_dim, int head_size, int kv_head_size, int ffn_hidden, int block_num, RoPEType RoPE_type, float rope_theta, int max_position_embeddings, int cache_limit,
Expand All @@ -85,7 +85,7 @@ class LLaMAModel final : public Module {
norm = RMSNorm(hidden_dim, 1e-6, names.post_norm_name);
lm_head = Linear(hidden_dim, vocab_size, false, names.lm_head_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
auto x = embedding(inputs[0]);
for (auto &block : blocks) {
x = block({x})[0];
Expand All @@ -95,9 +95,9 @@ class LLaMAModel final : public Module {
return {x};
}

void clear_kvcache() {
void clear_kvcache() override {
for (auto &block : blocks) {
auto kvcahce =block.get_attention().get_cache();
auto kvcahce = block.get_attention().get_cache();
for (auto &cache : kvcahce) {
cache->clearCache();
}
Expand Down
5 changes: 2 additions & 3 deletions src/models/minicpm/configuration_minicpm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ class MiniCPMNameConfig : public TransformerNameConfig {
token_embd_name = "model.embed_tokens";
post_norm_name = "model.norm";
lm_head_name = "lm_head";

}
};

struct MiniCPMConfig {
struct MiniCPMConfig : public TransformerConfig {
explicit MiniCPMConfig(int token_limit, string billions = "2B") :
cache_limit(token_limit) {
names_config.init();
Expand Down Expand Up @@ -79,7 +78,7 @@ struct MiniCPMConfig {
double rms_norm_eps = 1e-05;
float rope_theta = 10000.0;
int vocab_size = 122753;
int head_dim = 64; //self.hidden_size // self.num_heads
int head_dim = 64; // self.hidden_size // self.num_heads
float scale_depth = 1.4;
float scale_emb = 12;
float dim_model_base = 256;
Expand Down
Loading
Loading