Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mistral v0.2 7B support #83

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ else ()
target_link_libraries(demo_qwen PUBLIC MLLM_CPU)
endif ()

add_executable(demo_mistral ${PROJECT_SOURCE_DIR}/examples/demo_mistral.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC}
src/tokenizers/Tokenizer.cpp
src/tokenizers/BPE/Bpe.cpp
src/processor/PreProcess.cpp
)
if (ARM AND NOT APK)
target_compile_options(demo_mistral PRIVATE -fopenmp)
target_link_libraries(demo_mistral PUBLIC MLLM_CPU -fopenmp -static-openmp)
else ()
target_link_libraries(demo_mistral PUBLIC MLLM_CPU)
endif ()

if (APK)
add_library(mllm_lib STATIC ${DIR_SRC_CPU} ${DIR_SRC_EXP} ${DIR_SRC} ${DIR_SRC_MEM_MANAGER}
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Wait.. why on-device multimodal LLM? - It's a key building block for [intelligen
| [LLaVA 7B](https://github.com/haotian-liu/LLaVA) | [✔️](https://huggingface.co/mllmTeam/llava-1.5-7b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/llava-1.5-7b-mllm/tree/main) |
| [Gemma 2B](https://github.com/google/gemma_pytorch) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) |
| [Qwen 0.5B](https://github.com/QwenLM/Qwen) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) |
| [Mistral 7B](https://github.com/mistralai/mistral-src) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) |

## Quick Start

Expand Down Expand Up @@ -208,22 +209,22 @@ cd scripts

```bash
cd ./bin
./demo_fuyu -m ../models/fuyu-8b-q4_k.mllm -v ../vacob/fuyu_vocab.mllm
./demo_fuyu -m ../models/fuyu-8b-q4_k.mllm -v ../vocab/fuyu_vocab.mllm
```

#### Run LLaMA-2-7B

```bash
cd ./bin
./demo_llama -m ../models/llama-2-7b-chat-q4_k.mllm -v ../vacob/llama_vocab.mllm
./demo_llama -m ../models/llama-2-7b-chat-q4_k.mllm -v ../vocab/llama_vocab.mllm
```


#### Run ImageBind

```bash
cd ./bin
./demo_imagebind -m ../models/imagebind_huge-q4_k.mllm -v ../vacob/clip_vocab.mllm
./demo_imagebind -m ../models/imagebind_huge-q4_k.mllm -v ../vocab/clip_vocab.mllm
```


Expand Down
70 changes: 70 additions & 0 deletions examples/demo_mistral.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/**
* @file demo_mistral.cpp
* @author Chenghua Wang (chenghua.wang.edu@gmail.com)
* @brief Mistral demo
* @version 0.1
* @date 2024-05-29
*
* @copyright Copyright (c) 2024
*
*/
#include "cmdline.h"
#include "models/mistral/configuration_mistral.hpp"
#include "models/mistral/modeling_mistral.hpp"
#include "models/mistral/tokenization_mistral.hpp"
#include "processor/PostProcess.hpp"

using namespace mllm;

int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/mistral_vocab.mllm");
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/mistral-7b-q4_k.mllm");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
cmdParser.parse_check(argc, argv);

string vocab_path = cmdParser.get<string>("vocab");
string model_path = cmdParser.get<string>("model");
int tokens_limit = cmdParser.get<int>("limits");
CPUBackend::cpu_threads = cmdParser.get<int>("thread");

auto tokenizer = MistralTokenizer(vocab_path);
MistralConfig config(tokens_limit, "7B", RoPEType::HFHUBROPE);
auto model = MistralForCausalLM(config);
model.load(model_path);

vector<string> in_strs = {
" Hello, who are you?",
" What can you do?",
"Please introduce Beijing University of Posts and Telecommunications.",
};

auto processOutput = [&](std::string &text) -> std::pair<bool, std::string> {
text = std::regex_replace(text, std::regex("▁"), " ");
if (text == "<0x0A>") return {true, "\n"};
if (text == "</s>") return {false, ""};
return {true, text};
};

for (int i = 0; i < in_strs.size(); ++i) {
auto in_str = in_strs[i];
auto input_tensor = tokenizer.tokenize(in_str, i);
std::cout << "[Q] " << in_str << std::endl;
std::cout << "[A] " << std::flush;
for (int step = 0; step < 100; step++) {
auto result = model({input_tensor});
auto outputs = tokenizer.detokenize(result[0]);
auto out_string = outputs.first;
auto out_token = outputs.second;
auto [isOk, print_string] = processOutput(out_string);
if (isOk) {
std::cout << print_string << std::flush;
} else {
break;
}
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
}
}
124 changes: 124 additions & 0 deletions src/models/mistral/configuration_mistral.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/**
* @file configuration_mistral.hpp
* @author Chenghua Wang (chenghua.wang.edu@gmail.com)
* @brief Mistral 7B instruction 0.2V in Huggingface
* @version 0.1
* @date 2024-05-25
*
* @copyright Copyright (c) 2024
*
*/
#ifndef CONFIG_MISTRAL_HPP
#define CONFIG_MISTRAL_HPP

#include "Types.hpp"
#include "models/transformer/configuration_transformer.hpp"
#include <cctype>
#include <iterator>

using namespace mllm;

class MistralNameConfig : public TransformerNameConfig {
public:
void init(RoPEType type = RoPEType::HFHUBROPE) {
switch (type) {
case RoPEType::HFHUBROPE: {
blk_name = "model.layers.";
_attn_base_name = "self_attn.";
_ffn_base_name = "mlp.";
_q_proj_name = "q_proj";
_k_proj_name = "k_proj";
_v_proj_name = "v_proj";
_o_proj_name = "o_proj";
_gate_proj_name = "gate_proj";
_up_proj_name = "up_proj";
_down_proj_name = "down_proj";
_attn_norm_name = "input_layernorm";
_ffn_norm_name = "post_attention_layernorm";
token_embd_name = "model.embed_tokens";
post_norm_name = "model.norm";
lm_head_name = "lm_head";
break;
}
case RoPEType::LLAMAROPE: /*the mistral is same to llama*/ {
blk_name = "layers.";
_attn_base_name = "attention.";
_ffn_base_name = "feed_forward.";
_q_proj_name = "wq";
_k_proj_name = "wk";
_v_proj_name = "wv";
_o_proj_name = "wo";
_gate_proj_name = "w1";
_up_proj_name = "w3";
_down_proj_name = "w2";
_attn_norm_name = "attention_norm";
_ffn_norm_name = "ffn_norm";
token_embd_name = "tok_embeddings";
post_norm_name = "norm";
lm_head_name = "output";
break;
}
default: {
throw std::runtime_error("Unsupported gemma RoPE type");
}
}
}

std::string blk_name;
std::string token_embd_name;
std::string post_norm_name;
std::string lm_head_name;
std::string _gate_proj_name;
};

struct MistralConfig {
explicit MistralConfig(int token_limit, string billions = "7B", RoPEType type = RoPEType::HFHUBROPE) :
cache_limit(token_limit) {
names_config.init(type);
string billionsType;
std::transform(billions.begin(), billions.end(), std::back_inserter(billionsType),
::tolower);
if (billionsType == "7b") {
attention_dropout = 0.0;
bos_token_id = 1;
eos_token_id = 2;
hidden_act = "silu";
hidden_size = 4096;
initializer_range = 0.02;
intermediate_size = 14336;
max_position_embeddings = 32768;
model_type = "mistral";
num_attention_heads = 32;
num_hidden_layers = 32;
num_key_value_heads = 8;
rms_norm_eps = 1e-05;
rope_theta = 1000000.0;
vocab_size = 32000;
} else {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;
}

float attention_dropout = 0.0;
int bos_token_id = 1;
int eos_token_id = 2;
std::string hidden_act = "silu";
int hidden_size = 4096;
float initializer_range = 0.02;
int intermediate_size = 14336;
int max_position_embeddings = 32768;
std::string model_type = "mistral";
int num_attention_heads = 32;
int num_hidden_layers = 32;
int num_key_value_heads = 8;
double rms_norm_eps = 1e-05;
float rope_theta = 1000000.0;
int vocab_size = 32000;

int cache_limit;
RoPEType RoPE_type = RoPEType::HFHUBROPE;
MistralNameConfig names_config;
};

#endif //! CONFIG_MISTRAL_HPP
Loading
Loading