diff --git a/ChatQnA/Dockerfile.dynamic_batching b/ChatQnA/Dockerfile.dynamic_batching new file mode 100644 index 000000000..f6b4fda40 --- /dev/null +++ b/ChatQnA/Dockerfile.dynamic_batching @@ -0,0 +1,31 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +FROM python:3.11-slim + +RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ + git \ + libgl1-mesa-glx \ + libjemalloc-dev + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +WORKDIR /home/user/ +RUN git clone https://github.com/opea-project/GenAIComps.git + +WORKDIR /home/user/GenAIComps +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r /home/user/GenAIComps/requirements.txt && \ + pip install --no-cache-dir langchain_core + +COPY ./chatqna_dynamic_batching.py /home/user/chatqna_dynamic_batching.py + +ENV PYTHONPATH=$PYTHONPATH:/home/user/GenAIComps + +USER user + +WORKDIR /home/user + +ENTRYPOINT ["python", "chatqna_dynamic_batching.py"] diff --git a/ChatQnA/chatqna_dynamic_batching.py b/ChatQnA/chatqna_dynamic_batching.py new file mode 100644 index 000000000..9d4c6164f --- /dev/null +++ b/ChatQnA/chatqna_dynamic_batching.py @@ -0,0 +1,170 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import re + +from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType +from langchain_core.prompts import PromptTemplate + + +class ChatTemplate: + @staticmethod + def generate_rag_prompt(question, documents): + context_str = "\n".join(documents) + if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: + # chinese context + template = """ +### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。 +### 搜索结果:{context} +### 问题:{question} +### 回答: +""" + else: + template = """ +### You are a helpful, respectful and honest assistant to help the user with questions. \ +Please refer to the search results obtained from the local knowledge base. \ +But be careful to not incorporate the information that you think is not relevant to the question. \ +If you don't know the answer to a question, please don't share false information. \n +### Search results: {context} \n +### Question: {question} \n +### Answer: +""" + return template.format(context=context_str, question=question) + + +MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") +MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) +RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") +RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) +# Embed/Rerank use the same host:ip, diff routers +EMBEDDING_RERANK_SERVICE_HOST_IP = os.getenv("EMBEDDING_RERANK_SERVICE_HOST_IP", "0.0.0.0") +EMBEDDING_RERANK_SERVICE_PORT = os.getenv("EMBEDDING_RERANK_SERVICE_PORT", 6001) +LLM_SERVER_HOST_IP = os.getenv("LLM_SERVER_HOST_IP", "0.0.0.0") +LLM_SERVER_PORT = int(os.getenv("LLM_SERVER_PORT", 9009)) + + +def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): + # if self.services[cur_node].service_type == ServiceType.EMBEDDING: + # inputs["inputs"] = inputs["text"] + # del inputs["text"] + if self.services[cur_node].service_type == ServiceType.RETRIEVER: + # prepare the retriever params + retriever_parameters = kwargs.get("retriever_parameters", None) + if retriever_parameters: + inputs.update(retriever_parameters.dict()) + elif self.services[cur_node].service_type == ServiceType.RERANK: + reranker_parameters = kwargs.get("reranker_parameters", None) + top_n = reranker_parameters.top_n if reranker_parameters else 1 + inputs["top_n"] = top_n + elif self.services[cur_node].service_type == ServiceType.LLM: + prompt = inputs["query"] + docs: list[str] = inputs["documents"] + chat_template = llm_parameters_dict["chat_template"] + if chat_template: + prompt_template = PromptTemplate.from_template(chat_template) + input_variables = prompt_template.input_variables + if sorted(input_variables) == ["context", "question"]: + prompt = prompt_template.format(question=prompt, context="\n".join(docs)) + elif input_variables == ["question"]: + prompt = prompt_template.format(question=prompt) + else: + print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") + prompt = ChatTemplate.generate_rag_prompt(prompt, docs) + else: + prompt = ChatTemplate.generate_rag_prompt(prompt, docs) + + # inputs: LLMParamsDoc + # {'id': 'd52f75f7bd602526073d933dab541a8c', 'model': None, 'query': 'What is the revenue of Nike in 2023?', 'max_tokens': 1024, 'max_new_tokens': 1024, 'top_k': 10, 'top_p': 0.95, 'typical_p': 0.95, 'temperature': 0.01, 'frequency_penalty': 0.0, 'presence_penalty': 0.0, 'repetition_penalty': 1.0, 'streaming': True, 'chat_template': None, 'documents': []} + # convert TGI/vLLM to unified OpenAI /v1/chat/completions format + next_inputs = {} + next_inputs["model"] = "tgi" # specifically clarify the fake model to make the format unified + next_inputs["messages"] = [{"role": "user", "content": prompt}] + next_inputs["max_tokens"] = llm_parameters_dict["max_new_tokens"] + next_inputs["top_p"] = llm_parameters_dict["top_p"] + next_inputs["stream"] = inputs["streaming"] + next_inputs["frequency_penalty"] = inputs["repetition_penalty"] + next_inputs["temperature"] = inputs["temperature"] + inputs = next_inputs + + return inputs + + +def align_generator(self, gen, **kwargs): + # openai reaponse format + # b'data:{"id":"","object":"text_completion","created":1725530204,"model":"meta-llama/Meta-Llama-3-8B-Instruct","system_fingerprint":"2.0.1-native","choices":[{"index":0,"delta":{"role":"assistant","content":"?"},"logprobs":null,"finish_reason":null}]}\n\n' + for line in gen: + line = line.decode("utf-8") + start = line.find("{") + end = line.rfind("}") + 1 + + json_str = line[start:end] + try: + # sometimes yield empty chunk, do a fallback here + json_data = json.loads(json_str) + if json_data["choices"][0]["finish_reason"] != "eos_token": + yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n" + except Exception as e: + yield f"data: {repr(json_str.encode('utf-8'))}\n\n" + yield "data: [DONE]\n\n" + + +class ChatQnAService: + def __init__(self, host="0.0.0.0", port=8000): + self.host = host + self.port = port + ServiceOrchestrator.align_inputs = align_inputs + # ServiceOrchestrator.align_outputs = align_outputs + ServiceOrchestrator.align_generator = align_generator + self.megaservice = ServiceOrchestrator() + + def add_remote_service(self): + + embedding = MicroService( + name="embedding", + host=EMBEDDING_RERANK_SERVICE_HOST_IP, + port=EMBEDDING_RERANK_SERVICE_PORT, + # endpoint="/embed", + endpoint="/v1/embeddings", + use_remote_service=True, + service_type=ServiceType.EMBEDDING, + ) + + retriever = MicroService( + name="retriever", + host=RETRIEVER_SERVICE_HOST_IP, + port=RETRIEVER_SERVICE_PORT, + endpoint="/v1/retrieval", + use_remote_service=True, + service_type=ServiceType.RETRIEVER, + ) + + rerank = MicroService( + name="rerank", + host=EMBEDDING_RERANK_SERVICE_HOST_IP, + port=EMBEDDING_RERANK_SERVICE_PORT, + # endpoint="/rerank", + endpoint="/v1/reranking", + use_remote_service=True, + service_type=ServiceType.RERANK, + ) + + llm = MicroService( + name="llm", + host=LLM_SERVER_HOST_IP, + port=LLM_SERVER_PORT, + endpoint="/v1/chat/completions", + use_remote_service=True, + service_type=ServiceType.LLM, + ) + self.megaservice.add(embedding).add(retriever).add(rerank).add(llm) + self.megaservice.flow_to(embedding, retriever) + self.megaservice.flow_to(retriever, rerank) + self.megaservice.flow_to(rerank, llm) + self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) + + +if __name__ == "__main__": + chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) + chatqna.add_remote_service() diff --git a/ChatQnA/docker_compose/intel/hpu/gaudi/compose_dynamic_batching.yaml b/ChatQnA/docker_compose/intel/hpu/gaudi/compose_dynamic_batching.yaml new file mode 100644 index 000000000..155b5918f --- /dev/null +++ b/ChatQnA/docker_compose/intel/hpu/gaudi/compose_dynamic_batching.yaml @@ -0,0 +1,127 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +services: + redis-vector-db: + image: redis/redis-stack:7.2.0-v9 + container_name: redis-vector-db + ports: + - "6379:6379" + - "8001:8001" + dataprep-redis-service: + image: ${REGISTRY:-opea}/dataprep-redis:${TAG:-latest} + container_name: dataprep-redis-server + depends_on: + - redis-vector-db + ports: + - "6007:6007" + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + REDIS_URL: ${REDIS_URL} + INDEX_NAME: ${INDEX_NAME} + # TEI_ENDPOINT: ${TEI_EMBEDDING_ENDPOINT} + EMBED_MODEL: ${EMBED_MODEL} + HUGGINGFACEHUB_API_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + embedding-reranking-service: + image: ${REGISTRY:-opea}/embedding-reranking-local:${TAG:-latest} + container_name: embedding-reranking-server + ports: + - "6001:6001" + runtime: habana + cap_add: + - SYS_NICE + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HABANA_VISIBLE_DEVICES: all # only use 1 gaudi card + OMPI_MCA_btl_vader_single_copy_mechanism: none + LOGFLAG: ${LOGFLAG} + DYNAMIC_BATCHING_TIMEOUT: 0.01 + DYNAMIC_BATCHING_MAX_BATCH_SIZE: 32 + PAD_SEQUENCE_TO_MULTIPLE_OF: 128 + EMBEDDING_MODEL_ID: "BAAI/bge-base-en-v1.5" + RERANK_MODEL_ID: "BAAI/bge-reranker-base" + retriever: + image: ${REGISTRY:-opea}/retriever-redis:${TAG:-latest} + container_name: retriever-redis-server + depends_on: + - redis-vector-db + ports: + - "7000:7000" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + REDIS_URL: ${REDIS_URL} + INDEX_NAME: ${INDEX_NAME} + restart: unless-stopped + tgi-service: + image: ghcr.io/huggingface/tgi-gaudi:2.0.1 + container_name: tgi-gaudi-server + ports: + - "8005:80" + volumes: + - "./data:/data" + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + HF_TOKEN: ${HUGGINGFACEHUB_API_TOKEN} + HF_HUB_DISABLE_PROGRESS_BARS: 1 + HF_HUB_ENABLE_HF_TRANSFER: 0 + HABANA_VISIBLE_DEVICES: all + OMPI_MCA_btl_vader_single_copy_mechanism: none + runtime: habana + cap_add: + - SYS_NICE + ipc: host + command: --model-id ${LLM_MODEL_ID} --max-input-length 4096 --max-total-tokens 8192 + chaqna-gaudi-backend-server: + image: ${REGISTRY:-opea}/chatqna-dynamic-batching:${TAG:-latest} + container_name: chatqna-gaudi-backend-server + depends_on: + - redis-vector-db + - retriever + - tgi-service + - embedding-reranking-service + ports: + - "8888:8888" + environment: + - no_proxy=${no_proxy} + - https_proxy=${https_proxy} + - http_proxy=${http_proxy} + - MEGA_SERVICE_HOST_IP=${MEGA_SERVICE_HOST_IP} + - RETRIEVER_SERVICE_HOST_IP=${RETRIEVER_SERVICE_HOST_IP} + - EMBEDDING_RERANK_SERVICE_HOST_IP=${EMBEDDING_RERANK_SERVICE_HOST_IP} + - EMBEDDING_RERANK_SERVICE_PORT=${EMBEDDING_RERANK_SERVICE_PORT:-6001} + - LLM_SERVER_HOST_IP=${LLM_SERVER_HOST_IP} + - LLM_SERVER_PORT=${LLM_SERVER_PORT:-8005} + - LOGFLAG=${LOGFLAG} + ipc: host + restart: always + chaqna-gaudi-ui-server: + image: ${REGISTRY:-opea}/chatqna-ui:${TAG:-latest} + container_name: chatqna-gaudi-ui-server + depends_on: + - chaqna-gaudi-backend-server + ports: + - "5173:5173" + environment: + - no_proxy=${no_proxy} + - https_proxy=${https_proxy} + - http_proxy=${http_proxy} + - CHAT_BASE_URL=${BACKEND_SERVICE_ENDPOINT} + - UPLOAD_FILE_BASE_URL=${DATAPREP_SERVICE_ENDPOINT} + - GET_FILE=${DATAPREP_GET_FILE_ENDPOINT} + - DELETE_FILE=${DATAPREP_DELETE_FILE_ENDPOINT} + ipc: host + restart: always + +networks: + default: + driver: bridge