From 04ceda13b664b8aff89cf991f11b201feaded233 Mon Sep 17 00:00:00 2001 From: Yonatan Shelach <92271540+yonishelach@users.noreply.github.com> Date: Sun, 15 Sep 2024 14:06:24 +0300 Subject: [PATCH] [Controller] API changes (#18) * uid to name * update or create - step 1 * remove version from user and session endpoints * update by name & docstring space alignment * update cli * query item: session name instead of session uid * fix list projects * remove default workflow creation from init db * lint * fix owner id issue in list endpoints --- controller/src/controller/api/__init__.py | 2 +- .../controller/api/endpoints/data_sources.py | 172 ++-- .../src/controller/api/endpoints/datasets.py | 131 ++- .../src/controller/api/endpoints/documents.py | 133 +-- .../src/controller/api/endpoints/models.py | 135 +-- .../src/controller/api/endpoints/projects.py | 100 ++- .../api/endpoints/prompt_templates.py | 135 +-- .../src/controller/api/endpoints/sessions.py | 102 ++- .../src/controller/api/endpoints/users.py | 88 +- .../src/controller/api/endpoints/workflows.py | 199 +++-- controller/src/controller/api/utils.py | 23 +- controller/src/controller/db/sqlclient.py | 813 +++++++++--------- controller/src/controller/db/sqldb.py | 32 +- controller/src/controller/main.py | 177 ++-- .../src/genai_factory/schemas/session.py | 2 +- 15 files changed, 1258 insertions(+), 986 deletions(-) diff --git a/controller/src/controller/api/__init__.py b/controller/src/controller/api/__init__.py index 0f8a0b5..54aa7e1 100644 --- a/controller/src/controller/api/__init__.py +++ b/controller/src/controller/api/__init__.py @@ -78,7 +78,7 @@ ) api_router.include_router( sessions.router, - tags=["chat_sessions"], + tags=["sessions"], ) # Include the router in the main app diff --git a/controller/src/controller/api/endpoints/data_sources.py b/controller/src/controller/api/endpoints/data_sources.py index 122e44b..0c38272 100644 --- a/controller/src/controller/api/endpoints/data_sources.py +++ b/controller/src/controller/api/endpoints/data_sources.py @@ -22,6 +22,7 @@ _send_to_application, get_auth_user, get_db, + parse_version, ) from controller.db import client from genai_factory.schemas import ( @@ -39,19 +40,19 @@ def create_data_source( project_name: str, data_source: DataSource, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new data source in the database. - :param project_name: The name of the project to create the data source in. - :param data_source: The data source to create. - :param session: The database session. + :param project_name: The name of the project to create the data source in. + :param data_source: The data source to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_data_source(data_source=data_source, session=session) + data = client.create_data_source(data_source=data_source, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -60,22 +61,38 @@ def create_data_source( ) -@router.get("/data_sources/{uid}") +@router.get("/data_sources/{name}") def get_data_source( - project_name: str, uid: str, session=Depends(get_db) + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), ) -> APIResponse: """ Get a data source from the database. - :param project_name: The name of the project to get the data source from. - :param uid: The uid of the data source to get. - :param session: The database session. + :param project_name: The name of the project to get the data source from. + :param name: The name of the data source to get. + :param uid: The uid of the data source to get. + :param version: The version of the data source to get. + :param db_session: The database session. - :return: The data source from the database. + :return: The data source from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: - data = client.get_data_source(project_id=project_id, uid=uid, session=session) + # Parse the version if provided: + uid, version = parse_version(uid, version) + data = client.get_data_source( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) if data is None: return APIResponse( success=False, error=f"Data source with uid = {uid} not found" @@ -88,49 +105,66 @@ def get_data_source( ) -@router.put("/data_sources/{data_source_name}") +@router.put("/data_sources/{name}") def update_data_source( project_name: str, data_source: DataSource, - data_source_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a data source in the database. - :param project_name: The name of the project to update the data source in. - :param data_source: The data source to update. - :param data_source_name: The name of the data source to update. - :param session: The database session. + :param project_name: The name of the project to update the data source in. + :param data_source: The data source to update. + :param name: The name of the data source to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_data_source(data_source=data_source, session=session) + data = client.update_data_source( + name=name, data_source=data_source, db_session=db_session + ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update data source {data_source_name} in project {project_name}: {e}", + error=f"Failed to update data source {name} in project {project_name}: {e}", ) -@router.delete("/data_sources/{uid}") +@router.delete("/data_sources/{name}") def delete_data_source( - project_name: str, uid: str, session=Depends(get_db) + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), ) -> APIResponse: """ Delete a data source from the database. - :param project_name: The name of the project to delete the data source from. - :param uid: The ID of the data source to delete. - :param session: The database session. + :param project_name: The name of the project to delete the data source from. + :param name: The name of the data source to delete. + :param uid: The ID of the data source to delete. + :param version: The version of the data source to delete. + :param db_session: The database session. - :return: The response from the database. + :returThe response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - client.delete_data_source(project_id=project_id, uid=uid, session=session) + client.delete_data_source( + name=name, + project_id=project_id, + uid=uid, + version=version, + db_session=db_session, + ) except Exception as e: return APIResponse( success=False, @@ -147,25 +181,28 @@ def list_data_sources( data_source_type: Union[DataSourceType, str] = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ List data sources in the database. - :param project_name: The name of the project to list the data sources from. - :param name: The name to filter by. - :param version: The version to filter by. - :param data_source_type: The data source type to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to list the data sources from. + :param name: The name to filter by. + :param version: The version to filter by. + :param data_source_type: The data source type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).uid - project_id = client.get_project(project_name=project_name, session=session).uid + owner = client.get_user(user_name=auth.username, db_session=db_session) + owner_id = getattr(owner, "uid", None) + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: data = client.list_data_sources( project_id=project_id, @@ -175,7 +212,7 @@ def list_data_sources( data_source_type=data_source_type, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: @@ -185,36 +222,45 @@ def list_data_sources( ) -@router.post("/data_sources/{uid}/ingest") +@router.post("/data_sources/{name}/ingest") def ingest( - project_name, - uid, + project_name: str, + name: str, loader: str, path: str, + uid: str = None, metadata=None, version: str = None, from_file: bool = False, - session=Depends(get_db), + db_session=Depends(get_db), auth=Depends(get_auth_user), ): """ Ingest document into the vector database. - :param project_name: The name of the project to ingest the documents into. - :param uid: The UID of the data source to ingest the documents into. - :param loader: The data loader type to use. - :param path: The path to the document to ingest. - :param metadata: The metadata to associate with the documents. - :param version: The version of the documents. - :param from_file: Whether the documents are from a file. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the application. + :param project_name: The name of the project to ingest the documents into. + :param name: The name of the data source to ingest the documents into. + :param loader: The data loader type to use. + :param path: The path to the document to ingest. + :param uid: The UID of the data source to ingest the documents into. + :param metadata: The metadata to associate with the documents. + :param version: The version of the documents. + :param from_file: Whether the documents are from a file. + :param db_session: The database session. + :param auth: The authentication information. + + :return: The response from the application. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, ds_version = parse_version(uid, version) data_source = client.get_data_source( - project_id=project_id, uid=uid, session=session + name=name, + project_id=project_id, + uid=uid, + version=ds_version, + db_session=db_session, ) # Create document from path: @@ -227,7 +273,7 @@ def ingest( ) # Add document to the database: - document = client.create_document(document=document, session=session) + document = client.create_document(document=document, db_session=db_session) # Send ingest to application: params = { diff --git a/controller/src/controller/api/endpoints/datasets.py b/controller/src/controller/api/endpoints/datasets.py index 4666f4c..ffd2314 100644 --- a/controller/src/controller/api/endpoints/datasets.py +++ b/controller/src/controller/api/endpoints/datasets.py @@ -16,7 +16,7 @@ from fastapi import APIRouter, Depends -from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.api.utils import AuthInfo, get_auth_user, get_db, parse_version from controller.db import client from genai_factory.schemas import APIResponse, Dataset, OutputMode @@ -27,19 +27,19 @@ def create_dataset( project_name: str, dataset: Dataset, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new dataset in the database. - :param project_name: The name of the project to create the dataset in. - :param dataset: The dataset to create. - :param session: The database session. + :param project_name: The name of the project to create the dataset in. + :param dataset: The dataset to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_dataset(dataset=dataset, session=session) + data = client.create_dataset(dataset=dataset, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -48,20 +48,37 @@ def create_dataset( ) -@router.get("/datasets/{uid}") -def get_dataset(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.get("/datasets/{name}") +def get_dataset( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a dataset from the database. - :param project_name: The name of the project to get the dataset from. - :param uid: The name of the dataset to get. - :param session: The database session. + :param project_name: The name of the project to get the dataset from. + :param name: The name of the dataset to get. + :param uid: The name of the dataset to get. + :param version: The version of the dataset to get. + :param db_session: The database session. - :return: The dataset from the database. + :return: The dataset from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: - data = client.get_dataset(project_id=project_id, uid=uid, session=session) + uid, version = parse_version(uid, version) + data = client.get_dataset( + name=name, + project_id=project_id, + uid=uid, + version=version, + db_session=db_session, + ) if data is None: return APIResponse( success=False, error=f"Dataset with uid = {uid} not found" @@ -74,52 +91,69 @@ def get_dataset(project_name: str, uid: str, session=Depends(get_db)) -> APIResp ) -@router.put("/datasets/{dataset_name}") +@router.put("/datasets/{name}") def update_dataset( project_name: str, dataset: Dataset, - dataset_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a dataset in the database. - :param project_name: The name of the project to update the dataset in. - :param dataset: The dataset to update. - :param dataset_name: The name of the dataset to update. - :param session: The database session. + :param project_name: The name of the project to update the dataset in. + :param dataset: The dataset to update. + :param name: The name of the dataset to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_dataset(dataset=dataset, session=session) + data = client.update_dataset(name=name, dataset=dataset, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update dataset {dataset_name} in project {project_name}: {e}", + error=f"Failed to update dataset {name} in project {project_name}: {e}", ) -@router.delete("/datasets/{uid}") -def delete_dataset(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.delete("/datasets/{name}") +def delete_dataset( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Delete a dataset from the database. - :param project_name: The name of the project to delete the dataset from. - :param uid: The UID of the dataset to delete. - :param session: The database session. + :param project_name: The name of the project to delete the dataset from. + :param name: The name of the dataset to delete. + :param uid: The UID of the dataset to delete. + :param version: The version of the dataset to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - client.delete_dataset(project_id=project_id, uid=uid, session=session) + client.delete_dataset( + name=name, + project_id=project_id, + uid=uid, + version=version, + db_session=db_session, + ) return APIResponse(success=True) except Exception as e: return APIResponse( success=False, - error=f"Failed to delete dataset {uid} in project {project_name}: {e}", + error=f"Failed to delete dataset {name} in project {project_name}: {e}", ) @@ -131,25 +165,28 @@ def list_datasets( task: str = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ List datasets in the database. - :param project_name: The name of the project to list the datasets from. - :param name: The name to filter by. - :param version: The version to filter by. - :param task: The task to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to list the datasets from. + :param name: The name to filter by. + :param version: The version to filter by. + :param task: The task to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).uid - project_id = client.get_project(project_name=project_name, session=session).uid + owner = client.get_user(user_name=auth.username, db_session=db_session) + owner_id = getattr(owner, "uid", None) + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: data = client.list_datasets( project_id=project_id, @@ -159,7 +196,7 @@ def list_datasets( task=task, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: diff --git a/controller/src/controller/api/endpoints/documents.py b/controller/src/controller/api/endpoints/documents.py index 6450dd1..b78565e 100644 --- a/controller/src/controller/api/endpoints/documents.py +++ b/controller/src/controller/api/endpoints/documents.py @@ -16,7 +16,7 @@ from fastapi import APIRouter, Depends -from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.api.utils import AuthInfo, get_auth_user, get_db, parse_version from controller.db import client from genai_factory.schemas import APIResponse, Document, OutputMode @@ -27,19 +27,19 @@ def create_document( project_name: str, document: Document, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new document in the database. - :param project_name: The name of the project to create the document in. - :param document: The document to create. - :param session: The database session. + :param project_name: The name of the project to create the document in. + :param document: The document to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_document(document=document, session=session) + data = client.create_document(document=document, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -48,80 +48,114 @@ def create_document( ) -@router.get("/documents/{uid}") -def get_document(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.get("/documents/{name}") +def get_document( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a document from the database. - :param project_name: The name of the project to get the document from. - :param uid: The UID of the document to get. - :param session: The database session. + :param project_name: The name of the project to get the document from. + :param name: The name of the document to get. + :param uid: The UID of the document to get. + :param version: The version of the document to get. + :param db_session: The database session. - :return: The document from the database. + :return: The document from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - data = client.get_document(project_id=project_id, uid=uid, session=session) + data = client.get_document( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) if data is None: return APIResponse( - success=False, error=f"Document with uid = {uid} not found" + success=False, error=f"Document with name = {name} not found" ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to get document {uid} in project {project_name}: {e}", + error=f"Failed to get document {name} in project {project_name}: {e}", ) -@router.put("/documents/{document_name}") +@router.put("/documents/{name}") def update_document( project_name: str, document: Document, - document_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a document in the database. - :param project_name: The name of the project to update the document in. - :param document: The document to update. - :param document_name: The name of the document to update. - :param session: The database session. + :param project_name: The name of the project to update the document in. + :param document: The document to update. + :param name: The name of the document to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_document(document=document, session=session) + data = client.update_document( + name=name, document=document, db_session=db_session + ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update document {document.name} in project {project_name}: {e}", + error=f"Failed to update document {name} in project {project_name}: {e}", ) -@router.delete("/documents/{uid}") +@router.delete("/documents/{name}") def delete_document( - project_name: str, uid: str, session=Depends(get_db) + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), ) -> APIResponse: """ Delete a document from the database. - :param project_name: The name of the project to delete the document from. - :param uid: The UID of the document to delete. - :param session: The database session. + :param project_name: The name of the project to delete the document from. + :param name: The name of the document to delete. + :param uid: The UID of the document to delete. + :param version: The version of the document to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - client.delete_document(project_id=project_id, uid=uid, session=session) + client.delete_document( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) return APIResponse(success=True) except Exception as e: return APIResponse( success=False, - error=f"Failed to delete document {uid} in project {project_name}: {e}", + error=f"Failed to delete document {name} in project {project_name}: {e}", ) @@ -132,24 +166,27 @@ def list_documents( version: str = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ List documents in the database. - :param project_name: The name of the project to list the documents from. - :param name: The name to filter by. - :param version: The version to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to list the documents from. + :param name: The name to filter by. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).uid - project_id = client.get_project(project_name=project_name, session=session).uid + owner = client.get_user(user_name=auth.username, db_session=db_session) + owner_id = getattr(owner, "uid", None) + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: data = client.list_documents( project_id=project_id, @@ -158,7 +195,7 @@ def list_documents( version=version, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: diff --git a/controller/src/controller/api/endpoints/models.py b/controller/src/controller/api/endpoints/models.py index c2bf9c5..db71efd 100644 --- a/controller/src/controller/api/endpoints/models.py +++ b/controller/src/controller/api/endpoints/models.py @@ -16,7 +16,7 @@ from fastapi import APIRouter, Depends -from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.api.utils import AuthInfo, get_auth_user, get_db, parse_version from controller.db import client from genai_factory.schemas import APIResponse, Model, OutputMode @@ -27,19 +27,19 @@ def create_model( project_name: str, model: Model, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new model in the database. - :param project_name: The name of the project to create the model in. - :param model: The model to create. - :param session: The database session. + :param project_name: The name of the project to create the model in. + :param model: The model to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_model(model=model, session=session) + data = client.create_model(model=model, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -48,71 +48,107 @@ def create_model( ) -@router.get("/models/{uid}") -def get_model(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.get("/models/{name}") +def get_model( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a model from the database. - :param project_name: The name of the project to get the model from. - :param uid: The UID of the model to get. - :param session: The database session. + :param project_name: The name of the project to get the model from. + :param name: The name of the model to get. + :param uid: The UID of the model to get. + :param version: The version of the model to get. + :param db_session: The database session. - :return: The model from the database. + :return: The model from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - data = client.get_model(project_id=project_id, uid=uid, session=session) + data = client.get_model( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) if data is None: - return APIResponse(success=False, error=f"Model with uid = {uid} not found") + return APIResponse( + success=False, error=f"Model with name = {name} not found" + ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to get model {uid} in project {project_name}: {e}", + error=f"Failed to get model {name} in project {project_name}: {e}", ) -@router.put("/models/{model_name}") +@router.put("/models/{name}") def update_model( project_name: str, model: Model, - model_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a model in the database. - :param project_name: The name of the project to update the model in. - :param model: The model to update. - :param model_name: The name of the model to update. - :param session: The database session. + :param project_name: The name of the project to update the model in. + :param model: The model to update. + :param name: The name of the model to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_model(model=model, session=session) + data = client.update_model(name=name, model=model, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update model {model.name} in project {project_name}: {e}", + error=f"Failed to update model {name} in project {project_name}: {e}", ) -@router.delete("/models/{uid}") -def delete_model(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.delete("/models/{name}") +def delete_model( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Delete a model from the database. - :param project_name: The name of the project to delete the model from. - :param uid: The ID of the model to delete. - :param session: The database session. + :param project_name: The name of the project to delete the model from. + :param name: The name of the model to delete. + :param uid: The ID of the model to delete. + :param version: The version of the model to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - client.delete_model(project_id=project_id, uid=uid, session=session) + client.delete_model( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) return APIResponse(success=True) except Exception as e: return APIResponse( @@ -129,25 +165,28 @@ def list_models( model_type: str = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ List models in the database. - :param project_name: The name of the project to list the models from. - :param name: The name to filter by. - :param version: The version to filter by. - :param model_type: The model type to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to list the models from. + :param name: The name to filter by. + :param version: The version to filter by. + :param model_type: The model type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).uid - project_id = client.get_project(project_name=project_name, session=session).uid + owner = client.get_user(user_name=auth.username, db_session=db_session) + owner_id = getattr(owner, "uid", None) + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: data = client.list_models( project_id=project_id, @@ -157,7 +196,7 @@ def list_models( model_type=model_type, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: diff --git a/controller/src/controller/api/endpoints/projects.py b/controller/src/controller/api/endpoints/projects.py index 87fb6b0..6b35bf8 100644 --- a/controller/src/controller/api/endpoints/projects.py +++ b/controller/src/controller/api/endpoints/projects.py @@ -16,7 +16,7 @@ from fastapi import APIRouter, Depends -from controller.api.utils import get_db +from controller.api.utils import get_db, parse_version from controller.db import client from genai_factory.schemas import APIResponse, OutputMode, Project @@ -26,18 +26,18 @@ @router.post("/projects") def create_project( project: Project, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new project in the database. - :param project: The project to create. - :param session: The database session. + :param project: The project to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_project(project=project, session=session) + data = client.create_project(project=project, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -45,72 +45,78 @@ def create_project( ) -@router.get("/projects/{project_name}") -def get_project(project_name: str, session=Depends(get_db)) -> APIResponse: +@router.get("/projects/{name}") +def get_project( + name: str, uid: str = None, version: str = None, db_session=Depends(get_db) +) -> APIResponse: """ Get a project from the database. - :param project_name: The name of the project to get. - :param session: The database session. + :param name: The name of the project to get. + :param uid: The UID of the project to get. + :param version: The version of the project to get. + :param db_session: The database session. - :return: The project from the database. + :return: The project from the database. """ + uid, version = parse_version(uid=uid, version=version) try: - data = client.get_project(project_name=project_name, session=session) + data = client.get_project( + name=name, uid=uid, version=version, db_session=db_session + ) if data is None: return APIResponse( - success=False, error=f"Project with name {project_name} not found" + success=False, error=f"Project with name {name} not found" ) return APIResponse(success=True, data=data) except Exception as e: - return APIResponse( - success=False, error=f"Failed to get project {project_name}: {e}" - ) + return APIResponse(success=False, error=f"Failed to get project {name}: {e}") -@router.put("/projects/{project_name}") +@router.put("/projects/{name}") def update_project( project: Project, - project_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a project in the database. - :param project: The project to update. - :param project_name: The name of the project to update. - :param session: The database session. + :param project: The project to update. + :param name: The name of the project to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_project(project=project, session=session) + data = client.update_project(name=name, project=project, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: - return APIResponse( - success=False, error=f"Failed to update project {project_name}: {e}" - ) + return APIResponse(success=False, error=f"Failed to update project {name}: {e}") -@router.delete("/projects/{project_name}") -def delete_project(project_name: str, session=Depends(get_db)) -> APIResponse: +@router.delete("/projects/{name}") +def delete_project( + name: str, uid: str = None, version: str = None, db_session=Depends(get_db) +) -> APIResponse: """ Delete a project from the database. - :param project_name: The name of the project to delete. - :param session: The database session. + :param name: The name of the project to delete. + :param uid: The UID of the project to delete. + :param version: The version of the project to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - project = client.get_project(project_name=project_name, session=session) - + uid, version = parse_version(uid=uid, version=version) try: - client.delete_project(uid=project.uid, session=session) + client.delete_project( + name=name, uid=uid, version=version, db_session=db_session + ) return APIResponse(success=True) except Exception as e: - return APIResponse( - success=False, error=f"Failed to delete project {project_name}: {e}" - ) + return APIResponse(success=False, error=f"Failed to delete project {name}: {e}") @router.get("/projects") @@ -119,21 +125,21 @@ def list_projects( owner_name: str = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ List projects in the database. - :param name: The name of the project to filter by. - :param owner_name: The name of the owner to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. + :param name: The name of the project to filter by. + :param owner_name: The name of the owner to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ if owner_name is not None: - owner_id = client.get_user(user_name=owner_name, session=session).uid + owner_id = client.get_user(user_name=owner_name, db_session=db_session).uid else: owner_id = None try: @@ -141,7 +147,7 @@ def list_projects( owner_id=owner_id, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, name=name, ) return APIResponse(success=True, data=data) diff --git a/controller/src/controller/api/endpoints/prompt_templates.py b/controller/src/controller/api/endpoints/prompt_templates.py index e4ada2f..0cbb6d5 100644 --- a/controller/src/controller/api/endpoints/prompt_templates.py +++ b/controller/src/controller/api/endpoints/prompt_templates.py @@ -16,7 +16,7 @@ from fastapi import APIRouter, Depends -from controller.api.utils import AuthInfo, get_auth_user, get_db +from controller.api.utils import AuthInfo, get_auth_user, get_db, parse_version from controller.db import client from genai_factory.schemas import APIResponse, OutputMode, PromptTemplate @@ -27,19 +27,19 @@ def create_prompt( project_name: str, prompt: PromptTemplate, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new prompt in the database. - :param project_name: The name of the project to create the prompt in. - :param prompt: The prompt to create. - :param session: The database session. + :param project_name: The name of the project to create the prompt in. + :param prompt: The prompt to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_prompt_template(prompt=prompt, session=session) + data = client.create_prompt_template(prompt=prompt, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -48,78 +48,114 @@ def create_prompt( ) -@router.get("/prompt_templates/{uid}") -def get_prompt(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.get("/prompt_templates/{name}") +def get_prompt( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a prompt from the database. - :param project_name: The name of the project to get the prompt from. - :param uid: The UID of the prompt to get. - :param session: The database session. + :param project_name: The name of the project to get the prompt from. + :param name: The name of the prompt to get. + :param uid: The UID of the prompt to get. + :param version: The version of the prompt to get. + :param db_session: The database session. - :return: The prompt from the database. + :return: The prompt from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - data = client.get_prompt(project_id=project_id, uid=uid, session=session) + data = client.get_prompt_template( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) if data is None: return APIResponse( - success=False, error=f"Prompt with uid = {uid} not found" + success=False, error=f"Prompt with name = {name} not found" ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to get prompt {uid} in project {project_name}: {e}", + error=f"Failed to get prompt {name} in project {project_name}: {e}", ) -@router.put("/prompt_templates/{prompt_name}") +@router.put("/prompt_templates/{name}") def update_prompt( project_name: str, prompt: PromptTemplate, - prompt_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a prompt in the database. - :param project_name: The name of the project to update the prompt in. - :param prompt: The prompt to update. - :param prompt_name: The name of the prompt to update. - :param session: The database session. + :param project_name: The name of the project to update the prompt in. + :param prompt: The prompt to update. + :param name: The name of the prompt to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_prompt_template(prompt=prompt, session=session) + data = client.update_prompt_template( + name=name, prompt=prompt, db_session=db_session + ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update prompt {prompt_name} in project {project_name}: {e}", + error=f"Failed to update prompt {name} in project {project_name}: {e}", ) -@router.delete("/prompt_templates/{uid}") -def delete_prompt(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.delete("/prompt_templates/{name}") +def delete_prompt( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Delete a prompt from the database. - :param project_name: The name of the project to delete the prompt from. - :param uid: The UID of the prompt to delete. - :param session: The database session. + :param project_name: The name of the project to delete the prompt from. + :param name: The name of the prompt to delete. + :param uid: The UID of the prompt to delete. + :param version: The version of the prompt to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - client.delete_prompt_template(project_id=project_id, uid=uid, session=session) + client.delete_prompt_template( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) return APIResponse(success=True) except Exception as e: return APIResponse( success=False, - error=f"Failed to delete prompt {uid} in project {project_name}: {e}", + error=f"Failed to delete prompt {name} in project {project_name}: {e}", ) @@ -130,24 +166,27 @@ def list_prompts( version: str = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ List prompts in the database. - :param project_name: The name of the project to list the prompts from. - :param name: The name to filter by. - :param version: The version to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to list the prompts from. + :param name: The name to filter by. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).uid - project_id = client.get_project(project_name=project_name, session=session).uid + owner = client.get_user(user_name=auth.username, db_session=db_session) + owner_id = getattr(owner, "uid", None) + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid try: data = client.list_prompt_templates( project_id=project_id, @@ -156,7 +195,7 @@ def list_prompts( version=version, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: diff --git a/controller/src/controller/api/endpoints/sessions.py b/controller/src/controller/api/endpoints/sessions.py index 15303bc..b9c86c4 100644 --- a/controller/src/controller/api/endpoints/sessions.py +++ b/controller/src/controller/api/endpoints/sessions.py @@ -24,48 +24,56 @@ @router.post("/sessions") def create_session( user_name: str, - chat_session: ChatSession, - session=Depends(get_db), + session: ChatSession, + db_session=Depends(get_db), ) -> APIResponse: """ Create a new session in the database. - :param user_name: The name of the user to create the session for. - :param chat_session: The session to create. - :param session: The database session. + :param user_name: The name of the user to create the session for. + :param session: The session to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_chat_session(chat_session=chat_session, session=session) + data = client.create_session(session=session, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to create session {chat_session.uid} for user {user_name}: {e}", + error=f"Failed to create session {session.uid} for user {user_name}: {e}", ) -@router.get("/sessions/{uid}") -def get_session(user_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.get("/sessions/{name}") +def get_session( + user_name: str, + name: str, + uid: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a session from the database. If the session ID is "$last", get the last session for the user. - :param user_name: The name of the user to get the session for. - :param uid: The UID of the session to get. if "$last" bring the last user's session. - :param session: The database session. + :param user_name: The name of the user to get the session for. + :param name: The name of the session to get. + :param uid: The UID of the session to get. if "$last" bring the last user's session. + :param db_session: The database session. - :return: The session from the database. + :return: The session from the database. """ user_id = None - if uid == "$last": - user_id = client.get_user(user_name=user_name, session=session).uid - uid = None + if name == "$last": + user_id = client.get_user(user_name=user_name, db_session=db_session).uid + name = None try: - data = client.get_chat_session(uid=uid, user_id=user_id, session=session) + data = client.get_session( + user_id=user_id, name=name, uid=uid, db_session=db_session + ) if data is None: return APIResponse( - success=False, error=f"Session with uid = {uid} not found" + success=False, error=f"Session with name = {name} not found" ) return APIResponse(success=True, data=data) except Exception as e: @@ -75,50 +83,60 @@ def get_session(user_name: str, uid: str, session=Depends(get_db)) -> APIRespons ) -@router.put("/sessions/{session_name}") +@router.put("/sessions/{name}") def update_session( user_name: str, - chat_session: ChatSession, - session=Depends(get_db), + name: str, + session: ChatSession, + db_session=Depends(get_db), ) -> APIResponse: """ Update a session in the database. - :param user_name: The name of the user to update the session for. - :param chat_session: The session to update. - :param session: The database session. + :param user_name: The name of the user to update the session for. + :param name: The name of the session to update. + :param session: The session to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_chat_session(chat_session=chat_session, session=session) + data = client.update_session(name=name, session=session, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update session {chat_session.uid} for user {user_name}: {e}", + error=f"Failed to update session {name} for user {user_name}: {e}", ) -@router.delete("/sessions/{uid}") -def delete_session(user_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.delete("/sessions/{name}") +def delete_session( + user_name: str, + name: str, + uid: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Delete a session from the database. - :param user_name: The name of the user to delete the session for. - :param uid: The UID of the session to delete. - :param session: The database session. + :param user_name: The name of the user to delete the session for. + :param name: The name of the session to delete. + :param uid: The UID of the session to delete. + :param db_session: The database session. :return: The response from the database. """ - user_id = client.get_user(user_name=user_name, session=session).uid + user_id = client.get_user(user_name=user_name, db_session=db_session).uid try: - client.delete_chat_session(uid=uid, user_id=user_id, session=session) + client.delete_session( + name=name, uid=uid, user_id=user_id, db_session=db_session + ) return APIResponse(success=True) except Exception as e: return APIResponse( success=False, - error=f"Failed to delete session {uid} for user {user_name}: {e}", + error=f"Failed to delete session {name} for user {user_name}: {e}", ) @@ -130,7 +148,7 @@ def list_sessions( created: str = None, workflow_id: str = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ List sessions in the database. @@ -141,20 +159,20 @@ def list_sessions( :param created: The date to filter by. :param workflow_id: The ID of the workflow to filter by. :param mode: The output mode. - :param session: The database session. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - user_id = client.get_user(user_name=user_name, session=session).uid + user_id = client.get_user(user_name=user_name, db_session=db_session).uid try: - data = client.list_chat_sessions( + data = client.list_sessions( user_id=user_id, name=name, last=last, created_after=created, workflow_id=workflow_id, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: diff --git a/controller/src/controller/api/endpoints/users.py b/controller/src/controller/api/endpoints/users.py index f7ca4b1..65d951c 100644 --- a/controller/src/controller/api/endpoints/users.py +++ b/controller/src/controller/api/endpoints/users.py @@ -24,18 +24,18 @@ @router.post("/users") def create_user( user: User, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new user in the database. - :param user: The user to create. - :param session: The database session. + :param user: The user to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_user(user=user, session=session) + data = client.create_user(user=user, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -43,74 +43,76 @@ def create_user( ) -@router.get("/users/{user_name}") -def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> APIResponse: +@router.get("/users/{name}") +def get_user( + name: str, + email: str = None, + uid: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a user from the database. - :param user_name: The name of the user to get. - :param email: The email address to get the user by if the name is not provided. - :param session: The database session. + :param name: The name of the user to get. + :param email: The email address to get the user by if the name is not provided. + :param uid: The UID of the user to get. + :param db_session: The database session. - :return: The user from the database. + :return: The user from the database. """ try: - data = client.get_user(user_name=user_name, email=email, session=session) + data = client.get_user(name=name, email=email, uid=uid, db_session=db_session) if data is None: return APIResponse( success=False, - error=f"User with name = {user_name}, email = {email} not found", + error=f"User with name = {name}, email = {email} not found", ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to get user with name = {user_name}, email = {email}: {e}", + error=f"Failed to get user with name = {name}, email = {email}: {e}", ) -@router.put("/users/{user_name}") +@router.put("/users/{name}") def update_user( user: User, - user_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a user in the database. - :param user: The user to update. - :param user_name: The name of the user to update. - :param session: The database session. + :param user: The user to update. + :param name: The name of the user to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_user(user=user, session=session) + data = client.update_user(name=name, user=user, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: - return APIResponse( - success=False, error=f"Failed to update user {user_name}: {e}" - ) + return APIResponse(success=False, error=f"Failed to update user {name}: {e}") -@router.delete("/users/{user_name}") -def delete_user(user_name: str, session=Depends(get_db)) -> APIResponse: +@router.delete("/users/{name}") +def delete_user(name: str, uid: str = None, db_session=Depends(get_db)) -> APIResponse: """ Delete a user from the database. - :param user_name: The name of the user to delete. - :param session: The database session. + :param name: The name of the user to delete. + :param uid: The UID of the user to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ - user = client.get_user(user_name=user_name, session=session) try: - client.delete_user(uid=user.uid, session=session) + client.delete_user(name=name, uid=uid, db_session=db_session) return APIResponse(success=True) except Exception as e: - return APIResponse( - success=False, error=f"Failed to delete user {user_name}: {e}" - ) + return APIResponse(success=False, error=f"Failed to delete user {name}: {e}") @router.get("/users") @@ -119,18 +121,18 @@ def list_users( email: str = None, full_name: str = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ List users in the database. - :param name: The name to filter by. - :param email: The email address to filter by. - :param full_name: The full name to filter by. - :param mode: The output mode. - :param session: The database session. + :param name: The name to filter by. + :param email: The email address to filter by. + :param full_name: The full name to filter by. + :param mode: The output mode. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: data = client.list_users( @@ -138,7 +140,7 @@ def list_users( email=email, full_name=full_name, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: diff --git a/controller/src/controller/api/endpoints/workflows.py b/controller/src/controller/api/endpoints/workflows.py index 9c0b3db..a47d936 100644 --- a/controller/src/controller/api/endpoints/workflows.py +++ b/controller/src/controller/api/endpoints/workflows.py @@ -22,6 +22,7 @@ _send_to_application, get_auth_user, get_db, + parse_version, ) from controller.db import client from genai_factory.schemas import ( @@ -40,19 +41,19 @@ def create_workflow( project_name: str, workflow: Workflow, - session=Depends(get_db), + db_session=Depends(get_db), ) -> APIResponse: """ Create a new workflow in the database. - :param project_name: The name of the project to create the workflow in. - :param workflow: The workflow to create. - :param session: The database session. + :param project_name: The name of the project to create the workflow in. + :param workflow: The workflow to create. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.create_workflow(workflow=workflow, session=session) + data = client.create_workflow(workflow=workflow, db_session=db_session) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( @@ -61,79 +62,114 @@ def create_workflow( ) -@router.get("/workflows/{uid}") -def get_workflow(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: +@router.get("/workflows/{name}") +def get_workflow( + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), +) -> APIResponse: """ Get a workflow from the database. - :param project_name: The name of the project to get the workflow from. - :param uid: The UID of the workflow to get. - :param session: The database session. + :param project_name: The name of the project to get the workflow from. + :param name: The name of the workflow to get. + :param uid: The UID of the workflow to get. + :param version: The version of the workflow to get. + :param db_session: The database session. - :return: The workflow from the database. + :return: The workflow from the database. """ - project_id = client.get_project(project_name=project_name, session=session).uid + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid, version) try: - data = client.get_workflow(project_id=project_id, uid=uid, session=session) + data = client.get_workflow( + name=name, + project_id=project_id, + uid=uid, + version=version, + db_session=db_session, + ) if data is None: return APIResponse( - success=False, error=f"Workflow with uid = {uid} not found" + success=False, error=f"Workflow with name = {name} not found" ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to get workflow {uid} in project {project_name}: {e}", + error=f"Failed to get workflow {name} in project {project_name}: {e}", ) -@router.put("/workflows/{workflow_name}") +@router.put("/workflows/{name}") def update_workflow( project_name: str, workflow: Workflow, - workflow_name: str, - session=Depends(get_db), + name: str, + db_session=Depends(get_db), ) -> APIResponse: """ Update a workflow in the database. - :param project_name: The name of the project to update the workflow in. - :param workflow: The workflow to update. - :param workflow_name: The name of the workflow to update. - :param session: The database session. + :param project_name: The name of the project to update the workflow in. + :param workflow: The workflow to update. + :param name: The name of the workflow to update. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ try: - data = client.update_workflow(workflow=workflow, session=session) + data = client.update_workflow( + name=name, workflow=workflow, db_session=db_session + ) return APIResponse(success=True, data=data) except Exception as e: return APIResponse( success=False, - error=f"Failed to update workflow {workflow_name} in project {project_name}: {e}", + error=f"Failed to update workflow {name} in project {project_name}: {e}", ) -@router.delete("/workflows/{uid}") +@router.delete("/workflows/{name}") def delete_workflow( - project_name: str, uid: str, session=Depends(get_db) + project_name: str, + name: str, + uid: str = None, + version: str = None, + db_session=Depends(get_db), ) -> APIResponse: """ Delete a workflow from the database. - :param project_name: The name of the project to delete the workflow from. - :param uid: The UID of the workflow to delete. - :param session: The database session. + :param project_name: The name of the project to delete the workflow from. + :param name: The name of the workflow to delete. + :param uid: The UID of the workflow to delete. + :param version: The version of the workflow to delete. + :param db_session: The database session. - :return: The response from the database. + :return: The response from the database. """ + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + uid, version = parse_version(uid=uid, version=version) try: - client.delete_workflow(uid=uid, session=session) + client.delete_workflow( + project_id=project_id, + name=name, + uid=uid, + version=version, + db_session=db_session, + ) return APIResponse(success=True) except Exception as e: return APIResponse( success=False, - error=f"Failed to delete workflow {uid} in project {project_name}: {e}", + error=f"Failed to delete workflow {name} in project {project_name}: {e}", ) @@ -145,27 +181,28 @@ def list_workflows( workflow_type: Union[WorkflowType, str] = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.DETAILS, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ List workflows in the database. - :param project_name: The name of the project to list the workflows from. - :param name: The name to filter by. - :param version: The version to filter by. - :param workflow_type: The workflow type to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to list the workflows from. + :param name: The name to filter by. + :param version: The version to filter by. + :param workflow_type: The workflow type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ - owner_id = client.get_user( - user_name=auth.username, email=auth.username, session=session + owner = client.get_user(user_name=auth.username, db_session=db_session) + owner_id = getattr(owner, "uid", None) + project_id = client.get_project( + project_name=project_name, db_session=db_session ).uid - project_id = client.get_project(project_name=project_name, session=session).uid try: data = client.list_workflows( name=name, @@ -175,7 +212,7 @@ def list_workflows( workflow_type=workflow_type, labels_match=labels, output_mode=mode, - session=session, + db_session=db_session, ) return APIResponse(success=True, data=data) except Exception as e: @@ -185,51 +222,51 @@ def list_workflows( ) -@router.post("/workflows/{uid}/infer") +@router.post("/workflows/{name}/infer") def infer_workflow( project_name: str, - uid: str, + name: str, query: QueryItem, - session=Depends(get_db), + db_session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), ) -> APIResponse: """ Run application workflow. - :param project_name: The name of the project to run the workflow in. - :param uid: The UID of the workflow to run. - :param query: The query to run. - :param session: The database session. - :param auth: The authentication information. + :param project_name: The name of the project to run the workflow in. + :param name: The name of the workflow to run. + :param query: The query to run. + :param db_session: The database session. + :param auth: The authentication information. - :return: The response from the database. + :return: The response from the database. """ # Get workflow from the database - project_id = client.get_project(project_name=project_name, session=session).uid - workflow = client.get_workflow(project_id=project_id, uid=uid, session=session) + project_id = client.get_project( + project_name=project_name, db_session=db_session + ).uid + workflow = client.get_workflow( + project_id=project_id, name=name, db_session=db_session + ) + if workflow is None: + return APIResponse( + success=False, error=f"Workflow with name = {name} not found" + ) path = workflow.get_infer_path() - if query.session_id: - # Get session by id: - chat_session = client.get_chat_session(uid=query.session_id, session=session) - if chat_session is None: - # If not id found, get session by name: - session_name = query.session_id - chat_session = client.list_chat_sessions(name=session_name, session=session) - # If not name found, create a new session: - if chat_session: - chat_session = chat_session[0] - else: - chat_session = client.create_chat_session( - chat_session=ChatSession( - name=session_name, - workflow_id=uid, - owner_id=client.get_user( - user_name=auth.username, session=session - ).uid, - ), - ) - query.session_id = chat_session.uid + if query.session_name: + # Get session by name: + session = client.get_session(name=query.session_name, db_session=db_session) + if session is None: + client.create_session( + session=ChatSession( + name=query.session_name, + workflow_id=workflow.uid, + owner_id=client.get_user( + user_name=auth.username, db_session=db_session + ).uid, + ), + ) # Prepare the data to send to the application's workflow data = { "item": query.dict(), @@ -248,5 +285,5 @@ def infer_workflow( except Exception as e: return APIResponse( success=False, - error=f"Failed to infer workflow {uid} in project {project_name}: {e}", + error=f"Failed to infer workflow {name} in project {project_name}: {e}", ) diff --git a/controller/src/controller/api/utils.py b/controller/src/controller/api/utils.py index 5cae73e..8348a85 100644 --- a/controller/src/controller/api/utils.py +++ b/controller/src/controller/api/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List, Tuple, Union import requests from fastapi import Header, Request @@ -62,7 +62,7 @@ def _send_to_application( :param auth: The authentication information to use. If provided, the username will be added to the headers. :param kwargs: Additional keyword arguments to pass in the request function. For example, headers, params, etc. - :return: The JSON response from the application. + :return: The JSON response from the application. """ if config.application_url not in path: url = f"{config.application_url}/api/{path}" @@ -88,3 +88,22 @@ def _send_to_application( else: # If the request failed, raise an exception response.raise_for_status() + + +def parse_version(uid: str = None, version: str = None) -> Tuple[str, str]: + """ + Parse the version string from the uid if uid = uid:version. Otherwise, return the version as is. + + :param uid: The UID string. + :param version: The version string to parse. + + :return: The UID and version strings. + """ + if uid and ":" in uid: + uid, version_from_uid = uid.split(":") + if version_from_uid and version: + raise ValueError( + "Version cannot be specified in both the UID and the version parameter." + ) + version = version_from_uid + return uid, version diff --git a/controller/src/controller/db/sqlclient.py b/controller/src/controller/db/sqlclient.py index 2ae7736..6c6147b 100644 --- a/controller/src/controller/db/sqlclient.py +++ b/controller/src/controller/db/sqlclient.py @@ -43,9 +43,9 @@ def get_db_session(self, session: sqlalchemy.orm.Session = None): """ Get a session from the session maker. - :param session: The session to use. If None, a new session will be created. + :param session: The session to use. If None, a new session will be created. - :return: The session. + :return: The session. """ return session or self._session_maker() @@ -53,18 +53,16 @@ def get_local_session(self): """ Get a local session from the local session maker. - :return: The session. + :return: The session. """ return self._local_maker() - def create_tables(self, drop_old: bool = False, names: list = None) -> None: + def create_tables(self, drop_old: bool = False, names: list = None): """ Create the tables in the database. - :param drop_old: Whether to drop the old tables before creating the new ones. - :param names: The names of the tables to create. If None, all tables will be created. - - :return: A response object with the success status. + :param drop_old: Whether to drop the old tables before creating the new ones. + :param names: The names of the tables to create. If None, all tables will be created. """ tables = None if names: @@ -80,11 +78,11 @@ def _create( Create an object in the database. This method generates a UID to the object and adds the object to the session and commits the transaction. - :param session: The session to use. - :param db_class: The DB class of the object. - :param obj: The object to create. + :param session: The session to use. + :param db_class: The DB class of the object. + :param obj: The object to create. - :return: The created object. + :return: The created object. """ session = self.get_db_session(session) # try: @@ -100,16 +98,23 @@ def _get( """ Get an object from the database. - :param session: The session to use. - :param db_class: The DB class of the object. - :param api_class: The API class of the object. - :param kwargs: The keyword arguments to filter the object. + :param session: The session to use. + :param db_class: The DB class of the object. + :param api_class: The API class of the object. + :param kwargs: The keyword arguments to filter the object. - :return: the object. + :return: The object. """ + kwargs = self._drop_none(**kwargs) session = self.get_db_session(session) - obj = session.query(db_class).filter_by(**kwargs).one_or_none() + obj = session.query(db_class).filter_by(**kwargs) if obj: + if obj.count() > 1: + if not kwargs.get("version"): + # Take the latest created: + obj = obj.order_by(db_class.created.desc()).first() + else: + obj = obj.one_or_none() return api_class.from_orm_object(obj) def _update( @@ -118,13 +123,14 @@ def _update( """ Update an object in the database. - :param session: The session to use. - :param db_class: The DB class of the object. - :param api_object: The API object with the new data. - :param kwargs: The keyword arguments to filter the object. + :param session: The session to use. + :param db_class: The DB class of the object. + :param api_object: The API object with the new data. + :param kwargs: The keyword arguments to filter the object. - :return: The updated object. + :return: The updated object. """ + kwargs = self._drop_none(**kwargs) session = self.get_db_session(session) obj = session.query(db_class).filter_by(**kwargs).one_or_none() if obj: @@ -133,16 +139,19 @@ def _update( session.commit() return api_object.__class__.from_orm_object(obj) else: - raise ValueError(f"{db_class} object ({kwargs}) not found") + # Create a new object if not found + logger.debug(f"Object not found, creating a new one: {api_object}") + return self._create(session, db_class, api_object) - def _delete(self, session: sqlalchemy.orm.Session, db_class, **kwargs) -> None: + def _delete(self, session: sqlalchemy.orm.Session, db_class, **kwargs): """ Delete an object from the database. - :param session: The session to use. - :param db_class: The DB class of the object. - :param kwargs: The keyword arguments to filter the object. + :param session: The session to use. + :param db_class: The DB class of the object. + :param kwargs: The keyword arguments to filter the object. """ + kwargs = self._drop_none(**kwargs) session = self.get_db_session(session) query = session.query(db_class).filter_by(**kwargs) for obj in query: @@ -161,14 +170,14 @@ def _list( """ List objects from the database. - :param session: The session to use. - :param db_class: The DB class of the object. - :param api_class: The API class of the object. - :param output_mode: The output mode. - :param labels_match: The labels to match, filter the objects by labels. - :param filters: The filters to apply. + :param session: The session to use. + :param db_class: The DB class of the object. + :param api_class: The API class of the object. + :param output_mode: The output mode. + :param labels_match: The labels to match, filter the objects by labels. + :param filters: The filters to apply. - :return: A list of the desired objects. + :return: A list of the desired objects. """ session = self.get_db_session(session) @@ -184,78 +193,94 @@ def _list( logger.debug(f"output: {output}") return _process_output(output, api_class, output_mode) + @staticmethod + def _drop_none(**kwargs): + return {k: v for k, v in kwargs.items() if v is not None} + def create_user( - self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None + self, + user: Union[api_models.User, dict], + db_session: sqlalchemy.orm.Session = None, ): """ Create a new user in the database. - :param user: The user object to create. - :param session: The session to use. + :param user: The user object to create. + :param db_session: The session to use. - :return: The created user. + :return: The created user. """ logger.debug(f"Creating user: {user}") if isinstance(user, dict): user = api_models.User.from_dict(user) user.name = user.name or user.email - return self._create(session, db.User, user) + return self._create(db_session, db.User, user) def get_user( self, - user_id: str = None, - user_name: str = None, + uid: str = None, + name: str = None, email: str = None, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, + **kwargs, ): """ Get a user from the database. Either user_id or user_name or email must be provided. - :param user_id: The UID of the user to get. - :param user_name: The name of the user to get. - :param email: The email of the user to get. - :param session: The session to use. + :param uid: The UID of the user to get. + :param name: The name of the user to get. + :param email: The email of the user to get. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the user. - :return: The user. + :return: The user. """ args = {} if email: args["email"] = email - elif user_name: - args["name"] = user_name - elif user_id: - args["uid"] = user_id + elif name: + args["name"] = name + elif uid: + args["uid"] = uid else: raise ValueError("Either user_id or user_name or email must be provided") - logger.debug(f"Getting user: user_id={user_id}, user_name={user_name}") - return self._get(session, db.User, api_models.User, **args) + # add additional filters + args.update(kwargs) + logger.debug(f"Getting user: name={name}") + return self._get(db_session, db.User, api_models.User, **args) def update_user( - self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None + self, + name: str, + user: Union[api_models.User, dict], + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing user in the database. - :param user: The user object with the new data. - :param session: The session to use. + :param name: The name of the user to update. + :param user: The user object with the new data. + :param db_session: The session to use. - :return: The updated user. + :return: The updated user. """ logger.debug(f"Updating user: {user}") if isinstance(user, dict): user = api_models.User.from_dict(user) - return self._update(session, db.User, user, uid=user.uid) + return self._update(db_session, db.User, user, name=name, uid=user.uid) - def delete_user(self, uid: str, session: sqlalchemy.orm.Session = None): + def delete_user( + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs + ): """ Delete a user from the database. - :param uid: The UID of the user to delete. - :param session: The session to use. + :param name: The name of the user to delete. + :param db_session: The session to use. """ - logger.debug(f"Deleting user: user_uid={uid}") - self._delete(session, db.User, uid=uid) + logger.debug(f"Deleting user: name={name}") + self._delete(db_session, db.User, name=name, **kwargs) def list_users( self, @@ -264,19 +289,19 @@ def list_users( full_name: str = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List users from the database. - :param name: The name to filter the users by. - :param email: The email to filter the users by. - :param full_name: The full name to filter the users by. - :param labels_match: The labels to match, filter the users by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the users by. + :param email: The email to filter the users by. + :param full_name: The full name to filter the users by. + :param labels_match: The labels to match, filter the users by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: List of users. + :return: List of users. """ logger.debug( f"Getting users: email={email}, full_name={full_name}, mode={output_mode}" @@ -289,7 +314,7 @@ def list_users( if full_name: filters.append(db.User.full_name.like(f"%{full_name}%")) return self._list( - session=session, + session=db_session, db_class=db.User, api_class=api_models.User, output_mode=output_mode, @@ -300,60 +325,69 @@ def list_users( def create_project( self, project: Union[api_models.Project, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new project in the database. - :param project: The project object to create. - :param session: The session to use. + :param project: The project object to create. + :param db_session: The session to use. - :return: The created project. + :return: The created project. """ logger.debug(f"Creating project: {project}") if isinstance(project, dict): project = api_models.Project.from_dict(project) - return self._create(session, db.Project, project) + return self._create(db_session, db.Project, project) - def get_project(self, project_name: str, session: sqlalchemy.orm.Session = None): + def get_project( + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs + ): """ Get a project from the database. - :param project_name: The name of the project to get. - :param session: The session to use. + :param name: The name of the project to get. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the project. - :return: The requested project. + :return: The requested project. """ - logger.debug(f"Getting project: project_name={project_name}") - return self._get(session, db.Project, api_models.Project, name=project_name) + logger.debug(f"Getting project: name={name}") + return self._get( + db_session, db.Project, api_models.Project, name=name, **kwargs + ) def update_project( self, + name: str, project: Union[api_models.Project, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing project in the database. - :param project: The project object with the new data. - :param session: The session to use. + :param name: The name of the project to update. + :param project: The project object with the new data. + :param db_session: The session to use. - :return: The updated project. + :return: The updated project. """ logger.debug(f"Updating project: {project}") if isinstance(project, dict): project = api_models.Project.from_dict(project) - return self._update(session, db.Project, project, uid=project.uid) + return self._update(db_session, db.Project, project, name=name, uid=project.uid) - def delete_project(self, uid: str, session: sqlalchemy.orm.Session = None): + def delete_project( + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs + ): """ Delete a project from the database. - :param uid: The UID of the project to delete. - :param session: The session to use. + :param name: The name of the project to delete. + :param db_session: The session to use. """ - logger.debug(f"Deleting project: project_uid={uid}") - self._delete(session, db.Project, uid=uid) + logger.debug(f"Deleting project: name={name}") + self._delete(db_session, db.Project, name=name, **kwargs) def list_projects( self, @@ -362,19 +396,19 @@ def list_projects( version: str = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List projects from the database. - :param name: The name to filter the projects by. - :param owner_id: The owner to filter the projects by. - :param version: The version to filter the projects by. - :param labels_match: The labels to match, filter the projects by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the projects by. + :param owner_id: The owner to filter the projects by. + :param version: The version to filter the projects by. + :param labels_match: The labels to match, filter the projects by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: List of projects. + :return: List of projects. """ logger.debug( f"Getting projects: owner_id={owner_id}, version={version}, labels_match={labels_match}, mode={output_mode}" @@ -387,9 +421,9 @@ def list_projects( if version: filters.append(db.Project.version == version) return self._list( - session=session, - db_class=db.User, - api_class=api_models.User, + session=db_session, + db_class=db.Project, + api_class=api_models.Project, output_mode=output_mode, labels_match=labels_match, filters=filters, @@ -398,80 +432,73 @@ def list_projects( def create_data_source( self, data_source: Union[api_models.DataSource, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new data source in the database. :param data_source: The data source object to create. - :param session: The session to use. + :param db_session: The session to use. - :return: The created data source. + :return: The created data source. """ logger.debug(f"Creating data source: {data_source}") if isinstance(data_source, dict): data_source = api_models.DataSource.from_dict(data_source) - return self._create(session, db.DataSource, data_source) + return self._create(db_session, db.DataSource, data_source) def get_data_source( - self, - project_id: str, - uid: str, - session: sqlalchemy.orm.Session = None, + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Get a data source from the database. - :param project_id: The ID of the project to get the data source from. - :param uid: The UID of the data source to get. - :param session: The session to use. + :param name: The name of the data source to get. + :param db_session: The session to use. - :return: The requested data source. + :return: The requested data source. """ - logger.debug(f"Getting data source: data_source_uid={uid}") + logger.debug(f"Getting data source: name={name}") return self._get( - session, - db.DataSource, - api_models.DataSource, - uid=uid, - project_id=project_id, + db_session, db.DataSource, api_models.DataSource, name=name, **kwargs ) def update_data_source( self, + name: str, data_source: Union[api_models.DataSource, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing data source in the database. + :param name: The name of the data source to update. :param data_source: The data source object with the new data. - :param session: The session to use. + :param db_session: The session to use. - :return: The updated data source. + :return: The updated data source. """ logger.debug(f"Updating data source: {data_source}") if isinstance(data_source, dict): data_source = api_models.DataSource.from_dict(data_source) - return self._update(session, db.DataSource, data_source, uid=data_source.uid) + return self._update( + db_session, db.DataSource, data_source, name=name, uid=data_source.uid + ) def delete_data_source( - self, - project_id: str, - uid: str, - session: sqlalchemy.orm.Session = None, + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Delete a data source from the database. - :param project_id: The ID of the project to delete the data source from. - :param uid: The ID of the data source to delete. - :param session: The session to use. + :param name: The name of the data source to delete. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the data source. - :return: A response object with the success status. + :return: A response object with the success status. """ - logger.debug(f"Deleting data source: data_source_id={uid}") - self._delete(session, db.DataSource, project_id=project_id, uid=uid) + logger.debug(f"Deleting data source: name={name}") + self._delete(db_session, db.DataSource, name=name, **kwargs) def list_data_sources( self, @@ -482,21 +509,21 @@ def list_data_sources( data_source_type: Union[api_models.DataSourceType, str] = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List data sources from the database. - :param name: The name to filter the data sources by. - :param owner_id: The owner to filter the data sources by. - :param version: The version to filter the data sources by. - :param project_id: The project to filter the data sources by. - :param data_source_type: The data source type to filter the data sources by. - :param labels_match: The labels to match, filter the data sources by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the data sources by. + :param owner_id: The owner to filter the data sources by. + :param version: The version to filter the data sources by. + :param project_id: The project to filter the data sources by. + :param data_source_type: The data source type to filter the data sources by. + :param labels_match: The labels to match, filter the data sources by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: List of data sources. + :return: List of data sources. """ logger.debug( f"Getting data sources: name={name}, owner_id={owner_id}, version={version}," @@ -514,7 +541,7 @@ def list_data_sources( if data_source_type: filters.append(db.DataSource.data_source_type == data_source_type) return self._list( - session=session, + session=db_session, db_class=db.DataSource, api_class=api_models.DataSource, output_mode=output_mode, @@ -525,72 +552,70 @@ def list_data_sources( def create_dataset( self, dataset: Union[api_models.Dataset, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new dataset in the database. - :param dataset: The dataset object to create. - :param session: The session to use. + :param dataset: The dataset object to create. + :param db_session: The session to use. - :return: The created dataset. + :return: The created dataset. """ logger.debug(f"Creating dataset: {dataset}") if isinstance(dataset, dict): dataset = api_models.Dataset.from_dict(dataset) - return self._create(session, db.Dataset, dataset) + return self._create(db_session, db.Dataset, dataset) def get_dataset( - self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Get a dataset from the database. - :param project_id: The ID of the project to get the dataset from. - :param uid: The UID of the dataset to get. - :param session: The session to use. + :param name: The name of the dataset to get. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the dataset. - :return: The requested dataset. + :return: The requested dataset. """ - logger.debug(f"Getting dataset: dataset_id={uid}") + logger.debug(f"Getting dataset: name={name}") return self._get( - session, - db.Dataset, - api_models.Dataset, - uid=uid, - project_id=project_id, + db_session, db.Dataset, api_models.Dataset, name=name, **kwargs ) def update_dataset( self, + name: str, dataset: Union[api_models.Dataset, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing dataset in the database. - :param dataset: The dataset object with the new data. - :param session: The session to use. + :param name: The name of the dataset to update. + :param dataset: The dataset object with the new data. + :param db_session: The session to use. - :return: The updated dataset. + :return: The updated dataset. """ logger.debug(f"Updating dataset: {dataset}") if isinstance(dataset, dict): dataset = api_models.Dataset.from_dict(dataset) - return self._update(session, db.Dataset, dataset, uid=dataset.uid) + return self._update(db_session, db.Dataset, dataset, name=name, uid=dataset.uid) def delete_dataset( - self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Delete a dataset from the database. - :param project_id: The ID of the project to delete the dataset from. - :param uid: The ID of the dataset to delete. - :param session: The session to use. + :param name: The name of the dataset to delete. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the dataset. """ - logger.debug(f"Deleting dataset: dataset_id={uid}") - self._delete(session, db.Dataset, project_id=project_id, uid=uid) + logger.debug(f"Deleting dataset: name={name}") + self._delete(db_session, db.Dataset, name=name, **kwargs) def list_datasets( self, @@ -601,21 +626,21 @@ def list_datasets( task: str = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List datasets from the database. - :param name: The name to filter the datasets by. - :param owner_id: The owner to filter the datasets by. - :param version: The version to filter the datasets by. - :param project_id: The project to filter the datasets by. - :param task: The task to filter the datasets by. - :param labels_match: The labels to match, filter the datasets by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the datasets by. + :param owner_id: The owner to filter the datasets by. + :param version: The version to filter the datasets by. + :param project_id: The project to filter the datasets by. + :param task: The task to filter the datasets by. + :param labels_match: The labels to match, filter the datasets by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: The list of datasets. + :return: The list of datasets. """ logger.debug( f"Getting datasets: owner_id={owner_id}, version={version}, task={task}, labels_match={labels_match}," @@ -633,7 +658,7 @@ def list_datasets( if task: filters.append(db.Dataset.task == task) return self._list( - session=session, + session=db_session, db_class=db.Dataset, api_class=api_models.Dataset, output_mode=output_mode, @@ -644,68 +669,66 @@ def list_datasets( def create_model( self, model: Union[api_models.Model, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new model in the database. - :param model: The model object to create. - :param session: The session to use. + :param model: The model object to create. + :param db_session: The session to use. - :return: The created model. + :return: The created model. """ logger.debug(f"Creating model: {model}") if isinstance(model, dict): model = api_models.Model.from_dict(model) - return self._create(session, db.Model, model) + return self._create(db_session, db.Model, model) - def get_model( - self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None - ): + def get_model(self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs): """ Get a model from the database. - :param project_id: The ID of the project to get the model from. - :param uid: The UID of the model to get. - :param session: The session to use. + :param name: The name of the model to get. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the model. - :return: The requested model. + :return: The requested model. """ - logger.debug(f"Getting model: model_id={uid}") - return self._get( - session, db.Model, api_models.Model, project_id=project_id, uid=uid - ) + logger.debug(f"Getting model: name={name}") + return self._get(db_session, db.Model, api_models.Model, name=name, **kwargs) def update_model( self, + name: str, model: Union[api_models.Model, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing model in the database. - :param model: The model object with the new data. - :param session: The session to use. + :param name: The name of the model to update. + :param model: The model object with the new data. + :param db_session: The session to use. - :return: The updated model. + :return: The updated model. """ logger.debug(f"Updating model: {model}") if isinstance(model, dict): model = api_models.Model.from_dict(model) - return self._update(session, db.Model, model, uid=model.uid) + return self._update(db_session, db.Model, model, name=name, uid=model.uid) def delete_model( - self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Delete a model from the database. - :param project_id: The ID of the project to delete the model from. - :param uid: The UID of the model to delete. - :param session: The session to use. + :param name: The name of the model to delete. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the model. """ - logger.debug(f"Deleting model: model_id={uid}") - self._delete(session, db.Model, project_id=project_id, uid=uid) + logger.debug(f"Deleting model: name={name}") + self._delete(db_session, db.Model, name=name, **kwargs) def list_models( self, @@ -717,22 +740,22 @@ def list_models( task: str = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List models from the database. - :param name: The name to filter the models by. - :param owner_id: The owner to filter the models by. - :param version: The version to filter the models by. - :param project_id: The project to filter the models by. - :param model_type: The model type to filter the models by. - :param task: The task to filter the models by. - :param labels_match: The labels to match, filter the models by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the models by. + :param owner_id: The owner to filter the models by. + :param version: The version to filter the models by. + :param project_id: The project to filter the models by. + :param model_type: The model type to filter the models by. + :param task: The task to filter the models by. + :param labels_match: The labels to match, filter the models by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: The list of models. + :return: The list of models. """ logger.debug( f"Getting models: owner_id={owner_id}, version={version}, project_id={project_id}," @@ -752,7 +775,7 @@ def list_models( if task: filters.append(db.Model.task == task) return self._list( - session=session, + session=db_session, db_class=db.Model, api_class=api_models.Model, output_mode=output_mode, @@ -763,80 +786,79 @@ def list_models( def create_prompt_template( self, prompt_template: Union[api_models.PromptTemplate, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new prompt template in the database. :param prompt_template: The prompt template object to create. - :param session: The session to use. + :param db_session: The session to use. - :return: The created prompt template. + :return: The created prompt template. """ logger.debug(f"Creating prompt template: {prompt_template}") if isinstance(prompt_template, dict): prompt_template = api_models.PromptTemplate.from_dict(prompt_template) - return self._create(session, db.PromptTemplate, prompt_template) + return self._create(db_session, db.PromptTemplate, prompt_template) def get_prompt_template( - self, - project_id: str, - uid: str, - session: sqlalchemy.orm.Session = None, + self, name: str = None, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Get a prompt template from the database. - :param project_id: The ID of the project to get the prompt template from. - :param uid: The UID of the prompt template to get. - :param session: The session to use. + :param name: The name of the prompt template to get. + :param db_session: The session to use. - :return: The requested prompt template. + :return: The requested prompt template. """ - logger.debug(f"Getting prompt template: prompt_template_id={uid}") + logger.debug(f"Getting prompt template: name={name}") return self._get( - session, + db_session, db.PromptTemplate, api_models.PromptTemplate, - project_id=project_id, - uid=uid, + name=name, + **kwargs, ) def update_prompt_template( self, + name: str, prompt_template: Union[api_models.PromptTemplate, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing prompt template in the database. + :param name: The name of the prompt template to update. :param prompt_template: The prompt template object with the new data. - :param session: The session to use. + :param db_session: The session to use. - :return: The updated prompt template. + :return: The updated prompt template. """ logger.debug(f"Updating prompt template: {prompt_template}") if isinstance(prompt_template, dict): prompt_template = api_models.PromptTemplate.from_dict(prompt_template) return self._update( - session, db.PromptTemplate, prompt_template, uid=prompt_template.uid + db_session, + db.PromptTemplate, + prompt_template, + name=name, + uid=prompt_template.uid, ) def delete_prompt_template( - self, - project_id: str, - uid: str, - session: sqlalchemy.orm.Session = None, + self, name: str = None, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Delete a prompt template from the database. - :param project_id: The ID of the project to delete the prompt template from. - :param uid: The ID of the prompt template to delete. - :param session: The session to use. + :param name: The name of the prompt template to delete. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the prompt template. """ - logger.debug(f"Deleting prompt template: prompt_template_id={uid}") - self._delete(session, db.PromptTemplate, project_id=project_id, uid=uid) + logger.debug(f"Deleting prompt template: name={name}") + self._delete(db_session, db.PromptTemplate, name=name, **kwargs) def list_prompt_templates( self, @@ -846,20 +868,20 @@ def list_prompt_templates( project_id: str = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List prompt templates from the database. - :param name: The name to filter the prompt templates by. - :param owner_id: The owner to filter the prompt templates by. - :param version: The version to filter the prompt templates by. - :param project_id: The project to filter the prompt templates by. - :param labels_match: The labels to match, filter the prompt templates by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the prompt templates by. + :param owner_id: The owner to filter the prompt templates by. + :param version: The version to filter the prompt templates by. + :param project_id: The project to filter the prompt templates by. + :param labels_match: The labels to match, filter the prompt templates by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: The list of prompt templates. + :return: The list of prompt templates. """ logger.debug( f"Getting prompt templates: owner_id={owner_id}, version={version}, project_id={project_id}," @@ -875,7 +897,7 @@ def list_prompt_templates( if project_id: filters.append(db.PromptTemplate.project_id == project_id) return self._list( - session=session, + session=db_session, db_class=db.PromptTemplate, api_class=api_models.PromptTemplate, output_mode=output_mode, @@ -886,72 +908,72 @@ def list_prompt_templates( def create_document( self, document: Union[api_models.Document, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new document in the database. - :param document: The document object to create. - :param session: The session to use. + :param document: The document object to create. + :param db_session: The session to use. - :return: The created document. + :return: The created document. """ logger.debug(f"Creating document: {document}") if isinstance(document, dict): document = api_models.Document.from_dict(document) - return self._create(session, db.Document, document) + return self._create(db_session, db.Document, document) def get_document( - self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Get a document from the database. - :param project_id: The ID of the project to get the document from. - :param uid: The UID of the document to get. - :param session: The session to use. + :param name: The name of the document to get. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the document. - :return: The requested document. + :return: The requested document. """ - logger.debug(f"Getting document: document_id={uid}") + logger.debug(f"Getting document: name={name}") return self._get( - session, - db.Document, - api_models.Document, - project_id=project_id, - uid=uid, + db_session, db.Document, api_models.Document, name=name, **kwargs ) def update_document( self, + name: str, document: Union[api_models.Document, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing document in the database. - :param document: The document object with the new data. - :param session: The session to use. + :param name: The name of the document to update. + :param document: The document object with the new data. + :param db_session: The session to use. - :return: The updated document. + :return: The updated document. """ logger.debug(f"Updating document: {document}") if isinstance(document, dict): document = api_models.Document.from_dict(document) - return self._update(session, db.Document, document, uid=document.uid) + return self._update( + db_session, db.Document, document, name=name, uid=document.uid + ) def delete_document( - self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Delete a document from the database. - :param project_id: The ID of the project to delete the document from. - :param uid: The UID of the document to delete. - :param session: The session to use. + :param name: The name of the document to delete. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the document. """ - logger.debug(f"Deleting document: document_id={uid}") - self._delete(session, db.Document, project_id=project_id, uid=uid) + logger.debug(f"Deleting document: name={name}") + self._delete(db_session, db.Document, name=name, **kwargs) def list_documents( self, @@ -961,20 +983,20 @@ def list_documents( project_id: str = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List documents from the database. - :param name: The name to filter the documents by. - :param owner_id: The owner to filter the documents by. - :param version: The version to filter the documents by. - :param project_id: The project to filter the documents by. - :param labels_match: The labels to match, filter the documents by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the documents by. + :param owner_id: The owner to filter the documents by. + :param version: The version to filter the documents by. + :param project_id: The project to filter the documents by. + :param labels_match: The labels to match, filter the documents by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: The list of documents. + :return: The list of documents. """ logger.debug( f"Getting documents: owner_id={owner_id}, version={version}, project_id={project_id}," @@ -990,7 +1012,7 @@ def list_documents( if project_id: filters.append(db.Document.project_id == project_id) return self._list( - session=session, + session=db_session, db_class=db.Document, api_class=api_models.Document, output_mode=output_mode, @@ -1001,72 +1023,72 @@ def list_documents( def create_workflow( self, workflow: Union[api_models.Workflow, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Create a new workflow in the database. - :param workflow: The workflow object to create. - :param session: The session to use. + :param workflow: The workflow object to create. + :param db_session: The session to use. - :return: The created workflow. + :return: The created workflow. """ logger.debug(f"Creating workflow: {workflow}") if isinstance(workflow, dict): workflow = api_models.Workflow.from_dict(workflow) - return self._create(session, db.Workflow, workflow) + return self._create(db_session, db.Workflow, workflow) def get_workflow( - self, - project_id: str, - uid: str, - session: sqlalchemy.orm.Session = None, + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs ): """ Get a workflow from the database. - :param project_id: The ID of the project to get the workflow from. - :param uid: The UID of the workflow to get. - :param session: The session to use. + :param name: The name of the workflow to get. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the workflow. - :return: The requested workflow. + :return: The requested workflow. """ - logger.debug(f"Getting workflow: workflow_uid={uid}") + logger.debug(f"Getting workflow: name={name}") return self._get( - session, - db.Workflow, - api_models.Workflow, - project_id=project_id, - uid=uid, + db_session, db.Workflow, api_models.Workflow, name=name, **kwargs ) def update_workflow( self, + name: str, workflow: Union[api_models.Workflow, dict], - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ Update an existing workflow in the database. - :param workflow: The workflow object with the new data. - :param session: The session to use. + :param name: The name of the workflow to update. + :param workflow: The workflow object with the new data. + :param db_session: The session to use. - :return: The updated workflow. + :return: The updated workflow. """ logger.debug(f"Updating workflow: {workflow}") if isinstance(workflow, dict): workflow = api_models.Workflow.from_dict(workflow) - return self._update(session, db.Workflow, workflow, uid=workflow.uid) + return self._update( + db_session, db.Workflow, workflow, name=name, uid=workflow.uid + ) - def delete_workflow(self, uid: str, session: sqlalchemy.orm.Session = None): + def delete_workflow( + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs + ): """ Delete a workflow from the database. - :param uid: The ID of the workflow to delete. - :param session: The session to use. + :param name: The name of the workflow to delete. + :param db_session: The session to use. + :param kwargs: Additional keyword arguments to filter the workflow. """ - logger.debug(f"Deleting workflow: workflow_id={uid}") - self._delete(session, db.Workflow, uid=uid) + logger.debug(f"Deleting workflow: name={name}") + self._delete(db_session, db.Workflow, name=name, **kwargs) def list_workflows( self, @@ -1077,21 +1099,21 @@ def list_workflows( workflow_type: Union[api_models.WorkflowType, str] = None, labels_match: Union[list, str] = None, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ List workflows from the database. - :param name: The name to filter the workflows by. - :param owner_id: The owner to filter the workflows by. - :param version: The version to filter the workflows by. - :param project_id: The project to filter the workflows by. - :param workflow_type: The workflow type to filter the workflows by. - :param labels_match: The labels to match, filter the workflows by labels. - :param output_mode: The output mode. - :param session: The session to use. + :param name: The name to filter the workflows by. + :param owner_id: The owner to filter the workflows by. + :param version: The version to filter the workflows by. + :param project_id: The project to filter the workflows by. + :param workflow_type: The workflow type to filter the workflows by. + :param labels_match: The labels to match, filter the workflows by labels. + :param output_mode: The output mode. + :param db_session: The session to use. - :return: The list of workflows. + :return: The list of workflows. """ logger.debug( f"Getting workflows: name={name}, owner_id={owner_id}, version={version}, project_id={project_id}," @@ -1109,7 +1131,7 @@ def list_workflows( if workflow_type: filters.append(db.Workflow.workflow_type == workflow_type) return self._list( - session=session, + session=db_session, db_class=db.Workflow, api_class=api_models.Workflow, output_mode=output_mode, @@ -1117,74 +1139,87 @@ def list_workflows( filters=filters, ) - def create_chat_session( + def create_session( self, - chat_session: Union[api_models.ChatSession, dict], - session: sqlalchemy.orm.Session = None, + session: Union[api_models.ChatSession, dict], + db_session: sqlalchemy.orm.Session = None, ): """ - Create a new chat session in the database. + Create a new session in the database. - :param chat_session: The chat session object to create. - :param session: The session to use. + :param session: The chat session object to create. + :param db_session: The session to use. - :return: The created chat session. + :return: The created session. """ - logger.debug(f"Creating chat session: {chat_session}") - if isinstance(chat_session, dict): - chat_session = api_models.ChatSession.from_dict(chat_session) - return self._create(session, db.Session, chat_session) + logger.debug(f"Creating session: {session}") + if isinstance(session, dict): + session = api_models.ChatSession.from_dict(session) + return self._create(db_session, db.Session, session) - def get_chat_session( + def get_session( self, + name: str = None, uid: str = None, user_id: str = None, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, + **kwargs, ): """ - Get a chat session from the database. + Get a session from the database. - :param uid: The ID of the chat session to get. - :param user_id: The UID of the user to get the last session for. - :param session: The DB session to use. + :param name: The name of the session to get. + :param uid: The ID of the session to get. + :param user_id: The UID of the user to get the last session for. + :param db_session: The DB session to use. + :param kwargs: Additional keyword arguments to filter the session. - :return: The requested chat session. + :return: The requested session. """ - logger.debug(f"Getting chat session: session_uid={uid}, user_id={user_id}") + logger.debug(f"Getting session: name={name}, uid={uid}, user_id={user_id}") if uid: - return self._get(session, db.Session, api_models.ChatSession, uid=uid) + return self._get( + db_session, db.Session, api_models.ChatSession, uid=uid, **kwargs + ) elif user_id: # get the last session for the user - return self.list_chat_sessions(user_id=user_id, last=1, session=session)[0] + return self.list_sessions( + user_id=user_id, last=1, db_session=db_session, **kwargs + )[0] raise ValueError("session_name or user_id must be provided") - def update_chat_session( + def update_session( self, - chat_session: Union[api_models.ChatSession, dict], - session: sqlalchemy.orm.Session = None, + name: str, + session: Union[api_models.ChatSession, dict], + db_session: sqlalchemy.orm.Session = None, ): """ - Update a chat session in the database. + Update a session in the database. - :param chat_session: The chat session object with the new data. - :param session: The DB session to use. + :param name: The name of the session to update. + :param session: The session object with the new data. + :param db_session: The DB session to use. - :return: The updated chat session. + :return: The updated chat session. """ - logger.debug(f"Updating chat session: {chat_session}") - return self._update(session, db.Session, chat_session, uid=chat_session.uid) + logger.debug(f"Updating chat session: {session}") + return self._update(db_session, db.Session, session, name=name, uid=session.uid) - def delete_chat_session(self, uid: str, session: sqlalchemy.orm.Session = None): + def delete_session( + self, name: str, db_session: sqlalchemy.orm.Session = None, **kwargs + ): """ - Delete a chat session from the database. + Delete a session from the database. - :param uid: The UID of the chat session to delete. - :param session: The DB session to use. + :param name: The name of the session to delete. + :param db_session: The DB session to use. + :param kwargs: Additional keyword arguments to filter the session. """ - logger.debug(f"Deleting chat session: session_id={uid}") - self._delete(session, db.Session, uid=uid) + logger.debug(f"Deleting session: name={name}") + self._delete(db_session, db.Session, name=name, **kwargs) - def list_chat_sessions( + def list_sessions( self, name: str = None, user_id: str = None, @@ -1192,26 +1227,26 @@ def list_chat_sessions( created_after=None, last=0, output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, - session: sqlalchemy.orm.Session = None, + db_session: sqlalchemy.orm.Session = None, ): """ - List chat sessions from the database. + List sessions from the database. - :param name: The name to filter the chat sessions by. - :param user_id: The user ID to filter the chat sessions by. - :param workflow_id: The workflow ID to filter the chat sessions by. - :param created_after: The date to filter the chat sessions by. - :param last: The number of last chat sessions to return. - :param output_mode: The output mode. - :param session: The DB session to use. + :param name: The name to filter the chat sessions by. + :param user_id: The user ID to filter the chat sessions by. + :param workflow_id: The workflow ID to filter the chat sessions by. + :param created_after: The date to filter the chat sessions by. + :param last: The number of last chat sessions to return. + :param output_mode: The output mode. + :param db_session: The DB session to use. - :return: The list of chat sessions. + :return: The list of chat sessions. """ logger.debug( f"Getting chat sessions: user_id={user_id}, workflow_id={workflow_id} created>{created_after}," f" last={last}, mode={output_mode}" ) - session = self.get_db_session(session) + session = self.get_db_session(db_session) query = session.query(db.Session) if name: query = query.filter(db.Session.name == name) diff --git a/controller/src/controller/db/sqldb.py b/controller/src/controller/db/sqldb.py index 7c8ad2a..d609045 100644 --- a/controller/src/controller/db/sqldb.py +++ b/controller/src/controller/db/sqldb.py @@ -74,14 +74,14 @@ class BaseSchema(Base): Base class for all tables. We use this class to define common columns and methods for all tables. - :arg uid: unique identifier for each entry. - :arg name: entry's name. - :arg description: The entry's description. + :arg uid: Unique identifier for each entry. + :arg name: Entry's name. + :arg description: The entry's description. The following columns are automatically added to each table: - created: The entry's creation date. - updated: The entry's last update date. - - spec: A dictionary to store additional information. + - spec: A dictionary to store additional information. """ __abstract__ = True @@ -194,8 +194,8 @@ class User(BaseSchema): """ The User table which is used to define users. - :arg full_name: The user's full name. - :arg email: The user's email. + :arg full_name: The user's full name. + :arg email: The user's email. """ # Columns: @@ -341,8 +341,8 @@ class Dataset(VersionedOwnerBaseSchema): """ The Dataset table which is used to define datasets for the project. - :arg project_id: The project's id. - :arg task: The task of the dataset. + :arg project_id: The project's id. + :arg task: The task of the dataset. """ # Columns: @@ -385,9 +385,9 @@ class Model(VersionedOwnerBaseSchema): """ The Model table which is used to define models for the project. - :arg project_id: The project's id. - :arg model_type: The type of the model. Can be one of the values in genai_factory.schemas.model.ModelType. - :arg task: The task of the model. For example, "classification", "text-generation", etc. + :arg project_id: The project's id. + :arg model_type: The type of the model. Can be one of the values in genai_factory.schemas.model.ModelType. + :arg task: The task of the model. For example, "classification", "text-generation", etc. """ # Columns: @@ -449,7 +449,7 @@ class PromptTemplate(VersionedOwnerBaseSchema): The PromptTemplate table which is used to define prompt templates for the project. Each prompt template is associated with a model. - :arg project_id: The project's id. + :arg project_id: The project's id. """ # Columns: @@ -504,9 +504,9 @@ class Document(VersionedOwnerBaseSchema): """ The Document table which is used to define documents for the project. The documents are ingested into data sources. - :arg project_id: The project's id. - :arg path: The path to the document. Can be a remote file or a web page. - :arg origin: The origin location of the document. + :arg project_id: The project's id. + :arg path: The path to the document. Can be a remote file or a web page. + :arg origin: The origin location of the document. """ # Columns: @@ -615,7 +615,7 @@ class Session(OwnerBaseSchema): """ The Chat Session table which is used to define chat sessions of an application workflow per user. - :arg workflow_id: The workflow's id. + :arg workflow_id: The workflow's id. """ # Columns: diff --git a/controller/src/controller/main.py b/controller/src/controller/main.py index 19b0808..422651b 100644 --- a/controller/src/controller/main.py +++ b/controller/src/controller/main.py @@ -15,6 +15,7 @@ # main file with cli commands using python click library # include two click commands: 1. data ingestion (using the data loader), 2. query (using the agent) import json +from typing import Optional import click import yaml @@ -29,7 +30,6 @@ Project, QueryItem, User, - Workflow, ) @@ -45,7 +45,7 @@ def initdb(): Initialize the database tables (delete old tables). """ click.echo("Running Init DB") - session = client.get_db_session() + db_session = client.get_db_session() client.create_tables(True) # Create admin user: @@ -57,7 +57,7 @@ def initdb(): full_name="Guest User", is_admin=True, ), - session=session, + db_session=db_session, ).uid # Create project: @@ -68,7 +68,7 @@ def initdb(): description="Default Project", owner_id=user_id, ), - session=session, + db_session=db_session, ).uid # Create data source: @@ -81,23 +81,9 @@ def initdb(): project_id=project_id, data_source_type="vector", ), - session=session, + db_session=db_session, ) - - # Create Workflow: - click.echo("Creating default workflow") - client.create_workflow( - Workflow( - name="default", - description="Default Workflow", - owner_id=user_id, - project_id=project_id, - workflow_type="application", - deployment="http://localhost:8000/api/workflows/default", - ), - session=session, - ) - session.close() + db_session.close() @click.command("config") @@ -132,14 +118,12 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil :param version: Version of the document :param data_source: Data source name :param from_file: Take the document paths from the file - - :return: None """ - session = client.get_db_session() - project = client.get_project(project_name=project, session=session) - data_source = client.list_data_sources( - project_id=project.uid, name=data_source, session=session - )[0] + db_session = client.get_db_session() + project = client.get_project(project_name=project, db_session=db_session) + data_source = client.get_data_source( + project_id=project.uid, name=data_source, db_session=db_session + ) # Create document from path: document = Document( @@ -153,7 +137,7 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil # Add document to the database: response = client.create_document( document=document, - session=session, + db_session=db_session, ) document = response.to_dict(to_datestr=True) @@ -206,23 +190,21 @@ def infer( """ Run a chat query on the data source - :param question: The question to ask - :param project: The project name - :param workflow_name: The workflow name - :param filter: Filter Key value pair - :param data_source: Data source name - :param user: The name of the user - :param session: The session name - - :return: None + :param question: The question to ask + :param project: The project name + :param workflow_name: The workflow name + :param filter: Filter Key value pair + :param data_source: Data source name + :param user: The name of the user + :param session: The session name """ db_session = client.get_db_session() - project = client.get_project(project_name=project, session=db_session) + project = client.get_project(project_name=project, db_session=db_session) # Getting the workflow: - workflow = client.list_workflows( - project_id=project.uid, name=workflow_name, session=db_session - )[0] + workflow = client.get_workflow( + project_id=project.uid, name=workflow_name, db_session=db_session + ) path = workflow.get_infer_path() query = QueryItem( @@ -270,10 +252,8 @@ def list_users(user, email): """ List all the users in the database - :param user: username filter - :param email: email filter - - :return: None + :param user: Username filter + :param email: Email filter """ click.echo("Running List Users") @@ -294,13 +274,11 @@ def list_data_sources(owner, project, version, source_type, metadata): """ List all the data sources in the database - :param owner: owner filter - :param project: project filter - :param version: version filter - :param source_type: data source type filter - :param metadata: metadata filter (labels) - - :return: None + :param owner: Owner filter + :param project: Project filter + :param version: Version filter + :param source_type: Data source type filter + :param metadata: Metadata filter (labels) """ click.echo("Running List Collections") if owner: @@ -333,50 +311,31 @@ def update_data_source(name, project, owner, description, source_type, labels): """ Create or update a data source in the database - :param name: data source name - :param project: project name - :param owner: owner name - :param description: data source description - :param source_type: type of data source - :param labels: metadata labels - - :return: None + :param name: Data source name + :param project: Project name + :param owner: Owner name + :param description: Data source description + :param source_type: Type of data source + :param labels: Metadata labels """ + owner = owner or "guest" click.echo("Running Create or Update Collection") labels = fill_params(labels) - session = client.get_db_session() - # check if the collection exists, if it does, update it, otherwise create it - project = client.get_project(project_name=project, session=session) - data_source = client.list_data_sources( - project_id=project.uid, - data_source_name=name, - session=session, - ) - - if data_source is not None: - client.update_data_source( - session=session, - collection=DataSource( - project_id=project.uid, - name=name, - description=description, - data_source_type=source_type, - labels=labels, - ), - ).with_raise() - else: - client.create_data_source( - session=session, - data_source=DataSource( - project_id=project.uid, - name=name, - description=description, - owner_name=owner, - data_source_type=source_type, - labels=labels, - ), - ).with_raise() + db_session = client.get_db_session() + project = client.get_project(project_name=project, db_session=db_session) + + client.update_data_source( + db_session=db_session, + collection=DataSource( + project_id=project.uid, + name=name, + description=description, + data_source_type=source_type, + labels=labels, + owner_id=client.get_user(username=owner).uid, + ), + ).with_raise() @click.command("sessions") @@ -387,17 +346,15 @@ def list_sessions(user, last, created): """ List chat sessions - :param user: username filter - :param last: last n sessions - :param created: created after date - - :return: None + :param user: Username filter + :param last: Last n sessions + :param created: Created after date """ click.echo("Running List Sessions") if user: user = client.get_user(user_name=user).uid - data = client.list_chat_sessions( + data = client.list_sessions( user_id=user, created_after=created, last=last, output_mode="short" ) table = format_table_results(data) @@ -408,9 +365,9 @@ def sources_to_text(sources) -> str: """ Convert a list of sources to a text string. - :param sources: list of sources + :param sources: List of sources - :return: text string + :return: Text string """ if not sources: return "" @@ -423,9 +380,9 @@ def sources_to_md(sources) -> str: """ Convert a list of sources to a markdown string. - :param sources: list of sources + :param sources: List of sources - :return: markdown string + :return: Markdown string """ if not sources: return "" @@ -441,9 +398,9 @@ def get_title(metadata) -> str: """ Get the title from the metadata. - :param metadata: metadata dictionary + :param metadata: Metadata dictionary - :return: title string + :return: Title string """ if "chunk" in metadata: return f"{metadata.get('title', '')}-{metadata['chunk']}" @@ -452,14 +409,14 @@ def get_title(metadata) -> str: return metadata.get("title", "") -def fill_params(params, params_dict=None) -> dict: +def fill_params(params, params_dict=None) -> Optional[dict]: """ Fill the parameters dictionary from a list of key=value strings. - :param params: list of key=value strings - :param params_dict: dictionary to fill + :param params: List of key=value strings + :param params_dict: Dictionary to fill - :return: filled dictionary + :return: Filled dictionary """ params_dict = params_dict or {} for param in params: @@ -479,9 +436,9 @@ def format_table_results(table_results): """ Format the table results as a printed table. - :param table_results: table results dictionary + :param table_results: Table results dictionary - :return: formatted table string + :return: Formatted table string """ return tabulate(table_results, headers="keys", tablefmt="fancy_grid") diff --git a/genai_factory/src/genai_factory/schemas/session.py b/genai_factory/src/genai_factory/schemas/session.py index c200564..4d17cfc 100644 --- a/genai_factory/src/genai_factory/schemas/session.py +++ b/genai_factory/src/genai_factory/schemas/session.py @@ -22,7 +22,7 @@ class QueryItem(BaseModel): question: str - session_id: Optional[str] = None + session_name: Optional[str] = None filter: Optional[List[Tuple[str, str]]] = None data_source: Optional[str] = None