Skip to content

Commit

Permalink
[GenAI] Added and Updated steps (#22)
Browse files Browse the repository at this point in the history
* [GenAI] Added and Updated steps

* First fix after guy's review
  • Loading branch information
ZeevRispler authored Sep 16, 2024
1 parent 04ceda1 commit 933b873
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 54 deletions.
202 changes: 148 additions & 54 deletions genai_factory/src/genai_factory/chains/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, List, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.prompts import PromptTemplate
from langchain.prompts import PromptTemplate
from langchain.schema import Document

from genai_factory.chains.base import ChainRunner
from genai_factory.config import get_llm, get_vector_db, logger
from genai_factory.config import get_llm, get_vector_db
from genai_factory.schemas import WorkflowEvent

#TODO use workflow server logger
logger = logging.getLogger(__name__)


class DocumentCallbackHandler(BaseCallbackHandler):
"""A callback handler that adds index number to documents retrieved."""
"""Callback handler that adds index numbers to retrieved documents."""

def on_retriever_end(
self,
documents,
*,
run_id,
parent_run_id,
**kwargs,
):
logger.debug(f"on_retriever: {documents}")
if documents:
for i, doc in enumerate(documents):
doc.metadata["index"] = str(i)
def on_retriever_end(self, documents: List[Document], **kwargs):
"""
Add index numbers to the retrieved documents.
:param documents: The retrieved documents.
"""
logger.debug(f"Retrieved documents: {documents}")
for i, doc in enumerate(documents):
doc.metadata["index"] = str(i)


class DocumentRetriever:
Expand All @@ -56,36 +60,71 @@ class DocumentRetriever:
"""

def __init__(
self, llm, vector_store, verbose=False, chain_type: str = None, **search_kwargs
self,
llm,
vector_store,
verbose: bool = False,
chain_type: Optional[str] = None,
**search_kwargs,
):
"""
Initialize the document retriever.
:param llm: A language model to use for answering questions.
:param vector_store: A vector store to use for storing and retrieving documents.
:param verbose: Whether to print debug information.
:param chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", "refine" and "map_rerank".
:param search_kwargs: Additional keyword arguments to pass to the vector store.
"""
# Create a prompt template for the documents for when they are retrieved to the llm
document_prompt = PromptTemplate(
template="Content: {page_content}\nSource: {index}",
input_variables=["page_content", "index"],
)

self.chain = RetrievalQAWithSourcesChain.from_chain_type(
chain_type=chain_type or "stuff", # "map_reduce",
llm=llm,
retriever=vector_store.as_retriever(search_kwargs=search_kwargs),
chain_type=chain_type or "stuff",
return_source_documents=True,
chain_type_kwargs={"document_prompt": document_prompt},
verbose=verbose,
)
handler = DocumentCallbackHandler()
handler.verbose = verbose
self.chain_type = chain_type
self.cb = DocumentCallbackHandler()
self.cb.verbose = verbose
self.verbose = verbose
self.cb = handler
self.chain_type = chain_type

@classmethod
def from_config(cls, config, collection_name: str = None, **search_kwargs):
"""Creates a document retriever from a config object."""
def from_config(
cls, config, collection_name: Optional[str] = None, **search_kwargs
):
"""
Create a document retriever from a config object.
:param config: The config object to use for creating the retriever.
:param collection_name: The name of the collection to use.
:param search_kwargs: Additional keyword arguments to pass to the vector store.
:return: A new DocumentRetriever instance.
"""
vector_db = get_vector_db(config, collection_name=collection_name)
llm = get_llm(config)
return cls(llm, vector_db, verbose=config.verbose, **search_kwargs)

def _get_answer(self, query):
result = self.chain({"question": query.content}, callbacks=[self.cb])
def _get_answer(self, query: str) -> tuple[str, List[Document]]:
"""
Get the answer to a question and the source documents used.
:param query: The question to answer.
:return: A tuple containing the answer and the source documents.
"""
# Run the chain to get the answer and source documents
result = self.chain({"question": query}, callbacks=[self.cb])

# Filter the source documents to only include the ones that were used as sources and clean up the metadata
sources = [s.strip() for s in result["sources"].split(",")]
source_docs = [
doc
Expand All @@ -97,63 +136,118 @@ def _get_answer(self, query):
logger.info(f"Source documents:\n{docs_string}")
return result["answer"], source_docs

def run(self, event: WorkflowEvent):
def run(self, event: WorkflowEvent) -> Dict[str, any]:
"""
Run the retrieval with the given event.
:param event: The event to run the retrieval with.
:return: A dictionary containing the answer and the source documents.
"""
# TODO: use text when is_cli
logger.debug(f"Retriever Question: {event.query}\n")
answer, sources = self._get_answer(event.query)
logger.debug(f"answer: {answer} \nSources: {sources}")
logger.debug(f"Retriever Question: {event.query}")
# event.query.content is not always present
query = event.query.content if hasattr(event.query, "content") else event.query
answer, sources = self._get_answer(query)
logger.debug(f"Answer: {answer}\nSources: {sources}")
return {"answer": answer, "sources": sources}


class MultiRetriever(ChainRunner):
def __init__(self, llm=None, default_collection=None, **kwargs):
"""A class that manages multiple document retrievers."""

def __init__(self, llm=None, default_collection: Optional[str] = None, **kwargs):
"""
Initialize the multi retriever.
:param llm: The language model to use.
:param default_collection: The default collection to use.
"""
super().__init__(**kwargs)
self.llm = llm
self.default_collection = default_collection
self._retrievers = {}
self._retrievers: Dict[str, DocumentRetriever] = {}

def post_init(self, mode="sync"):
def post_init(self, mode: str = "sync"):
"""
Post initialization function, set the language model and default collection.
:param mode: The mode to use. #TODO what is this?
"""
self.llm = self.llm or get_llm(self.context._config)
if not self.default_collection:
self.default_collection = self.context._config.default_collection()

def _get_retriever(self, collection_name: str = None):
def _get_retriever(
self, collection_name: Optional[str] = None
) -> DocumentRetriever:
"""
Get a retriever for a given collection.
:param collection_name: The name of the collection to get the retriever for.
:return: The retriever for the given collection.
"""
collection_name = collection_name or self.default_collection
logger.debug(f"Selected collection: {collection_name}")
# Create a new retriever if one does not exist for the given collection
if collection_name not in self._retrievers:
# Get the vector database for the collection
vector_db = get_vector_db(
self.context._config, collection_name=collection_name
)
retriever = DocumentRetriever(
self.llm,
vector_db,
verbose=self.verbose,
# collection_name=collection_name,
)
# Create a new retriever and store it
retriever = DocumentRetriever(self.llm, vector_db, verbose=self.verbose)
self._retrievers[collection_name] = retriever

return self._retrievers[collection_name]

def _run(self, event: WorkflowEvent):
retriever = self._get_retriever(event.kwargs.get("collection_name"))
def _run(self, event: WorkflowEvent) -> Dict[str, any]:
"""
Run the multi retriever.
:param event: The event to run the retriever with.
:return: A dictionary containing the answer and the source documents.
"""
collection_name = event.kwargs.get(
"collection_name"
) # TODO name always in kwargs?
retriever = self._get_retriever(collection_name)
return retriever.run(event)


def fix_milvus_filter_arg(vector_db, search_kwargs):
"""Fixes the milvus filter argument."""
# detect if its Milvus and need to swap the filter dict arg with expr string
def fix_milvus_filter_arg(vector_db, search_kwargs: Dict[str, any]):
"""
Fix the Milvus filter argument if necessary.
:param vector_db: The vector database to fix the filter argument for.
:param search_kwargs: The search keyword arguments to fix.
"""
if "filter" in search_kwargs and hasattr(vector_db, "_create_connection_alias"):
filter = search_kwargs.pop("filter")
if isinstance(filter, dict):
# convert a dict of key value pairs to a string with key1=value1 and key2=value2
filter = " and ".join(f"{k}={v}" for k, v in filter.items())
search_kwargs["expr"] = filter
filter_arg = search_kwargs.pop("filter")
if isinstance(filter_arg, dict):
filter_str = " and ".join(f"{k}={v}" for k, v in filter_arg.items())
else:
filter_str = filter_arg
search_kwargs["expr"] = filter_str


def get_retriever_from_config(
config, verbose=False, collection_name: str = None, **search_kwargs
):
"""Creates a document retriever from a config object."""
config,
verbose: bool = False,
collection_name: Optional[str] = None,
**search_kwargs,
) -> DocumentRetriever:
"""
Create a document retriever from a config object.
:param config: The config object to use for creating the retriever.
:param verbose: Whether to print debug information.
:param collection_name: The name of the collection to use.
:param search_kwargs: Additional keyword arguments to pass to the vector store.
:return: A new DocumentRetriever instance.
"""
vector_db = get_vector_db(config, collection_name=collection_name)
llm = get_llm(config)
verbose = verbose or config.verbose
Expand Down
68 changes: 68 additions & 0 deletions genai_factory/src/genai_factory/chains/sentiment_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import pipeline

from genai_factory.chains.base import ChainRunner


class SentimentAnalysisStep(ChainRunner):
"""
Processes sentiment analysis on a given text.
"""

# Default model to use as model and tokenizer if not given
DEFAULT_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"

def __init__(
self,
tokenizer: str = None,
model: str = None,
pipeline_kwargs: dict = None,
**kwargs,
):
"""
Initialize the sentiment analysis step.
:param model: The name of the model to use, if not given, the default model will be used, has to be
from the roberta model family.
:param tokenizer: The name of the tokenizer to use, if not given, the default tokenizer will be used,
has to be compatible with the model.
:param pipeline_kwargs: Additional keyword arguments to pass to the HuggingFace pipeline.
"""
super().__init__(**kwargs)
self.tokenizer = tokenizer or self.DEFAULT_MODEL
self.model = model or self.DEFAULT_MODEL
# Load the HuggingFace sentiment analysis pipeline
self.sentiment_classifier = pipeline(
"sentiment-analysis",
tokenizer=self.tokenizer,
model=self.model,
**pipeline_kwargs,
)

def _run(self, event):
"""
Run the sentiment analysis step.
:param event: The event to process.
:return: The processed event with the sentiment analysis result.
"""
query = event.query
sentiment = self.sentiment_classifier(
query
) # Is a list of dictionaries (in tested examples)
return {
"answer": sentiment[0]["label"],
"sources": "",
} # TODO: Can only return string

0 comments on commit 933b873

Please sign in to comment.