-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add megaservice definition without microservice wrappers (#700)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
0629696
commit ebe6b47
Showing
7 changed files
with
1,192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
|
||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
FROM python:3.11-slim | ||
|
||
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ | ||
libgl1-mesa-glx \ | ||
libjemalloc-dev \ | ||
vim \ | ||
git | ||
|
||
RUN useradd -m -s /bin/bash user && \ | ||
mkdir -p /home/user && \ | ||
chown -R user /home/user/ | ||
|
||
WORKDIR /home/user/ | ||
RUN git clone https://github.com/opea-project/GenAIComps.git | ||
|
||
WORKDIR /home/user/GenAIComps | ||
RUN pip install --no-cache-dir --upgrade pip && \ | ||
pip install --no-cache-dir -r /home/user/GenAIComps/requirements.txt && \ | ||
pip install --no-cache-dir langchain_core | ||
|
||
COPY ./chatqna_no_wrapper.py /home/user/chatqna_no_wrapper.py | ||
|
||
ENV PYTHONPATH=$PYTHONPATH:/home/user/GenAIComps | ||
|
||
USER user | ||
|
||
WORKDIR /home/user | ||
|
||
ENTRYPOINT ["python", "chatqna_no_wrapper.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
import re | ||
|
||
from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType | ||
from langchain_core.prompts import PromptTemplate | ||
|
||
|
||
class ChatTemplate: | ||
@staticmethod | ||
def generate_rag_prompt(question, documents): | ||
context_str = "\n".join(documents) | ||
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3: | ||
# chinese context | ||
template = """ | ||
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。 | ||
### 搜索结果:{context} | ||
### 问题:{question} | ||
### 回答: | ||
""" | ||
else: | ||
template = """ | ||
### You are a helpful, respectful and honest assistant to help the user with questions. \ | ||
Please refer to the search results obtained from the local knowledge base. \ | ||
But be careful to not incorporate the information that you think is not relevant to the question. \ | ||
If you don't know the answer to a question, please don't share false information. \n | ||
### Search results: {context} \n | ||
### Question: {question} \n | ||
### Answer: | ||
""" | ||
return template.format(context=context_str, question=question) | ||
|
||
|
||
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") | ||
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) | ||
# EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") | ||
# EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) | ||
# RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") | ||
# RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) | ||
# RERANK_SERVICE_HOST_IP = os.getenv("RERANK_SERVICE_HOST_IP", "0.0.0.0") | ||
# RERANK_SERVICE_PORT = int(os.getenv("RERANK_SERVICE_PORT", 8000)) | ||
# LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") | ||
# LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) | ||
EMBEDDING_SERVER_HOST_IP = os.getenv("EMBEDDING_SERVER_HOST_IP", "0.0.0.0") | ||
EMBEDDING_SERVER_PORT = int(os.getenv("EMBEDDING_SERVER_PORT", 6006)) | ||
RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") | ||
RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) | ||
RERANK_SERVER_HOST_IP = os.getenv("RERANK_SERVER_HOST_IP", "0.0.0.0") | ||
RERANK_SERVER_PORT = int(os.getenv("RERANK_SERVER_PORT", 8808)) | ||
LLM_SERVER_HOST_IP = os.getenv("LLM_SERVER_HOST_IP", "0.0.0.0") | ||
LLM_SERVER_PORT = int(os.getenv("LLM_SERVER_PORT", 9009)) | ||
|
||
|
||
def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): | ||
if self.services[cur_node].service_type == ServiceType.EMBEDDING: | ||
inputs["inputs"] = inputs["text"] | ||
del inputs["text"] | ||
elif self.services[cur_node].service_type == ServiceType.RETRIEVER: | ||
# prepare the retriever params | ||
retriever_parameters = kwargs.get("retriever_parameters", None) | ||
if retriever_parameters: | ||
inputs.update(retriever_parameters.dict()) | ||
elif self.services[cur_node].service_type == ServiceType.LLM: | ||
# convert TGI/vLLM to unified OpenAI /v1/chat/completions format | ||
next_inputs = {} | ||
next_inputs["model"] = "tgi" # specifically clarify the fake model to make the format unified | ||
next_inputs["messages"] = [{"role": "user", "content": inputs["inputs"]}] | ||
next_inputs["max_tokens"] = llm_parameters_dict["max_new_tokens"] | ||
next_inputs["top_p"] = llm_parameters_dict["top_p"] | ||
next_inputs["stream"] = inputs["streaming"] | ||
next_inputs["frequency_penalty"] = inputs["repetition_penalty"] | ||
next_inputs["temperature"] = inputs["temperature"] | ||
inputs = next_inputs | ||
|
||
return inputs | ||
|
||
|
||
def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs): | ||
next_data = {} | ||
if self.services[cur_node].service_type == ServiceType.EMBEDDING: | ||
assert isinstance(data, list) | ||
next_data = {"text": inputs["inputs"], "embedding": data[0]} | ||
elif self.services[cur_node].service_type == ServiceType.RETRIEVER: | ||
|
||
docs = [doc["text"] for doc in data["retrieved_docs"]] | ||
|
||
with_rerank = runtime_graph.downstream(cur_node)[0].startswith("rerank") | ||
if with_rerank and docs: | ||
# forward to rerank | ||
# prepare inputs for rerank | ||
next_data["query"] = data["initial_query"] | ||
next_data["texts"] = [doc["text"] for doc in data["retrieved_docs"]] | ||
else: | ||
# forward to llm | ||
if not docs: | ||
# delete the rerank from retriever -> rerank -> llm | ||
for ds in reversed(runtime_graph.downstream(cur_node)): | ||
for nds in runtime_graph.downstream(ds): | ||
runtime_graph.add_edge(cur_node, nds) | ||
runtime_graph.delete_node_if_exists(ds) | ||
|
||
# handle template | ||
# if user provides template, then format the prompt with it | ||
# otherwise, use the default template | ||
prompt = data["initial_query"] | ||
chat_template = llm_parameters_dict["chat_template"] | ||
if chat_template: | ||
prompt_template = PromptTemplate.from_template(chat_template) | ||
input_variables = prompt_template.input_variables | ||
if sorted(input_variables) == ["context", "question"]: | ||
prompt = prompt_template.format(question=data["initial_query"], context="\n".join(docs)) | ||
elif input_variables == ["question"]: | ||
prompt = prompt_template.format(question=data["initial_query"]) | ||
else: | ||
print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") | ||
prompt = ChatTemplate.generate_rag_prompt(data["initial_query"], docs) | ||
else: | ||
prompt = ChatTemplate.generate_rag_prompt(data["initial_query"], docs) | ||
|
||
next_data["inputs"] = prompt | ||
|
||
elif self.services[cur_node].service_type == ServiceType.RERANK: | ||
# rerank the inputs with the scores | ||
reranker_parameters = kwargs.get("reranker_parameters", None) | ||
top_n = reranker_parameters.top_n if reranker_parameters else 1 | ||
docs = inputs["texts"] | ||
reranked_docs = [] | ||
for best_response in data[:top_n]: | ||
reranked_docs.append(docs[best_response["index"]]) | ||
|
||
# handle template | ||
# if user provides template, then format the prompt with it | ||
# otherwise, use the default template | ||
prompt = inputs["query"] | ||
chat_template = llm_parameters_dict["chat_template"] | ||
if chat_template: | ||
prompt_template = PromptTemplate.from_template(chat_template) | ||
input_variables = prompt_template.input_variables | ||
if sorted(input_variables) == ["context", "question"]: | ||
prompt = prompt_template.format(question=prompt, context="\n".join(docs)) | ||
elif input_variables == ["question"]: | ||
prompt = prompt_template.format(question=prompt) | ||
else: | ||
print(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") | ||
prompt = ChatTemplate.generate_rag_prompt(prompt, docs) | ||
else: | ||
prompt = ChatTemplate.generate_rag_prompt(prompt, docs) | ||
|
||
next_data["inputs"] = prompt | ||
|
||
return next_data | ||
|
||
|
||
def align_generator(self, gen, **kwargs): | ||
# openai reaponse format | ||
# b'data:{"id":"","object":"text_completion","created":1725530204,"model":"meta-llama/Meta-Llama-3-8B-Instruct","system_fingerprint":"2.0.1-native","choices":[{"index":0,"delta":{"role":"assistant","content":"?"},"logprobs":null,"finish_reason":null}]}\n\n' | ||
for line in gen: | ||
line = line.decode("utf-8") | ||
start = line.find("{") | ||
end = line.rfind("}") + 1 | ||
|
||
json_str = line[start:end] | ||
try: | ||
# sometimes yield empty chunk, do a fallback here | ||
json_data = json.loads(json_str) | ||
if json_data["choices"][0]["finish_reason"] != "eos_token": | ||
yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n" | ||
except Exception as e: | ||
yield f"data: {repr(json_str.encode('utf-8'))}\n\n" | ||
yield "data: [DONE]\n\n" | ||
|
||
|
||
class ChatQnAService: | ||
def __init__(self, host="0.0.0.0", port=8000): | ||
self.host = host | ||
self.port = port | ||
ServiceOrchestrator.align_inputs = align_inputs | ||
ServiceOrchestrator.align_outputs = align_outputs | ||
ServiceOrchestrator.align_generator = align_generator | ||
self.megaservice = ServiceOrchestrator() | ||
|
||
def add_remote_service(self): | ||
|
||
embedding = MicroService( | ||
name="embedding", | ||
host=EMBEDDING_SERVER_HOST_IP, | ||
port=EMBEDDING_SERVER_PORT, | ||
endpoint="/embed", | ||
use_remote_service=True, | ||
service_type=ServiceType.EMBEDDING, | ||
) | ||
|
||
retriever = MicroService( | ||
name="retriever", | ||
host=RETRIEVER_SERVICE_HOST_IP, | ||
port=RETRIEVER_SERVICE_PORT, | ||
endpoint="/v1/retrieval", | ||
use_remote_service=True, | ||
service_type=ServiceType.RETRIEVER, | ||
) | ||
|
||
rerank = MicroService( | ||
name="rerank", | ||
host=RERANK_SERVER_HOST_IP, | ||
port=RERANK_SERVER_PORT, | ||
endpoint="/rerank", | ||
use_remote_service=True, | ||
service_type=ServiceType.RERANK, | ||
) | ||
|
||
llm = MicroService( | ||
name="llm", | ||
host=LLM_SERVER_HOST_IP, | ||
port=LLM_SERVER_PORT, | ||
endpoint="/v1/chat/completions", | ||
use_remote_service=True, | ||
service_type=ServiceType.LLM, | ||
) | ||
self.megaservice.add(embedding).add(retriever).add(rerank).add(llm) | ||
self.megaservice.flow_to(embedding, retriever) | ||
self.megaservice.flow_to(retriever, rerank) | ||
self.megaservice.flow_to(rerank, llm) | ||
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) | ||
|
||
def add_remote_service_without_rerank(self): | ||
|
||
embedding = MicroService( | ||
name="embedding", | ||
host=EMBEDDING_SERVER_HOST_IP, | ||
port=EMBEDDING_SERVER_PORT, | ||
endpoint="/embed", | ||
use_remote_service=True, | ||
service_type=ServiceType.EMBEDDING, | ||
) | ||
|
||
retriever = MicroService( | ||
name="retriever", | ||
host=RETRIEVER_SERVICE_HOST_IP, | ||
port=RETRIEVER_SERVICE_PORT, | ||
endpoint="/v1/retrieval", | ||
use_remote_service=True, | ||
service_type=ServiceType.RETRIEVER, | ||
) | ||
|
||
llm = MicroService( | ||
name="llm", | ||
host=LLM_SERVER_HOST_IP, | ||
port=LLM_SERVER_PORT, | ||
endpoint="/v1/chat/completions", | ||
use_remote_service=True, | ||
service_type=ServiceType.LLM, | ||
) | ||
self.megaservice.add(embedding).add(retriever).add(llm) | ||
self.megaservice.flow_to(embedding, retriever) | ||
self.megaservice.flow_to(retriever, llm) | ||
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port) | ||
|
||
|
||
if __name__ == "__main__": | ||
chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT) | ||
chatqna.add_remote_service() | ||
# chatqna.add_remote_service_without_rerank() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.