diff --git a/controller/requirements.txt b/controller/requirements.txt index 4f3c6b1..5ca77c1 100644 --- a/controller/requirements.txt +++ b/controller/requirements.txt @@ -6,5 +6,6 @@ SQLAlchemy~=2.0.23 uvicorn python-dotenv pyyaml +click requests tabulate \ No newline at end of file diff --git a/controller/src/__init__.py b/controller/src/controller/__init__.py similarity index 100% rename from controller/src/__init__.py rename to controller/src/controller/__init__.py diff --git a/controller/src/api/__init__.py b/controller/src/controller/api/__init__.py similarity index 97% rename from controller/src/api/__init__.py rename to controller/src/controller/api/__init__.py index eb605b4..0f8a0b5 100644 --- a/controller/src/api/__init__.py +++ b/controller/src/controller/api/__init__.py @@ -15,7 +15,7 @@ from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware -from controller.src.api.endpoints import ( +from controller.api.endpoints import ( data_sources, datasets, documents, diff --git a/controller/src/api/endpoints/__init__.py b/controller/src/controller/api/endpoints/__init__.py similarity index 100% rename from controller/src/api/endpoints/__init__.py rename to controller/src/controller/api/endpoints/__init__.py diff --git a/controller/src/api/endpoints/data_sources.py b/controller/src/controller/api/endpoints/data_sources.py similarity index 98% rename from controller/src/api/endpoints/data_sources.py rename to controller/src/controller/api/endpoints/data_sources.py index ac5f19d..122e44b 100644 --- a/controller/src/api/endpoints/data_sources.py +++ b/controller/src/controller/api/endpoints/data_sources.py @@ -17,14 +17,14 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import ( +from controller.api.utils import ( AuthInfo, _send_to_application, get_auth_user, get_db, ) -from controller.src.db import client -from controller.src.schemas import ( +from controller.db import client +from genai_factory.schemas import ( APIResponse, DataSource, DataSourceType, diff --git a/controller/src/api/endpoints/datasets.py b/controller/src/controller/api/endpoints/datasets.py similarity index 96% rename from controller/src/api/endpoints/datasets.py rename to controller/src/controller/api/endpoints/datasets.py index f662ed1..4666f4c 100644 --- a/controller/src/api/endpoints/datasets.py +++ b/controller/src/controller/api/endpoints/datasets.py @@ -16,9 +16,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import AuthInfo, get_auth_user, get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, Dataset, OutputMode +from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.db import client +from genai_factory.schemas import APIResponse, Dataset, OutputMode router = APIRouter(prefix="/projects/{project_name}") diff --git a/controller/src/api/endpoints/documents.py b/controller/src/controller/api/endpoints/documents.py similarity index 96% rename from controller/src/api/endpoints/documents.py rename to controller/src/controller/api/endpoints/documents.py index 3f3149a..6450dd1 100644 --- a/controller/src/api/endpoints/documents.py +++ b/controller/src/controller/api/endpoints/documents.py @@ -16,9 +16,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import AuthInfo, get_auth_user, get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, Document, OutputMode +from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.db import client +from genai_factory.schemas import APIResponse, Document, OutputMode router = APIRouter(prefix="/projects/{project_name}") diff --git a/controller/src/api/endpoints/models.py b/controller/src/controller/api/endpoints/models.py similarity index 96% rename from controller/src/api/endpoints/models.py rename to controller/src/controller/api/endpoints/models.py index 1c5e72f..c2bf9c5 100644 --- a/controller/src/api/endpoints/models.py +++ b/controller/src/controller/api/endpoints/models.py @@ -16,9 +16,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import AuthInfo, get_auth_user, get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, Model, OutputMode +from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.db import client +from genai_factory.schemas import APIResponse, Model, OutputMode router = APIRouter(prefix="/projects/{project_name}") diff --git a/controller/src/api/endpoints/projects.py b/controller/src/controller/api/endpoints/projects.py similarity index 96% rename from controller/src/api/endpoints/projects.py rename to controller/src/controller/api/endpoints/projects.py index 486ed65..87fb6b0 100644 --- a/controller/src/api/endpoints/projects.py +++ b/controller/src/controller/api/endpoints/projects.py @@ -16,9 +16,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, OutputMode, Project +from controller.api.utils import get_db +from controller.db import client +from genai_factory.schemas import APIResponse, OutputMode, Project router = APIRouter() diff --git a/controller/src/api/endpoints/prompt_templates.py b/controller/src/controller/api/endpoints/prompt_templates.py similarity index 96% rename from controller/src/api/endpoints/prompt_templates.py rename to controller/src/controller/api/endpoints/prompt_templates.py index b0c0957..e4ada2f 100644 --- a/controller/src/api/endpoints/prompt_templates.py +++ b/controller/src/controller/api/endpoints/prompt_templates.py @@ -16,9 +16,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import AuthInfo, get_auth_user, get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, OutputMode, PromptTemplate +from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.db import client +from genai_factory.schemas import APIResponse, OutputMode, PromptTemplate router = APIRouter(prefix="/projects/{project_name}") diff --git a/controller/src/api/endpoints/sessions.py b/controller/src/controller/api/endpoints/sessions.py similarity index 97% rename from controller/src/api/endpoints/sessions.py rename to controller/src/controller/api/endpoints/sessions.py index 9accb23..15303bc 100644 --- a/controller/src/api/endpoints/sessions.py +++ b/controller/src/controller/api/endpoints/sessions.py @@ -14,9 +14,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, ChatSession, OutputMode +from controller.api.utils import get_db +from controller.db import client +from genai_factory.schemas import APIResponse, ChatSession, OutputMode router = APIRouter(prefix="/users/{user_name}") diff --git a/controller/src/api/endpoints/users.py b/controller/src/controller/api/endpoints/users.py similarity index 96% rename from controller/src/api/endpoints/users.py rename to controller/src/controller/api/endpoints/users.py index 5892dc9..f7ca4b1 100644 --- a/controller/src/api/endpoints/users.py +++ b/controller/src/controller/api/endpoints/users.py @@ -14,9 +14,9 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import get_db -from controller.src.db import client -from controller.src.schemas import APIResponse, OutputMode, User +from controller.api.utils import get_db +from controller.db import client +from genai_factory.schemas import APIResponse, OutputMode, User router = APIRouter() diff --git a/controller/src/api/endpoints/workflows.py b/controller/src/controller/api/endpoints/workflows.py similarity index 98% rename from controller/src/api/endpoints/workflows.py rename to controller/src/controller/api/endpoints/workflows.py index 768e02b..9c0b3db 100644 --- a/controller/src/api/endpoints/workflows.py +++ b/controller/src/controller/api/endpoints/workflows.py @@ -17,14 +17,14 @@ from fastapi import APIRouter, Depends -from controller.src.api.utils import ( +from controller.api.utils import ( AuthInfo, _send_to_application, get_auth_user, get_db, ) -from controller.src.db import client -from controller.src.schemas import ( +from controller.db import client +from genai_factory.schemas import ( APIResponse, ChatSession, OutputMode, diff --git a/controller/src/api/utils.py b/controller/src/controller/api/utils.py similarity index 97% rename from controller/src/api/utils.py rename to controller/src/controller/api/utils.py index e076940..5cae73e 100644 --- a/controller/src/api/utils.py +++ b/controller/src/controller/api/utils.py @@ -18,8 +18,8 @@ from fastapi import Header, Request from pydantic import BaseModel -from controller.src.config import config -from controller.src.db import client +from controller.config import config +from controller.db import client def get_db(): diff --git a/controller/src/config.py b/controller/src/controller/config.py similarity index 100% rename from controller/src/config.py rename to controller/src/controller/config.py diff --git a/controller/src/db/__init__.py b/controller/src/controller/db/__init__.py similarity index 89% rename from controller/src/db/__init__.py rename to controller/src/controller/db/__init__.py index 28da560..55c3bc1 100644 --- a/controller/src/db/__init__.py +++ b/controller/src/controller/db/__init__.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from controller.src.config import config - -from .sqlclient import SqlClient +from controller.config import config +from controller.db.sqlclient import SqlClient client = None diff --git a/controller/src/db/sqlclient.py b/controller/src/controller/db/sqlclient.py similarity index 99% rename from controller/src/db/sqlclient.py rename to controller/src/controller/db/sqlclient.py index f5080c2..2ae7736 100644 --- a/controller/src/db/sqlclient.py +++ b/controller/src/controller/db/sqlclient.py @@ -19,9 +19,9 @@ import sqlalchemy from sqlalchemy.orm import sessionmaker -import controller.src.db.sqldb as db -import controller.src.schemas as api_models -from controller.src.config import logger +import controller.db.sqldb as db +import genai_factory.schemas as api_models +from controller.config import logger class SqlClient: diff --git a/controller/src/db/sqldb.py b/controller/src/controller/db/sqldb.py similarity index 98% rename from controller/src/db/sqldb.py rename to controller/src/controller/db/sqldb.py index 468814d..7c8ad2a 100644 --- a/controller/src/db/sqldb.py +++ b/controller/src/controller/db/sqldb.py @@ -283,7 +283,7 @@ class DataSource(VersionedOwnerBaseSchema): :arg project_id: The project's id. :arg data_source_type: The type of the data source. - Can be one of the values in controller.src.schemas.data_source.DataSourceType. + Can be one of the values in genai_factory.schemas.data_source.DataSourceType. """ # Columns: @@ -386,7 +386,7 @@ class Model(VersionedOwnerBaseSchema): The Model table which is used to define models for the project. :arg project_id: The project's id. - :arg model_type: The type of the model. Can be one of the values in controller.src.schemas.model.ModelType. + :arg model_type: The type of the model. Can be one of the values in genai_factory.schemas.model.ModelType. :arg task: The task of the model. For example, "classification", "text-generation", etc. """ @@ -570,7 +570,7 @@ class Workflow(VersionedOwnerBaseSchema): :arg project_id: The project's id. :arg workflow_type: The type of the workflow. - Can be one of the values in controller.src.schemas.workflow.WorkflowType. + Can be one of the values in genai_factory.schemas.workflow.WorkflowType. """ # Columns: diff --git a/controller/src/main.py b/controller/src/controller/main.py similarity index 98% rename from controller/src/main.py rename to controller/src/controller/main.py index b51ef8f..19b0808 100644 --- a/controller/src/main.py +++ b/controller/src/controller/main.py @@ -20,10 +20,10 @@ import yaml from tabulate import tabulate -from controller.src.api.utils import _send_to_application -from controller.src.config import config -from controller.src.db import client -from controller.src.schemas import ( +from controller.api.utils import _send_to_application +from controller.config import config +from controller.db import client +from genai_factory.schemas import ( DataSource, Document, Project, diff --git a/controller/src/schemas/workflow.py b/controller/src/schemas/workflow.py deleted file mode 100644 index 2a376d1..0000000 --- a/controller/src/schemas/workflow.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2023 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from enum import Enum -from typing import List, Optional - -from controller.src.schemas.base import BaseWithVerMetadata - - -class WorkflowType(str, Enum): - INGESTION = "ingestion" - APPLICATION = "application" - DATA_PROCESSING = "data-processing" - TRAINING = "training" - EVALUATION = "evaluation" - - -class Workflow(BaseWithVerMetadata): - _top_level_fields = ["workflow_type"] - - workflow_type: WorkflowType - project_id: str - deployment: Optional[str] = None - workflow_function: Optional[str] = None - configuration: Optional[dict] = None - graph: Optional[List[dict]] = None - - def get_infer_path(self): - if self.deployment is None: - return None - return os.path.join(self.deployment, "infer") diff --git a/genai_factory/requirements.txt b/genai_factory/requirements.txt new file mode 100644 index 0000000..1bb81d4 --- /dev/null +++ b/genai_factory/requirements.txt @@ -0,0 +1,7 @@ +langchain_community +langchain_openai +langchain_huggingface +pymilvus +langchain-milvus +chromadb==0.5.3 +mlrun==1.6.0 \ No newline at end of file diff --git a/genai_factory/src/genai_factory/__init__.py b/genai_factory/src/genai_factory/__init__.py new file mode 100644 index 0000000..84f1919 --- /dev/null +++ b/genai_factory/src/genai_factory/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/genai_factory/src/genai_factory/__main__.py b/genai_factory/src/genai_factory/__main__.py new file mode 100644 index 0000000..f2301f9 --- /dev/null +++ b/genai_factory/src/genai_factory/__main__.py @@ -0,0 +1,116 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# main file with cli commands using python click library +import importlib.util +import sys + +import click +import uvicorn + +from genai_factory.api import router +from genai_factory.chains.base import HistorySaver, SessionLoader +from genai_factory.chains.refine import RefineQuery +from genai_factory.chains.retrieval import MultiRetriever +from genai_factory.config import config, username +from genai_factory.workflows import AppServer + +default_graph = [ + SessionLoader(), + RefineQuery(), + MultiRetriever(), + HistorySaver(), +] + + +@click.group() +def cli(): + pass + + +@click.command() +@click.argument("workflow-name", type=str) +@click.option("-p", "--path", type=str, default=None, help="Path to the workflow file") +@click.option("-r", "--runner", type=str, default="fastapi", help="Runner to use") +@click.option( + "-t", + "--workflow-type", + type=str, + default="application", + help="Type of the workflow", +) +def run( + workflow_name: str, + runner: str, + path: str, + workflow_type: str, +): + """ + Run workflow application + + :param workflow_name: The workflow name + :param runner: The runner to use, default is fastapi. + :param path: The path to the workflow file. + :param workflow_type: The type of the workflow. Can be one of mlrun.genai.schemas.WorkflowType + + :return: None + """ + # Import the workflow's graph from the path + if path: + # Load the module from the given file path + spec = importlib.util.spec_from_file_location("module_name", path) + module = importlib.util.module_from_spec(spec) + sys.modules["module_name"] = module + spec.loader.exec_module(module) + + # Retrieve the desired object from the module + click.echo(f"Using graph from {path}") + graph = getattr(module, "workflow_graph") + else: + # Use the default graph + click.echo("Using default graph") + graph = default_graph + + if runner == "nuclio": + click.echo("Running nuclio is not supported yet") + elif runner == "fastapi": + app_server = AppServer() + app_server.add_workflow( + project_name="default", + name=workflow_name, + graph=graph, + deployment=config.infer_path(workflow_name), + workflow_type=workflow_type, + username=username, + update=True, + ) + app = app_server.to_fastapi(router=router) + + # Deploy the fastapi app + host = config.workflow_deployment["host"] + port = config.workflow_deployment["port"] + click.echo(f"Running workflow {workflow_name} with fastapi on {host}") + uvicorn.run(app, host=host, port=port) + + else: + click.echo( + f"Runner {runner} not supported. Supported runners are: nuclio, fastapi" + ) + + +cli.add_command(run) + + +if __name__ == "__main__": + cli() diff --git a/genai_factory/src/genai_factory/actions.py b/genai_factory/src/genai_factory/actions.py new file mode 100644 index 0000000..bbe55a2 --- /dev/null +++ b/genai_factory/src/genai_factory/actions.py @@ -0,0 +1,51 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import openai +from pydantic import BaseModel + +from genai_factory.config import config, logger +from genai_factory.data.doc_loader import get_data_loader, get_loader_obj +from genai_factory.schemas import APIResponse + + +class IngestItem(BaseModel): + path: str + loader: str + metadata: Optional[List[Tuple[str, str]]] = None + version: Optional[str] = None + + +def ingest(collection_name, item: IngestItem): + """This is the data ingestion command""" + logger.debug( + f"Running Data Ingestion: collection_name={collection_name}, path={item.path}, loader={item.loader}" + ) + data_loader = get_data_loader( + config, + data_source_name=collection_name, + ) + loader_obj = get_loader_obj(item.path, loader_type=item.loader) + data_loader.load(loader_obj, metadata=item.metadata, version=item.version) + return APIResponse(success=True) + + +def transcribe_file(file_handler): + """transcribe audio file using openai API""" + logger.debug("Transcribing file") + text = openai.Audio.transcribe("whisper-1", file_handler) + print(text) + return APIResponse(success=True, data=text) diff --git a/genai_factory/src/genai_factory/api.py b/genai_factory/src/genai_factory/api.py new file mode 100644 index 0000000..29d6ca4 --- /dev/null +++ b/genai_factory/src/genai_factory/api.py @@ -0,0 +1,116 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from fastapi import APIRouter, Depends, FastAPI, Header, Request +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from genai_factory.client import Client +from genai_factory.config import config +from genai_factory.data.doc_loader import get_data_loader, get_loader_obj +from genai_factory.schemas import Document, QueryItem, Workflow + +app = FastAPI() + +# Add CORS middleware, remove in production +origins = ["*"] # React app +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Create a router with a prefix +router = APIRouter(prefix="/api") + +client = Client(base_url=config.api_url) + + +class AuthInfo(BaseModel): + username: str + token: str + roles: List[str] = [] + + +# placeholder for extracting the Auth info from the request +async def get_auth_user( + request: Request, x_username: Union[str, None] = Header(None) +) -> AuthInfo: + """Get the user from the database""" + token = request.cookies.get("Authorization", "") + if x_username: + return AuthInfo(username=x_username, token=token) + else: + return AuthInfo(username="guest@example.com", token=token) + + +@router.post("/data_sources/{data_source_name}/ingest") +async def ingest( + data_source_name: str, + database_kwargs: dict, + loader: str, + metadata: dict = None, + document: Document = None, + from_file: bool = False, +): + """Ingest documents into the vector database""" + data_loader = get_data_loader( + config=config, + data_source_name=data_source_name, + database_kwargs=database_kwargs, + ) + + if from_file: + with open(document.path, "r") as fp: + lines = fp.readlines() + for line in lines: + path = line.strip() + if path and not path.startswith("#"): + loader_obj = get_loader_obj(path, loader_type=loader) + data_loader.load( + loader_obj, metadata=metadata, version=document.version + ) + + else: + loader_obj = get_loader_obj(document.path, loader_type=loader) + data_loader.load(loader_obj, metadata=metadata, version=document.version) + return {"status": "ok"} + + +@router.post("/workflows/{name}/infer") +async def infer_workflow( + request: Request, + name: str, + workflow: Workflow, + item: QueryItem, + auth=Depends(get_auth_user), +): + """This is the query command""" + app_server = request.app.extra.get("app_server") + if not app_server: + raise ValueError("app_server not found in app") + + event = { + "username": auth.username, + "session_id": item.session_id, + "query": item.question, + "workflow_id": workflow.uid, + } + resp = app_server.run_workflow(name, event) + print(f"resp: {resp}") + return resp diff --git a/genai_factory/src/genai_factory/chains/__init__.py b/genai_factory/src/genai_factory/chains/__init__.py new file mode 100644 index 0000000..84f1919 --- /dev/null +++ b/genai_factory/src/genai_factory/chains/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/genai_factory/src/genai_factory/chains/base.py b/genai_factory/src/genai_factory/chains/base.py new file mode 100644 index 0000000..0a25666 --- /dev/null +++ b/genai_factory/src/genai_factory/chains/base.py @@ -0,0 +1,101 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import storey + +from genai_factory.schemas import WorkflowEvent + + +class ChainRunner(storey.Flow): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._is_async = asyncio.iscoroutinefunction(self._run) + + def _run(self, event: WorkflowEvent): + raise NotImplementedError() + + def __call__(self, event: WorkflowEvent): + return self._run(event) + + def post_init(self, mode="sync"): + pass + + async def _do(self, event): + if event is storey.dtypes._termination_obj: + return await self._do_downstream(storey.dtypes._termination_obj) + else: + print("step name: ", self.name) + element = self._get_event_or_body(event) + if self._is_async: + resp = await self._run(element) + else: + resp = self._run(element) + if resp: + for key, val in resp.items(): + element.results[key] = val + if "answer" in resp: + element.query = resp["answer"] + mapped_event = self._user_fn_output_to_event(event, element) + await self._do_downstream(mapped_event) + + +class SessionLoader(storey.Flow): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def _do(self, event): + if event is storey.dtypes._termination_obj: + return await self._do_downstream(storey.dtypes._termination_obj) + else: + element = self._get_event_or_body(event) + if isinstance(element, dict): + element = WorkflowEvent(**element) + + self.context.session_store.read_state(element) + mapped_event = self._user_fn_output_to_event(event, element) + await self._do_downstream(mapped_event) + + +class HistorySaver(ChainRunner): + def __init__( + self, + answer_key: str = None, + question_key: str = None, + save_sources: str = True, + **kwargs, + ): + super().__init__(**kwargs) + self.answer_key = answer_key + self.question_key = question_key + self.save_sources = save_sources + + async def _run(self, event: WorkflowEvent): + question = ( + event.results[self.question_key] + if self.question_key + else event.original_query + ) + sources = None + if self.save_sources and "sources" in event.results: + sources = [src.metadata for src in event.results["sources"]] + event.results["sources"] = sources + event.conversation.add_message("Human", question) + event.conversation.add_message( + "AI", event.results[self.answer_key or "answer"], sources + ) + + self.context.session_store.save(event) + return event.results diff --git a/genai_factory/src/genai_factory/chains/refine.py b/genai_factory/src/genai_factory/chains/refine.py new file mode 100644 index 0000000..ffa3d19 --- /dev/null +++ b/genai_factory/src/genai_factory/chains/refine.py @@ -0,0 +1,63 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain_core.prompts.prompt import PromptTemplate + +from genai_factory.chains.base import ChainRunner +from genai_factory.config import get_llm, logger +from genai_factory.schemas import WorkflowEvent + +_refine_prompt_template = """ +You are a helpful AI assistant, given the following conversation and a follow up request, + rephrase the follow up request to be a standalone request, keeping the same user language. +Your rephrasing must include any relevant history element to get a precise standalone request + and not losing previous context. + +Chat History: +{chat_history} + +Follow Up Input: {question} + +Standalone request: +""" + + +class RefineQuery(ChainRunner): + def __init__(self, llm=None, prompt_template=None, **kwargs): + super().__init__(**kwargs) + self.llm = llm + self.prompt_template = prompt_template + self._chain = None + + def post_init(self, mode="sync"): + self.llm = self.llm or get_llm(self.context._config) + refine_prompt = PromptTemplate.from_template( + self.prompt_template or _refine_prompt_template + ) + self._chain = refine_prompt | self.llm + + def _run(self, event: WorkflowEvent): + chat_history = str(event.conversation) + logger.debug(f"Question: {event.query}\nChat history: {chat_history}") + resp = self._chain.invoke( + {"question": event.query, "chat_history": chat_history} + ) + logger.debug(f"Refined question: {resp}") + return {"answer": resp} + + +def get_refine_chain(config, verbose=False, prompt_template=None): + llm = get_llm(config) + verbose = verbose or config.verbose + return RefineQuery(llm=llm, verbose=verbose, prompt_template=prompt_template) diff --git a/genai_factory/src/genai_factory/chains/retrieval.py b/genai_factory/src/genai_factory/chains/retrieval.py new file mode 100644 index 0000000..6c08d1b --- /dev/null +++ b/genai_factory/src/genai_factory/chains/retrieval.py @@ -0,0 +1,160 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.prompts import PromptTemplate + +from genai_factory.chains.base import ChainRunner +from genai_factory.config import get_llm, get_vector_db, logger +from genai_factory.schemas import WorkflowEvent + + +class DocumentCallbackHandler(BaseCallbackHandler): + """A callback handler that adds index number to documents retrieved.""" + + def on_retriever_end( + self, + documents, + *, + run_id, + parent_run_id, + **kwargs, + ): + logger.debug(f"on_retriever: {documents}") + if documents: + for i, doc in enumerate(documents): + doc.metadata["index"] = str(i) + + +class DocumentRetriever: + """A wrapper for the retrieval QA chain that returns source documents. + + Example: + vector_store = get_vector_db(config) + llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo") + query = "What is an llm?" + dr = document_retrevial(llm, vector_store) + dr.get_answer(query) + + Args: + llm: A language model. + vector_store: A vector store. + verbose: Whether to print debug information. + + """ + + def __init__( + self, llm, vector_store, verbose=False, chain_type: str = None, **search_kwargs + ): + document_prompt = PromptTemplate( + template="Content: {page_content}\nSource: {index}", + input_variables=["page_content", "index"], + ) + + self.chain = RetrievalQAWithSourcesChain.from_chain_type( + chain_type=chain_type or "stuff", # "map_reduce", + llm=llm, + retriever=vector_store.as_retriever(search_kwargs=search_kwargs), + return_source_documents=True, + chain_type_kwargs={"document_prompt": document_prompt}, + verbose=verbose, + ) + handler = DocumentCallbackHandler() + handler.verbose = verbose + self.chain_type = chain_type + self.verbose = verbose + self.cb = handler + + @classmethod + def from_config(cls, config, collection_name: str = None, **search_kwargs): + """Creates a document retriever from a config object.""" + vector_db = get_vector_db(config, collection_name=collection_name) + llm = get_llm(config) + return cls(llm, vector_db, verbose=config.verbose, **search_kwargs) + + def _get_answer(self, query): + result = self.chain({"question": query.content}, callbacks=[self.cb]) + sources = [s.strip() for s in result["sources"].split(",")] + source_docs = [ + doc + for doc in result["source_documents"] + if doc.metadata.pop("index", "") in sources + ] + if self.verbose: + docs_string = "\n".join(str(doc.metadata) for doc in source_docs) + logger.info(f"Source documents:\n{docs_string}") + return result["answer"], source_docs + + def run(self, event: WorkflowEvent): + # TODO: use text when is_cli + logger.debug(f"Retriever Question: {event.query}\n") + answer, sources = self._get_answer(event.query) + logger.debug(f"answer: {answer} \nSources: {sources}") + return {"answer": answer, "sources": sources} + + +class MultiRetriever(ChainRunner): + def __init__(self, llm=None, default_collection=None, **kwargs): + super().__init__(**kwargs) + self.llm = llm + self.default_collection = default_collection + self._retrievers = {} + + def post_init(self, mode="sync"): + self.llm = self.llm or get_llm(self.context._config) + if not self.default_collection: + self.default_collection = self.context._config.default_collection() + + def _get_retriever(self, collection_name: str = None): + collection_name = collection_name or self.default_collection + logger.debug(f"Selected collection: {collection_name}") + if collection_name not in self._retrievers: + vector_db = get_vector_db( + self.context._config, collection_name=collection_name + ) + retriever = DocumentRetriever( + self.llm, + vector_db, + verbose=self.verbose, + # collection_name=collection_name, + ) + self._retrievers[collection_name] = retriever + + return self._retrievers[collection_name] + + def _run(self, event: WorkflowEvent): + retriever = self._get_retriever(event.kwargs.get("collection_name")) + return retriever.run(event) + + +def fix_milvus_filter_arg(vector_db, search_kwargs): + """Fixes the milvus filter argument.""" + # detect if its Milvus and need to swap the filter dict arg with expr string + if "filter" in search_kwargs and hasattr(vector_db, "_create_connection_alias"): + filter = search_kwargs.pop("filter") + if isinstance(filter, dict): + # convert a dict of key value pairs to a string with key1=value1 and key2=value2 + filter = " and ".join(f"{k}={v}" for k, v in filter.items()) + search_kwargs["expr"] = filter + + +def get_retriever_from_config( + config, verbose=False, collection_name: str = None, **search_kwargs +): + """Creates a document retriever from a config object.""" + vector_db = get_vector_db(config, collection_name=collection_name) + llm = get_llm(config) + verbose = verbose or config.verbose + return DocumentRetriever(llm, vector_db, verbose=verbose, **search_kwargs) diff --git a/controller/src/schemas/__init__.py b/genai_factory/src/genai_factory/chains/sql.py similarity index 59% rename from controller/src/schemas/__init__.py rename to genai_factory/src/genai_factory/chains/sql.py index 68d09b8..d789f90 100644 --- a/controller/src/schemas/__init__.py +++ b/genai_factory/src/genai_factory/chains/sql.py @@ -12,13 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import APIResponse, Base, OutputMode -from .data_source import DataSource, DataSourceType -from .dataset import Dataset -from .document import Document -from .model import Model, ModelType -from .project import Project -from .prompt_template import PromptTemplate -from .session import ChatSession, QueryItem -from .user import User -from .workflow import Workflow, WorkflowType +# todo diff --git a/genai_factory/src/genai_factory/client.py b/genai_factory/src/genai_factory/client.py new file mode 100644 index 0000000..532fba5 --- /dev/null +++ b/genai_factory/src/genai_factory/client.py @@ -0,0 +1,160 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import requests +from mlrun.utils.helpers import dict_to_json + +from genai_factory.config import config, logger +from genai_factory.schemas import ChatSession, Project, Workflow + + +class Client: + def __init__(self, base_url, username=None, token=None): + self.base_url = base_url + self.username = username or "guest" + self.token = token + + def post_request( + self, path, data=None, params=None, method="GET", files=None, json=None + ): + # Construct the URL + url = f"{self.base_url}/api/{path}" + kw = { + key: value + for key, value in ( + ("params", params), + ("data", data), + ("json", json), + ("files", files), + ) + if value is not None + } + if data is not None: + kw["data"] = dict_to_json(kw["data"]) + if params is not None: + kw["params"] = ( + {k: v for k, v in params.items() if v is not None} if params else None + ) + # Make the request + logger.debug( + f"Sending {method} request to {url}, params: {params}, data: {data}" + ) + response = requests.request( + method, + url, + headers={"x_username": self.username}, + **kw, + ) + + # Check the response + if response.status_code == 200: + # If the request was successful, return the JSON response + return response.json() + else: + # If the request failed, raise an exception + response.raise_for_status() + + def get_collection(self, name): + response = self.post_request(f"collection/{name}") + return response["data"] + + def get_session(self, uid: str, user_name: str): + response = self.post_request(f"users/{user_name}/sessions/{uid}") + return response["data"] + + def get_user(self, username: str = "", email: str = None): + params = {} + if email: + params["email"] = email + response = self.post_request(f"users/{username}", params=params) + return response["data"] + + def create_session( + self, + name, + user_id, + username=None, + workflow_id=None, + history=None, + ): + chat_session = { + "name": name, + "owner_id": user_id, + "workflow_id": workflow_id, + "history": history or [], + } + response = self.post_request( + f"users/{username}/sessions", data=chat_session, method="POST" + ) + return response + + def update_session( + self, + chat_session: ChatSession, + username: str, + history=None, + ): + chat_session.history = history or [] + response = self.post_request( + f"users/{username}/sessions/{chat_session.name}", + data=chat_session.to_dict(), + method="PUT", + ) + return response["success"] + + def get_project(self, project_name: str): + response = self.post_request(f"projects/{project_name}") + return Project(**response["data"]) + + def create_workflow(self, project_name: str, workflow: Union[Workflow, dict]): + project_id = client.get_project(project_name=project_name).uid + if isinstance(workflow, dict): + workflow["project_id"] = project_id + graph = workflow.pop("graph", None) + workflow = Workflow(**workflow) + workflow.add_graph(graph) + response = self.post_request( + f"projects/{project_name}/workflows", method="POST", data=workflow.to_dict() + ) + return Workflow(**response["data"]) + + def get_workflow( + self, project_name: str, workflow_name: str = None, workflow_id: str = None + ): + if workflow_id: + response = self.post_request( + f"projects/{project_name}/workflows/{workflow_id}" + )["data"] + else: + response = self.post_request( + f"projects/{project_name}/workflows", params={"name": workflow_name} + ) + if not response["data"]: + return None + response = response["data"][0] + return Workflow(**response) + + def update_workflow(self, project_name: str, workflow: Workflow): + print(workflow.to_dict()) + response = self.post_request( + f"projects/{project_name}/workflows/{workflow.uid}", + data=workflow.to_dict(), + method="PUT", + ) + return Workflow(**response["data"]) + + +client = Client(base_url=config.api_url) diff --git a/genai_factory/src/genai_factory/config.py b/genai_factory/src/genai_factory/config.py new file mode 100644 index 0000000..ffbdcaf --- /dev/null +++ b/genai_factory/src/genai_factory/config.py @@ -0,0 +1,180 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +import os +from pathlib import Path + +import dotenv +import yaml +from pydantic import BaseModel + +root_path = Path(__file__).parent.parent.parent +dotenv.load_dotenv(os.environ.get("AGENT_ENV_PATH", str(root_path / ".env"))) +default_data_path = os.environ.get("AGENT_DATA_PATH", str(root_path / "data")) + + +class AppConfig(BaseModel): + """Configuration for the agent.""" + + api_url: str = "http://localhost:8001" # url of the controller API + verbose: bool = True + log_level: str = "DEBUG" + use_local_db: bool = True + + chunk_size: int = 1024 + chunk_overlap: int = 20 + + # Embeddings + embeddings: dict = {"class_name": "huggingface", "model_name": "all-MiniLM-L6-v2"} + + # Default LLM + default_llm: dict = { + "class_name": "langchain_openai.ChatOpenAI", + "temperature": 0, + "model_name": "gpt-3.5-turbo", + } + # Vector store + default_vector_store: dict = { + "class_name": "milvus", + "collection_name": "default", + "connection_args": {"address": "localhost:19530"}, + } + + workflow_deployment: dict = { + "host": "localhost", + "port": 8000, + } + + # Workflow kwargs + workflow_args: dict = {} + + def infer_path(self, workflow_name: str): + if self.workflow_deployment: + host = self.workflow_deployment.get("host", "localhost") + port = self.workflow_deployment.get("port", 8000) + return f"http://{host}:{port}/api/workflows/{workflow_name}" + return "" + + def default_collection(self): + return self.default_vector_store.get("collection_name", "default") + + def print(self): + print(yaml.dump(self.dict())) + + @classmethod + def load_from_yaml(cls, path: str): + with open(path, "r") as f: + data = yaml.safe_load(f) + return cls.parse_obj(data) + + @classmethod + def local_config(cls): + """Create a local config for testing oe local deployment.""" + config = cls() + config.verbose = True + config.default_vector_store = { + "class_name": "chroma", + "collection_name": "default", + "persist_directory": str((Path(default_data_path) / "chroma").absolute()), + } + return config + + +username = os.environ.get("GENAI_USER_NAME", "") +is_local_config = os.environ.get("IS_LOCAL_CONFIG", "0").lower().strip() in [ + "true", + "1", +] +config_path = os.environ.get("AGENT_CONFIG_PATH") + +if config_path: + config = AppConfig.load_from_yaml(config_path) +elif is_local_config: + config = AppConfig.local_config() +else: + config = AppConfig() + +logger = logging.getLogger("llmagent") +logger.setLevel(config.log_level.upper()) +logger.addHandler(logging.StreamHandler()) +logger.info("Logger initialized...") +# logger.info(f"Using config:\n {yaml.dump(config.model_dump())}") + + +embeddings_shortcuts = { + "huggingface": "langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings", + "openai": "langchain_openai.embeddings.base.OpenAIEmbeddings", +} + +vector_db_shortcuts = { + "milvus": "langchain_community.vectorstores.Milvus", + "chroma": "langchain_community.vectorstores.chroma.Chroma", +} + +llm_shortcuts = { + "chat": "langchain_openai.ChatOpenAI", + "gpt": "langchain_community.chat_models.GPT", +} + + +def get_embedding_function(config: AppConfig, embeddings_args: dict = None): + return get_object_from_dict( + embeddings_args or config.embeddings, embeddings_shortcuts + ) + + +def get_llm(config: AppConfig, llm_args: dict = None): + """Get a language model instance.""" + return get_object_from_dict(llm_args or config.default_llm, llm_shortcuts) + + +def get_vector_db( + config: AppConfig, + collection_name: str = None, + vector_store_args: dict = None, +): + """Get a vector database instance. + + Args: + config: An AppConfig instance. + collection_name: The name of the collection to use (if not default). + vector_store_args: class_name and arguments to pass to the vector store class (None will use the config). + """ + embeddings = get_embedding_function(config=config) + vector_store_args = vector_store_args or config.default_vector_store + vector_store_args = vector_store_args.copy() + if collection_name: + vector_store_args["collection_name"] = collection_name + vector_store_args["embedding_function"] = embeddings + return get_object_from_dict(vector_store_args, vector_db_shortcuts) + + +def get_class_from_string(class_path, shortcuts: dict = {}) -> type: + if class_path in shortcuts: + class_path = shortcuts[class_path] + module_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + return class_ + + +def get_object_from_dict(obj_dict: dict, shortcuts: dict = {}): + if not isinstance(obj_dict, dict): + return obj_dict + obj_dict = obj_dict.copy() + class_name = obj_dict.pop("class_name") + class_ = get_class_from_string(class_name, shortcuts) + return class_(**obj_dict) diff --git a/genai_factory/src/genai_factory/data/__init__.py b/genai_factory/src/genai_factory/data/__init__.py new file mode 100644 index 0000000..84f1919 --- /dev/null +++ b/genai_factory/src/genai_factory/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/genai_factory/src/genai_factory/data/doc_loader.py b/genai_factory/src/genai_factory/data/doc_loader.py new file mode 100644 index 0000000..6df586a --- /dev/null +++ b/genai_factory/src/genai_factory/data/doc_loader.py @@ -0,0 +1,138 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from pathlib import Path + +from langchain_community.document_loaders import ( + CSVLoader, + PyMuPDFLoader, + TextLoader, + UnstructuredHTMLLoader, + UnstructuredMarkdownLoader, + UnstructuredPowerPointLoader, + UnstructuredWordDocumentLoader, + WebBaseLoader, +) +from langchain_text_splitters import RecursiveCharacterTextSplitter + +from genai_factory.config import AppConfig, get_vector_db, logger +from genai_factory.data.web_loader import SmartWebLoader + +LOADER_MAPPING = { + ".csv": (CSVLoader, {}), + ".doc": (UnstructuredWordDocumentLoader, {}), + ".docx": (UnstructuredWordDocumentLoader, {}), + ".html": (UnstructuredHTMLLoader, {}), + ".md": (UnstructuredMarkdownLoader, {}), + ".pdf": (PyMuPDFLoader, {}), + ".ppt": (UnstructuredPowerPointLoader, {}), + ".pptx": (UnstructuredPowerPointLoader, {}), + ".txt": (TextLoader, {"encoding": "utf8"}), + # Add more mappings for other file extensions and loaders as needed +} + + +# get the initialized loader class and its arguments from the type (web or file) and full file path +# use Path().suffix lib to extract the file extension from the file path +def get_loader_obj(doc_path: str, loader_type: str = None, **extra_args): + if loader_type == "web": + return WebBaseLoader([doc_path], **extra_args) + elif loader_type == "eweb": + return SmartWebLoader([doc_path], **extra_args) + else: + ext = Path(doc_path).suffix + if ext in LOADER_MAPPING: + loader_class, loader_args = LOADER_MAPPING[ext] + return loader_class(doc_path, **{**loader_args, **extra_args}) + raise ValueError(f"Unsupported file extension '{ext}'") + + +class DataLoader: + """Loads documents into a vector store. + Example: + + data_loader = DataLoader(config) + loader = get_loader_obj("https://milvus.io/docs/overview.md", loader_type="web") + data_loader.load(loader, metadata={"xx": "web"}) + """ + + def __init__(self, config: AppConfig, vector_store=None): + self.vector_store = vector_store + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap + ) + + def load(self, loader, metadata: dict = None, version: int = None): + """Loads documents into the vector store. + + Args: + loader: A document loader. + metadata: A dictionary of metadata to attach to the documents. + version: A version number for the documents. + """ + docs = loader.load() + to_chunk = not hasattr(loader, "chunked") + for doc in docs: + self.ingest_document(doc, metadata, version, to_chunk=to_chunk) + + def ingest_document( + self, + doc, + metadata: dict = None, + version: int = None, + doc_uid: str = None, + to_chunk: bool = True, + ): + """Ingests a document into the vector store. + + Args: + doc: A document. + metadata: A dictionary of extra metadata to attach to the document. + version: A version number for the document. + doc_uid: A unique identifier for the document (will be generated if None). + """ + if not doc_uid: + doc_uid = uuid.uuid4().hex + if to_chunk: + chunks = self.text_splitter.split_documents([doc]) + else: + chunks = [doc] + for i, chunk in enumerate(chunks): + if to_chunk: + chunk.metadata["chunk"] = i + if metadata: + for key, value in metadata.items(): + chunk.metadata[key] = value + chunk.metadata["doc_uid"] = doc_uid + if version: + chunk.metadata["version"] = version + logger.debug( + f"Loading doc chunk:\n{chunk.page_content}\nMetadata: {chunk.metadata}" + ) + self.vector_store.add_documents(chunks) + + +def get_data_loader( + config: AppConfig, + data_source_name: str = None, + database_kwargs: dict = None, +) -> DataLoader: + """Get a data loader instance.""" + vector_db = get_vector_db( + config, + collection_name=data_source_name, + vector_store_args=database_kwargs, + ) + return DataLoader(config, vector_store=vector_db) diff --git a/genai_factory/src/genai_factory/data/web_loader.py b/genai_factory/src/genai_factory/data/web_loader.py new file mode 100644 index 0000000..a4db0ae --- /dev/null +++ b/genai_factory/src/genai_factory/data/web_loader.py @@ -0,0 +1,74 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from urllib.parse import urlparse + +import requests +from bs4 import BeautifulSoup +from langchain_core.documents import Document + + +class SmartWebLoader: + chunked = True + + def __init__(self, urls: list, **kwargs): + if isinstance(urls, str): + urls = [urls] + self.urls = urls + + def _parse_page(self, url: str) -> Document: + # Get url parts: + parsed_url = urlparse(url) + url_parts = parsed_url.path.rsplit("/", 4) + + # Get html from web url: + page = requests.get(url) + soup = BeautifulSoup(page.content, "html.parser") + + # Get titles: + titles_span = soup.find_all("span", class_="cmp-accordion__title") + titles = [title.text for title in titles_span] + + # Get answers: + answers_div = soup.find_all("div", class_="cmp-text") + # answers = [answer.encode_contents() for answer in answers_div] + answers = [answer.get_text() for answer in answers_div] + + # Get hyperlinks to content: + specific_links_button = soup.find_all("button", class_="cmp-accordion__button") + specific_links = [ + url + "#" + button.attrs["id"] for button in specific_links_button + ] + + chunks = [] + for title, answer, specific_link in zip(titles, answers, specific_links): + content = f"Question: {title}\nAnswer: {answer}\n" + full_title = f"{url_parts[-4]}/{url_parts[-3]}/{url_parts[-2]}/{title}" + metadata = { + "service": url_parts[-4], + "topic": url_parts[-3], + "subtopic": url_parts[-2], + "section": url_parts[-1].removesuffix(".html"), + "title": full_title, + "source": specific_link, + } + chunks.append(Document(page_content=content, metadata=metadata)) + + return chunks + + def load(self): + docs = [] + for url in self.urls: + docs.extend(self._parse_page(url)) + return docs diff --git a/genai_factory/src/genai_factory/schemas/__init__.py b/genai_factory/src/genai_factory/schemas/__init__.py new file mode 100644 index 0000000..6407044 --- /dev/null +++ b/genai_factory/src/genai_factory/schemas/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from genai_factory.schemas.base import APIDictResponse, APIResponse, Base, OutputMode +from genai_factory.schemas.data_source import DataSource, DataSourceType +from genai_factory.schemas.dataset import Dataset +from genai_factory.schemas.document import Document +from genai_factory.schemas.model import Model, ModelType +from genai_factory.schemas.project import Project +from genai_factory.schemas.prompt_template import PromptTemplate +from genai_factory.schemas.session import ChatSession, Conversation, QueryItem +from genai_factory.schemas.user import User +from genai_factory.schemas.workflow import Workflow, WorkflowEvent, WorkflowType diff --git a/controller/src/schemas/base.py b/genai_factory/src/genai_factory/schemas/base.py similarity index 98% rename from controller/src/schemas/base.py rename to genai_factory/src/genai_factory/schemas/base.py index 3b9c7f9..3080eac 100644 --- a/controller/src/schemas/base.py +++ b/genai_factory/src/genai_factory/schemas/base.py @@ -180,6 +180,10 @@ def with_raise_http(self, format=None) -> "APIResponse": return self +class APIDictResponse(APIResponse): + data: Optional[dict] = None + + class OutputMode(str, Enum): NAMES = "names" SHORT = "short" diff --git a/controller/src/schemas/data_source.py b/genai_factory/src/genai_factory/schemas/data_source.py similarity index 94% rename from controller/src/schemas/data_source.py rename to genai_factory/src/genai_factory/schemas/data_source.py index 308506a..c08bbf3 100644 --- a/controller/src/schemas/data_source.py +++ b/genai_factory/src/genai_factory/schemas/data_source.py @@ -14,7 +14,7 @@ from enum import Enum -from controller.src.schemas.base import BaseWithVerMetadata +from genai_factory.schemas.base import BaseWithVerMetadata class DataSourceType(str, Enum): diff --git a/controller/src/schemas/dataset.py b/genai_factory/src/genai_factory/schemas/dataset.py similarity index 93% rename from controller/src/schemas/dataset.py rename to genai_factory/src/genai_factory/schemas/dataset.py index 132576e..07368b4 100644 --- a/controller/src/schemas/dataset.py +++ b/genai_factory/src/genai_factory/schemas/dataset.py @@ -14,7 +14,7 @@ from typing import List, Optional -from controller.src.schemas.base import BaseWithVerMetadata +from genai_factory.schemas.base import BaseWithVerMetadata class Dataset(BaseWithVerMetadata): diff --git a/controller/src/schemas/document.py b/genai_factory/src/genai_factory/schemas/document.py similarity index 92% rename from controller/src/schemas/document.py rename to genai_factory/src/genai_factory/schemas/document.py index 0f1ee39..40728bb 100644 --- a/controller/src/schemas/document.py +++ b/genai_factory/src/genai_factory/schemas/document.py @@ -14,7 +14,7 @@ from typing import Optional -from controller.src.schemas.base import BaseWithVerMetadata +from genai_factory.schemas.base import BaseWithVerMetadata class Document(BaseWithVerMetadata): diff --git a/controller/src/schemas/model.py b/genai_factory/src/genai_factory/schemas/model.py similarity index 94% rename from controller/src/schemas/model.py rename to genai_factory/src/genai_factory/schemas/model.py index 9000124..5fbf03c 100644 --- a/controller/src/schemas/model.py +++ b/genai_factory/src/genai_factory/schemas/model.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Optional -from controller.src.schemas.base import BaseWithVerMetadata +from genai_factory.schemas.base import BaseWithVerMetadata class ModelType(str, Enum): diff --git a/controller/src/schemas/project.py b/genai_factory/src/genai_factory/schemas/project.py similarity index 91% rename from controller/src/schemas/project.py rename to genai_factory/src/genai_factory/schemas/project.py index ae996e9..682aea9 100644 --- a/controller/src/schemas/project.py +++ b/genai_factory/src/genai_factory/schemas/project.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from controller.src.schemas.base import BaseWithVerMetadata +from genai_factory.schemas.base import BaseWithVerMetadata class Project(BaseWithVerMetadata): diff --git a/controller/src/schemas/prompt_template.py b/genai_factory/src/genai_factory/schemas/prompt_template.py similarity index 92% rename from controller/src/schemas/prompt_template.py rename to genai_factory/src/genai_factory/schemas/prompt_template.py index 448e1c4..30cb5d8 100644 --- a/controller/src/schemas/prompt_template.py +++ b/genai_factory/src/genai_factory/schemas/prompt_template.py @@ -14,7 +14,7 @@ from typing import List, Optional -from controller.src.schemas.base import BaseWithVerMetadata +from genai_factory.schemas.base import BaseWithVerMetadata class PromptTemplate(BaseWithVerMetadata): diff --git a/controller/src/schemas/session.py b/genai_factory/src/genai_factory/schemas/session.py similarity index 60% rename from controller/src/schemas/session.py rename to genai_factory/src/genai_factory/schemas/session.py index a54ca2d..c200564 100644 --- a/controller/src/schemas/session.py +++ b/genai_factory/src/genai_factory/schemas/session.py @@ -17,7 +17,7 @@ from pydantic import BaseModel -from controller.src.schemas.base import BaseWithOwner +from genai_factory.schemas.base import BaseWithOwner class QueryItem(BaseModel): @@ -43,9 +43,36 @@ class Message(BaseModel): human_feedback: Optional[str] = None +class Conversation(BaseModel): + messages: list[Message] = [] + saved_index: int = 0 + + def __str__(self): + return "\n".join([f"{m.role}: {m.content}" for m in self.messages]) + + def add_message(self, role, content, sources=None): + self.messages.append(Message(role=role, content=content, sources=sources)) + + def to_list(self): + return self.dict()["messages"] + # return self.model_dump(mode="json")["messages"] + + def to_dict(self): + return self.dict()["messages"] + # return self.model_dump(mode="json")["messages"] + + @classmethod + def from_list(cls, data: list): + return cls.parse_obj({"messages": data or []}) + # return cls.model_validate({"messages": data or []}) + + class ChatSession(BaseWithOwner): _extra_fields = ["history"] _top_level_fields = ["workflow_id"] workflow_id: str history: List[Message] = [] + + def to_conversation(self): + return Conversation.from_list(self.history) diff --git a/controller/src/schemas/user.py b/genai_factory/src/genai_factory/schemas/user.py similarity index 93% rename from controller/src/schemas/user.py rename to genai_factory/src/genai_factory/schemas/user.py index 909e6ea..1c58af8 100644 --- a/controller/src/schemas/user.py +++ b/genai_factory/src/genai_factory/schemas/user.py @@ -14,7 +14,7 @@ from typing import Optional -from controller.src.schemas.base import BaseWithMetadata +from genai_factory.schemas.base import BaseWithMetadata class User(BaseWithMetadata): diff --git a/genai_factory/src/genai_factory/schemas/workflow.py b/genai_factory/src/genai_factory/schemas/workflow.py new file mode 100644 index 0000000..0bbaec7 --- /dev/null +++ b/genai_factory/src/genai_factory/schemas/workflow.py @@ -0,0 +1,95 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from enum import Enum +from typing import List, Optional + +import storey + +from genai_factory.schemas import Conversation +from genai_factory.schemas.base import BaseWithVerMetadata + + +class WorkflowType(str, Enum): + INGESTION = "ingestion" + APPLICATION = "application" + DATA_PROCESSING = "data-processing" + TRAINING = "training" + EVALUATION = "evaluation" + + +class Workflow(BaseWithVerMetadata): + _top_level_fields = ["workflow_type"] + + workflow_type: WorkflowType + project_id: str + deployment: Optional[str] = None + workflow_function: Optional[str] = None + configuration: Optional[dict] = None + graph: Optional[List[dict]] = None + + def get_infer_path(self): + if self.deployment is None: + return None + return os.path.join(self.deployment, "infer") + + def add_graph(self, graph: List[storey.Flow]): + self.graph = [step.to_dict() for step in graph] + + +class WorkflowEvent: + """ + A workflow event. + """ + + def __init__( + self, + query=None, + username=None, + session_id=None, + db_session=None, + workflow_id=None, + **kwargs, + ): + self.username = username + self.session_id = session_id + self.original_query = query + self.query = query + self.kwargs = kwargs + + self.session = None + self.user = None + self.results = {} + self.state = {} + self.conversation: Conversation = Conversation() + self.workflow_id = workflow_id + + self.db_session = db_session # SQL db session (from FastAPI) + + def to_dict(self): + return { + "username": self.username, + "session_id": self.session_id, + "query": self.query, + "kwargs": self.kwargs, + "results": self.results, + "state": self.state, + "conversation": self.conversation.to_list(), + "workflow_id": self.workflow_id, + "session": self.session.to_dict() if self.session else None, + } + + def __getitem__(self, item): + return getattr(self, item) diff --git a/genai_factory/src/genai_factory/sessions.py b/genai_factory/src/genai_factory/sessions.py new file mode 100644 index 0000000..6c09093 --- /dev/null +++ b/genai_factory/src/genai_factory/sessions.py @@ -0,0 +1,51 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from genai_factory.client import Client +from genai_factory.client import client as default_client +from genai_factory.schemas import ChatSession, WorkflowEvent + + +class SessionStore: + def __init__(self, client): + self.db_session = None + self.client = client + + def read_state(self, event: WorkflowEvent): + event.user = self.client.get_user(username=event.username, email=event.username) + event.username = event.user["name"] or "guest" + if not event.session and event.session_id: + resp = self.client.get_session( + uid=event.session_id, user_name=event.username + ) + chat_session = ChatSession(**resp) + event.session = chat_session + event.conversation = chat_session.to_conversation() + + def save(self, event: WorkflowEvent): + """Save the session and conversation to the database""" + if event.session_id: + self.client.update_session( + chat_session=event.session, + username=event.username, + history=event.conversation.to_list(), + ) + + +def get_session_store(config=None): + if config: + client = Client(base_url=config.api_url) + else: + client = default_client + return SessionStore(client=client) diff --git a/genai_factory/src/genai_factory/workflows.py b/genai_factory/src/genai_factory/workflows.py new file mode 100644 index 0000000..8d8c0cd --- /dev/null +++ b/genai_factory/src/genai_factory/workflows.py @@ -0,0 +1,218 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mlrun +from mlrun import serving +from mlrun.utils import get_caller_globals + +from genai_factory.client import client +from genai_factory.config import config as default_config +from genai_factory.schemas import APIDictResponse +from genai_factory.sessions import get_session_store + + +class AppServer: + def __init__(self, config=None, verbose=False): + self._config = config or default_config + self._session_store = get_session_store(self._config) + self._workflows = {} + self.verbose = verbose + + def set_config(self, config): + self._config = config + self._session_store = get_session_store(self._config) + for workflow in self._workflows.values(): + workflow._server = None + + def add_workflow( + self, + project_name: str, + name: str, + graph: list = None, + deployment: str = None, + update: bool = False, + workflow_type: str = None, + username: str = None, + ): + # Check if workflow already exists: + if name in self._workflows: + raise ValueError(f"workflow {name} already exists") + workflow = client.get_workflow(project_name=project_name, workflow_name=name) + if workflow: + if not update: + raise ValueError( + f"workflow {name} already exists, to update set update=True" + ) + else: + # Update workflow: + if graph: + workflow.add_graph(graph) + if deployment: + workflow.deployment = deployment + workflow = client.update_workflow( + project_name=project_name, workflow=workflow + ) + else: + # Workflow does not exist, create it: + owner_id = client.get_user(username=username)["uid"] + workflow = { + "name": name, + "deployment": deployment, + "graph": graph, + "workflow_type": workflow_type, + "owner_id": owner_id, + } + # Add workflow to database: + workflow = client.create_workflow( + project_name=project_name, + workflow=workflow, + ) + # Add workflow to app server: + self._workflows[name] = { + "uid": workflow.uid, + "project_name": project_name, + } + return workflow + + def add_workflows(self, project_name: str, workflows: dict): + for name, workflow in workflows.items(): + self.add_workflow(project_name=project_name, **workflow) + + def get_workflow(self, name): + workflow = self._workflows.get(name) + uid = workflow.get("uid") + project_name = workflow.get("project_name") + return client.get_workflow(project_name=project_name, workflow_id=uid) + + def run_workflow(self, name, event): + workflow = self.get_workflow(name) + if not workflow: + raise ValueError(f"workflow {name} not found") + app_workflow = AppWorkflow(self, name=workflow.name, graph=workflow.graph) + return app_workflow.run(event) + + def api_startup(self): + print("\nstartup event\n") + + def to_fastapi(self, router=None): + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + + app = FastAPI() + + # Add CORS middleware, remove in production + origins = ["*"] # React app + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + extra = app.extra or {} + extra["app_server"] = self + app.extra = extra + if router: + app.include_router(router) + return app + + +app_server = AppServer() + +# workflows cache +workflows = {} + + +class AppWorkflow: + def __init__(self, parent, name=None, graph=None): + self.name = name or "" + self._parent = parent + self._graph = None + self._server = None + + if graph: + self.graph = graph + + @property + def graph(self) -> serving.states.RootFlowStep: + return self._graph + + @graph.setter + def graph(self, graph): + if isinstance(graph, list): + if not graph: + raise ValueError("graph list must not be empty") + graph_obj = mlrun.serving.states.RootFlowStep() + step = graph_obj + for item in graph: + if isinstance(item, dict): + step = step.to(**item) + else: + step = step.to(item) + step.respond() + self._graph = graph_obj + return + + if isinstance(graph, dict): + graph = mlrun.serving.states.RootFlowStep.from_dict(graph) + self._graph = graph + + def get_server(self): + if self._server is None: + namespace = get_caller_globals() + server = serving.create_graph_server( + graph=self.graph, + parameters={}, + verbose=self._parent.verbose or True, + graph_initializer=self.lc_initializer, + ) + server.init_states(context=None, namespace=namespace) + server.init_object(namespace) + self._server = server + return server + return self._server + + def lc_initializer(self, server): + context = server.context + + def register_prompt( + name, template, description: str = None, llm_args: dict = None + ): + if not hasattr(context, "prompts"): + context.prompts = {} + context.prompts[name] = (template, llm_args) + + if getattr(context, "_config", None) is None: + context._config = self._parent._config + if getattr(context, "session_store", None) is None: + context.session_store = self._parent._session_store + + def run(self, event, db_session=None): + # todo: pass sql db_session to steps via context or event + server = self.get_server() + try: + resp = server.test("", body=event) + except Exception as e: + server.wait_for_completion() + raise e + + return APIDictResponse( + success=True, + data={ + "answer": resp.results["answer"], + "sources": resp.results["sources"], + "returned_state": {}, + }, + )