-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Part 4 - Introduce Main RAG Service API and its tests #629
Open
ishaansehgal99
wants to merge
2
commits into
main
Choose a base branch
from
Ishaan/RAG-Part-2-API
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# config.py | ||
import os | ||
|
||
EMBEDDING_TYPE = os.getenv("EMBEDDING_TYPE", "local") | ||
EMBEDDING_URL = os.getenv("EMBEDDING_URL") | ||
|
||
INFERENCE_URL = os.getenv("INFERENCE_URL", "http://52.190.41.209/chat") | ||
INFERENCE_ACCESS_SECRET = os.getenv("AccessSecret", "default-inference-secret") | ||
# RESPONSE_FIELD = os.getenv("RESPONSE_FIELD", "result") | ||
|
||
MODEL_ID = os.getenv("MODEL_ID", "BAAI/bge-small-en-v1.5") | ||
VECTOR_DB_TYPE = os.getenv("VECTOR_DB_TYPE", "faiss") | ||
INDEX_SERVICE_NAME = os.getenv("INDEX_SERVICE_NAME", "default-index-service") | ||
ACCESS_SECRET = os.getenv("ACCESS_SECRET", "default-access-secret") | ||
PERSIST_DIR = "./storage" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from typing import Dict, List | ||
|
||
from embedding.huggingface_local import LocalHuggingFaceEmbedding | ||
from embedding.huggingface_remote import RemoteHuggingFaceEmbedding | ||
from fastapi import FastAPI, HTTPException | ||
from llama_index.core.schema import TextNode | ||
from llama_index.core.storage.docstore.types import RefDocInfo | ||
from vector_store.faiss_store import FaissVectorStoreHandler | ||
from vector_store_manager.manager import VectorStoreManager | ||
|
||
from config import ACCESS_SECRET, EMBEDDING_TYPE, MODEL_ID | ||
from models import Document, IndexRequest, ListDocumentsResponse, QueryRequest | ||
|
||
app = FastAPI() | ||
|
||
# Initialize embedding model | ||
if EMBEDDING_TYPE.lower() == "local": | ||
embedding_manager = LocalHuggingFaceEmbedding(MODEL_ID) | ||
elif EMBEDDING_TYPE.lower() == "remote": | ||
embedding_manager = RemoteHuggingFaceEmbedding(MODEL_ID, ACCESS_SECRET) | ||
else: | ||
raise ValueError("Invalid Embedding Type Specified (Must be Local or Remote)") | ||
|
||
# Initialize vector store | ||
# TODO: Dynamically set VectorStore from EnvVars (which ultimately comes from CRD StorageSpec) | ||
vector_store_handler = FaissVectorStoreHandler(embedding_manager) | ||
|
||
# Initialize RAG operations | ||
rag_ops = VectorStoreManager(vector_store_handler) | ||
|
||
@app.post("/index", response_model=List[Document]) | ||
async def index_documents(request: IndexRequest): # TODO: Research async/sync what to use (inference is calling) | ||
try: | ||
doc_ids = rag_ops.create(request.index_name, request.documents) | ||
documents = [ | ||
Document(doc_id=doc_id, text=doc.text, metadata=doc.metadata) | ||
for doc_id, doc in zip(doc_ids, request.documents) | ||
] | ||
return documents | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.post("/query", response_model=Dict[str, str]) | ||
async def query_index(request: QueryRequest): | ||
try: | ||
llm_params = request.llm_params or {} # Default to empty dict if no params provided | ||
response = rag_ops.read(request.index_name, request.query, request.top_k, llm_params) | ||
return {"response": str(response)} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
""" | ||
@app.put("/update", response_model=Dict[str, List[str]]) | ||
async def update_documents(request: UpdateRequest): | ||
try: | ||
result = rag_ops.update(request.documents) | ||
return result | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.post("/refresh", response_model=List[bool]) | ||
async def refresh_documents(request: RefreshRequest): | ||
try: | ||
result = rag_ops.refresh(request.documents) | ||
return result | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.delete("/document/{doc_id}") | ||
async def delete_document(doc_id: str): | ||
try: | ||
rag_ops.delete(doc_id) | ||
return {"message": "Document deleted successfully"} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
""" | ||
|
||
@app.get("/document/{index_name}/{doc_id}", response_model=RefDocInfo) | ||
async def get_document(index_name: str, doc_id: str): | ||
try: | ||
document = rag_ops.get(index_name, doc_id) | ||
if document: | ||
return document | ||
else: | ||
raise HTTPException(status_code=404, detail=f"Document with ID {doc_id} " | ||
f"not found in index '{index_name}'.") | ||
except ValueError as ve: | ||
raise HTTPException(status_code=404, detail=str(ve)) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
@app.get("/indexed-documents", response_model=ListDocumentsResponse) | ||
async def list_all_indexed_documents(): | ||
try: | ||
documents = rag_ops.list_all_indexed_documents() | ||
serialized_documents = { | ||
index_name: { | ||
doc_name: { | ||
"text": doc_info.text, "hash": doc_info.hash | ||
} for doc_name, doc_info in vector_store_index.docstore.docs.items() | ||
} | ||
for index_name, vector_store_index in documents.items() | ||
} | ||
return ListDocumentsResponse(documents=serialized_documents) | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import Dict, List, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class Document(BaseModel): | ||
text: str | ||
metadata: Optional[dict] = {} | ||
doc_id: Optional[str] = None | ||
|
||
class IndexRequest(BaseModel): | ||
index_name: str | ||
documents: List[Document] | ||
|
||
class QueryRequest(BaseModel): | ||
index_name: str | ||
query: str | ||
top_k: int = 10 | ||
llm_params: Optional[Dict] = None # Accept a dictionary for parameters | ||
|
||
class ListDocumentsResponse(BaseModel): | ||
documents:Dict[str, Dict[str, Dict[str, str]]] |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Force CPU-only execution for testing | ||
os.environ["OMP_NUM_THREADS"] = "1" # Force single-threaded for testing to prevent segfault while loading embedding model | ||
os.environ["MKL_NUM_THREADS"] = "1" # Force MKL to use a single thread |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import os | ||
from tempfile import TemporaryDirectory | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
from embedding.huggingface_local import LocalHuggingFaceEmbedding | ||
from fastapi.testclient import TestClient | ||
from main import app, rag_ops | ||
from vector_store.faiss_store import FaissVectorStoreHandler | ||
|
||
from config import INFERENCE_ACCESS_SECRET, INFERENCE_URL, MODEL_ID | ||
from models import Document | ||
|
||
AUTO_GEN_DOC_ID_LEN = 36 | ||
|
||
client = TestClient(app) | ||
|
||
def test_index_documents_success(): | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
doc1, doc2 = response.json() | ||
assert (doc1["text"] == "This is a test document") | ||
assert len(doc1["doc_id"]) == AUTO_GEN_DOC_ID_LEN | ||
assert not doc1["metadata"] | ||
|
||
assert (doc2["text"] == "Another test document") | ||
assert len(doc2["doc_id"]) == AUTO_GEN_DOC_ID_LEN | ||
assert not doc2["metadata"] | ||
|
||
@patch('requests.post') | ||
def test_query_index_success(mock_post): | ||
# Define Mock Response for Custom Inference API | ||
mock_response = { | ||
"result": "This is the completion from the API" | ||
} | ||
mock_post.return_value.json.return_value = mock_response | ||
# Index | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
|
||
# Query | ||
request_data = { | ||
"index_name": "test_index", | ||
"query": "test query", | ||
"top_k": 1, | ||
"llm_params": {"temperature": 0.7} | ||
} | ||
|
||
response = client.post("/query", json=request_data) | ||
assert response.status_code == 200 | ||
assert response.json() == {"response": "This is the completion from the API"} | ||
assert mock_post.call_count == 1 | ||
|
||
def test_query_index_failure(): | ||
# Prepare request data for querying. | ||
request_data = { | ||
"index_name": "non_existent_index", # Use an index name that doesn't exist | ||
"query": "test query", | ||
"top_k": 1, | ||
"llm_params": {"temperature": 0.7} | ||
} | ||
|
||
response = client.post("/query", json=request_data) | ||
assert response.status_code == 500 | ||
assert response.json()["detail"] == "No such index: 'non_existent_index' exists." | ||
|
||
|
||
def test_get_document_success(): | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
# {"doc_id": "doc1", "text": "This is a test document"}, | ||
{"doc_id": "doc1", "text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
index_response = client.post("/index", json=request_data) | ||
assert index_response.status_code == 200 | ||
|
||
# Call the GET document endpoint. | ||
get_response = client.get("/document/test_index/doc1") | ||
assert get_response.status_code == 200 | ||
|
||
response_json = get_response.json() | ||
|
||
assert response_json.keys() == {"node_ids", 'metadata'} | ||
assert response_json['metadata'] == {} | ||
|
||
assert isinstance(response_json["node_ids"], list) and len(response_json["node_ids"]) == 1 | ||
|
||
|
||
def test_get_document_failure(): | ||
# Call the GET document endpoint. | ||
response = client.get("/document/test_index/doc1") | ||
assert response.status_code == 404 | ||
|
||
def test_list_all_indexed_documents_success(): | ||
response = client.get("/indexed-documents") | ||
assert response.status_code == 200 | ||
assert response.json() == {'documents': {}} | ||
|
||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "This is a test document"}, | ||
{"text": "Another test document"} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
|
||
response = client.get("/indexed-documents") | ||
assert response.status_code == 200 | ||
assert "test_index" in response.json()["documents"] | ||
response_idx = response.json()["documents"]["test_index"] | ||
assert len(response_idx) == 2 # Two Documents Indexed | ||
assert ({item["text"] for item in response_idx.values()} | ||
== {item["text"] for item in request_data["documents"]}) | ||
|
||
|
||
""" | ||
Example of a live query test. This test is currently commented out as it requires a valid | ||
INFERENCE_URL in config.py. To run the test, ensure that a valid INFERENCE_URL is provided. | ||
Upon execution, RAG results should be observed. | ||
def test_live_query_test(): | ||
# Index | ||
request_data = { | ||
"index_name": "test_index", | ||
"documents": [ | ||
{"text": "Polar bear – can lift 450Kg (approximately 0.7 times their body weight) \ | ||
Adult male polar bears can grow to be anywhere between 300 and 700kg"}, | ||
{"text": "Giraffes are the tallest mammals and are well-adapted to living in trees. \ | ||
They have few predators as adults."} | ||
] | ||
} | ||
|
||
response = client.post("/index", json=request_data) | ||
assert response.status_code == 200 | ||
|
||
# Query | ||
request_data = { | ||
"index_name": "test_index", | ||
"query": "What is the strongest bear?", | ||
"top_k": 1, | ||
"llm_params": {"temperature": 0.7} | ||
} | ||
|
||
response = client.post("/query", json=request_data) | ||
assert response.status_code == 200 | ||
""" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Included in future PR