Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Part 4 - Introduce Main RAG Service API and its tests #629

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ jobs:
- name: Run unit tests & Generate coverage
run: |
make unit-test
make rag-service-test
make tuning-metrics-server-test

- name: Run inference api unit tests
- name: Run inference api e2e tests
run: |
make inference-api-e2e

Expand Down
14 changes: 12 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,23 @@ unit-test: ## Run unit tests.
-race -coverprofile=coverage.txt -covermode=atomic
go tool cover -func=coverage.txt

.PHONY: rag-service-test
rag-service-test:
pip install -r presets/rag_service/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/rag_service/tests

.PHONY: tuning-metrics-server-test
tuning-metrics-server-test:
pytest -o log_cli=true -o log_cli_level=INFO presets/tuning/text-generation/metrics

## --------------------------------------
## E2E tests
## --------------------------------------

inference-api-e2e:
.PHONY: inference-api-e2e
inference-api-e2e:
pip install -r presets/inference/text-generation/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO .
pytest -o log_cli=true -o log_cli_level=INFO presets/inference/text-generation/tests

# Ginkgo configurations
GINKGO_FOCUS ?=
Expand Down
15 changes: 15 additions & 0 deletions ragengine/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# config.py
import os

EMBEDDING_TYPE = os.getenv("EMBEDDING_TYPE", "local")
EMBEDDING_URL = os.getenv("EMBEDDING_URL")

INFERENCE_URL = os.getenv("INFERENCE_URL", "http://52.190.41.209/chat")
INFERENCE_ACCESS_SECRET = os.getenv("AccessSecret", "default-inference-secret")
# RESPONSE_FIELD = os.getenv("RESPONSE_FIELD", "result")

MODEL_ID = os.getenv("MODEL_ID", "BAAI/bge-small-en-v1.5")
VECTOR_DB_TYPE = os.getenv("VECTOR_DB_TYPE", "faiss")
INDEX_SERVICE_NAME = os.getenv("INDEX_SERVICE_NAME", "default-index-service")
ACCESS_SECRET = os.getenv("ACCESS_SECRET", "default-access-secret")
PERSIST_DIR = "./storage"
110 changes: 110 additions & 0 deletions ragengine/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Dict, List

from embedding.huggingface_local import LocalHuggingFaceEmbedding
from embedding.huggingface_remote import RemoteHuggingFaceEmbedding
from fastapi import FastAPI, HTTPException
from llama_index.core.schema import TextNode
from llama_index.core.storage.docstore.types import RefDocInfo
from vector_store.faiss_store import FaissVectorStoreHandler
from vector_store_manager.manager import VectorStoreManager
Comment on lines +8 to +9
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Included in future PR


from config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID
from models import Document, IndexRequest, ListDocumentsResponse, QueryRequest

app = FastAPI()

# Initialize embedding model
if EMBEDDING_TYPE.lower() == "local":
embedding_manager = LocalHuggingFaceEmbedding(MODEL_ID)
elif EMBEDDING_TYPE.lower() == "remote":
embedding_manager = RemoteHuggingFaceEmbedding(MODEL_ID, ACCESS_SECRET)
else:
raise ValueError("Invalid Embedding Type Specified (Must be Local or Remote)")

# Initialize vector store
# TODO: Dynamically set VectorStore from EnvVars (which ultimately comes from CRD StorageSpec)
vector_store_handler = FaissVectorStoreHandler(embedding_manager)

# Initialize RAG operations
rag_ops = VectorStoreManager(vector_store_handler)

@app.post("/index", response_model=List[Document])
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling)
try:
doc_ids = rag_ops.create(request.index_name, request.documents)
documents = [
Document(doc_id=doc_id, text=doc.text, metadata=doc.metadata)
for doc_id, doc in zip(doc_ids, request.documents)
]
return documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query", response_model=Dict[str, str])
async def query_index(request: QueryRequest):
try:
llm_params = request.llm_params or {} # Default to empty dict if no params provided
response = rag_ops.read(request.index_name, request.query, request.top_k, llm_params)
return {"response": str(response)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

"""
@app.put("/update", response_model=Dict[str, List[str]])
async def update_documents(request: UpdateRequest):
try:
result = rag_ops.update(request.documents)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/refresh", response_model=List[bool])
async def refresh_documents(request: RefreshRequest):
try:
result = rag_ops.refresh(request.documents)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.delete("/document/{doc_id}")
async def delete_document(doc_id: str):
try:
rag_ops.delete(doc_id)
return {"message": "Document deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
"""

@app.get("/document/{index_name}/{doc_id}", response_model=RefDocInfo)
async def get_document(index_name: str, doc_id: str):
try:
document = rag_ops.get(index_name, doc_id)
if document:
return document
else:
raise HTTPException(status_code=404, detail=f"Document with ID {doc_id} "
f"not found in index '{index_name}'.")
except ValueError as ve:
raise HTTPException(status_code=404, detail=str(ve))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/indexed-documents", response_model=ListDocumentsResponse)
async def list_all_indexed_documents():
try:
documents = rag_ops.list_all_indexed_documents()
serialized_documents = {
index_name: {
doc_name: {
"text": doc_info.text, "hash": doc_info.hash
} for doc_name, doc_info in vector_store_index.docstore.docs.items()
}
for index_name, vector_store_index in documents.items()
}
return ListDocumentsResponse(documents=serialized_documents)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
22 changes: 22 additions & 0 deletions ragengine/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Dict, List, Optional

from pydantic import BaseModel


class Document(BaseModel):
text: str
metadata: Optional[dict] = {}
doc_id: Optional[str] = None

class IndexRequest(BaseModel):
index_name: str
documents: List[Document]

class QueryRequest(BaseModel):
index_name: str
query: str
top_k: int = 10
llm_params: Optional[Dict] = None # Accept a dictionary for parameters

class ListDocumentsResponse(BaseModel):
documents:Dict[str, Dict[str, Dict[str, str]]]
Empty file added ragengine/tests/__init__.py
Empty file.
Empty file added ragengine/tests/api/__init.py
Empty file.
7 changes: 7 additions & 0 deletions ragengine/tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread
168 changes: 168 additions & 0 deletions ragengine/tests/api/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, patch

import pytest
from embedding.huggingface_local import LocalHuggingFaceEmbedding
from fastapi.testclient import TestClient
from main import app, rag_ops
from vector_store.faiss_store import FaissVectorStoreHandler

from config import INFERENCE_ACCESS_SECRET, INFERENCE_URL, MODEL_ID
from models import Document

AUTO_GEN_DOC_ID_LEN = 36

client = TestClient(app)

def test_index_documents_success():
request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200
doc1, doc2 = response.json()
assert (doc1["text"] == "This is a test document")
assert len(doc1["doc_id"]) == AUTO_GEN_DOC_ID_LEN
assert not doc1["metadata"]

assert (doc2["text"] == "Another test document")
assert len(doc2["doc_id"]) == AUTO_GEN_DOC_ID_LEN
assert not doc2["metadata"]

@patch('requests.post')
def test_query_index_success(mock_post):
# Define Mock Response for Custom Inference API
mock_response = {
"result": "This is the completion from the API"
}
mock_post.return_value.json.return_value = mock_response
# Index
request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

# Query
request_data = {
"index_name": "test_index",
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 200
assert response.json() == {"response": "This is the completion from the API"}
assert mock_post.call_count == 1

def test_query_index_failure():
# Prepare request data for querying.
request_data = {
"index_name": "non_existent_index", # Use an index name that doesn't exist
"query": "test query",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 500
assert response.json()["detail"] == "No such index: 'non_existent_index' exists."


def test_get_document_success():
request_data = {
"index_name": "test_index",
"documents": [
# {"doc_id": "doc1", "text": "This is a test document"},
{"doc_id": "doc1", "text": "This is a test document"},
{"text": "Another test document"}
]
}

index_response = client.post("/index", json=request_data)
assert index_response.status_code == 200

# Call the GET document endpoint.
get_response = client.get("/document/test_index/doc1")
assert get_response.status_code == 200

response_json = get_response.json()

assert response_json.keys() == {"node_ids", 'metadata'}
assert response_json['metadata'] == {}

assert isinstance(response_json["node_ids"], list) and len(response_json["node_ids"]) == 1


def test_get_document_failure():
# Call the GET document endpoint.
response = client.get("/document/test_index/doc1")
assert response.status_code == 404

def test_list_all_indexed_documents_success():
response = client.get("/indexed-documents")
assert response.status_code == 200
assert response.json() == {'documents': {}}

request_data = {
"index_name": "test_index",
"documents": [
{"text": "This is a test document"},
{"text": "Another test document"}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

response = client.get("/indexed-documents")
assert response.status_code == 200
assert "test_index" in response.json()["documents"]
response_idx = response.json()["documents"]["test_index"]
assert len(response_idx) == 2 # Two Documents Indexed
assert ({item["text"] for item in response_idx.values()}
== {item["text"] for item in request_data["documents"]})


"""
Example of a live query test. This test is currently commented out as it requires a valid
INFERENCE_URL in config.py. To run the test, ensure that a valid INFERENCE_URL is provided.
Upon execution, RAG results should be observed.
def test_live_query_test():
# Index
request_data = {
"index_name": "test_index",
"documents": [
{"text": "Polar bear – can lift 450Kg (approximately 0.7 times their body weight) \
Adult male polar bears can grow to be anywhere between 300 and 700kg"},
{"text": "Giraffes are the tallest mammals and are well-adapted to living in trees. \
They have few predators as adults."}
]
}

response = client.post("/index", json=request_data)
assert response.status_code == 200

# Query
request_data = {
"index_name": "test_index",
"query": "What is the strongest bear?",
"top_k": 1,
"llm_params": {"temperature": 0.7}
}

response = client.post("/query", json=request_data)
assert response.status_code == 200
"""