From c442f183bd83351c6311f402d9a3798c2b7c9116 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Thu, 8 Aug 2024 16:44:13 +0300 Subject: [PATCH 01/10] Database tables --- controller/src/sqldb.py | 439 +++++++++++++++++++++++++++++----------- 1 file changed, 320 insertions(+), 119 deletions(-) diff --git a/controller/src/sqldb.py b/controller/src/sqldb.py index 08fea01..3484a9a 100644 --- a/controller/src/sqldb.py +++ b/controller/src/sqldb.py @@ -13,12 +13,30 @@ # limitations under the License. import datetime +import re +from typing import List, Optional -import sqlalchemy -from sqlalchemy import (JSON, Column, DateTime, ForeignKey, Index, Integer, - String, UniqueConstraint) +from sqlalchemy import ( + JSON, + Column, + ForeignKey, + Index, + Integer, + String, + Table, + UniqueConstraint, +) from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.orm import ( + Mapped, + declarative_base, + declared_attr, + mapped_column, + relationship, +) + +ID_LENGTH = 64 +TEXT_LENGTH = 1024 # Create a base class for declarative class definitions Base = declarative_base() @@ -51,146 +69,329 @@ def update_labels(obj, labels: dict): obj.labels.append(obj.Label(name=name, value=value, parent=obj.name)) -class User(Base): - __tablename__ = "users" +class BaseSchema(Base): + """ + Base class for all tables. + We use this class to define common columns and methods for all tables. - name = Column(String(255), primary_key=True, nullable=False) - email = Column(String(255), nullable=False, unique=True) - description = Column(String(255), nullable=True, default="") - full_name = Column(String(255), nullable=False) - created = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) - updated = Column( - DateTime, - default=datetime.datetime.utcnow, - onupdate=datetime.datetime.utcnow, - nullable=False, - ) - spec = Column(MutableDict.as_mutable(JSON), nullable=True) - Label = make_label(__tablename__) - labels = relationship(Label, cascade="all, delete-orphan") + :arg id: unique identifier for each entry. + :arg name: entry's name. + :arg description: The entry's description. + :arg owner_id: The entry's owner's id. + + The following columns are automatically added to each table: + - date_created: The entry's creation date. + - date_updated: The entry's last update date. + - spec: A dictionary to store additional information. + """ + __abstract__ = True -class ChatSessionContext(Base): - """Chat session context table CRUD""" + @declared_attr + def __tablename__(cls) -> str: + # Convert CamelCase class name to snake_case table name + return re.sub(r"(? Date: Thu, 8 Aug 2024 18:19:43 +0300 Subject: [PATCH 02/10] models --- controller/src/model.py | 137 +++++++++++++++++++++++++++++++--------- 1 file changed, 106 insertions(+), 31 deletions(-) diff --git a/controller/src/model.py b/controller/src/model.py index 854be10..8e65a7d 100644 --- a/controller/src/model.py +++ b/controller/src/model.py @@ -52,11 +52,10 @@ class ChatRole(str, Enum): class Message(BaseModel): role: ChatRole - content: str - html: Optional[str] = None + body: str + extra_data: Optional[dict] = None sources: Optional[List[dict]] = None - rating: Optional[int] = None - suggestion: Optional[str] = None + human_feedback: Optional[str] = None class Conversation(BaseModel): @@ -64,10 +63,10 @@ class Conversation(BaseModel): saved_index: int = 0 def __str__(self): - return "\n".join([f"{m.role}: {m.content}" for m in self.messages]) + return "\n".join([f"{m.role}: {m.body}" for m in self.messages]) - def add_message(self, role, content, sources=None): - self.messages.append(Message(role=role, content=content, sources=sources)) + def add_message(self, role, body, sources=None): + self.messages.append(Message(role=role, body=body, sources=sources)) def to_list(self): return self.dict()["messages"] @@ -90,17 +89,49 @@ class QueryItem(BaseModel): collection: Optional[str] = None +class OutputMode(str, Enum): + Names = "names" + Short = "short" + Dict = "dict" + Details = "details" + + +class DataSourceType(str, Enum): + relational = "relational" + vector = "vector" + graph = "graph" + key_value = "key-value" + column_family = "column-family" + storage = "storage" + other = "other" + + +class ModelType(str, Enum): + model = "model" + adapter = "adapter" + + +class WorkflowType(str, Enum): + ingestion = "ingestion" + application = "application" + data_processing = "data-processing" + training = "training" + evaluation = "evaluation" + + # ======================================================================================== metadata_fields = [ + "id", "name", "description", "labels", - "owner_name", + "owner_id", "created", "updated", "version", + "project_id", ] @@ -222,7 +253,11 @@ class BaseWithMetadata(Base): updated: Optional[Union[str, datetime]] = None -class BaseWithVerMetadata(BaseWithMetadata): +class BaseWithOwner(BaseWithMetadata): + owner_id: Optional[str] = None + + +class BaseWithVerMetadata(BaseWithOwner): version: Optional[str] = "" @@ -234,39 +269,79 @@ class User(BaseWithMetadata): full_name: Optional[str] = None features: Optional[dict[str, str]] = None policy: Optional[dict[str, str]] = None + is_admin: Optional[bool] = False -class DocCollection(BaseWithMetadata): - _top_level_fields = ["owner_name"] +class Project(BaseWithVerMetadata): + pass - owner_name: Optional[str] = None + +class DataSource(BaseWithVerMetadata): + _top_level_fields = ["data_source_type"] + + project_id: str + data_source_type: DataSourceType category: Optional[str] = None - db_args: Optional[dict[str, str]] = None + database_kwargs: Optional[dict[str, str]] = None -class ChatSession(BaseWithMetadata): - _extra_fields = ["history", "features", "state", "agent_name"] - _top_level_fields = ["username"] +class Dataset(BaseWithVerMetadata): + _top_level_fields = ["task"] - username: Optional[str] = None - agent_name: Optional[str] = None - history: Optional[List[Message]] = [] - features: Optional[dict[str, str]] = None - state: Optional[dict[str, str]] = None + project_id: str + task: str + sources: Optional[List[str]] = None + path: str + producer: Optional[str] = None - def to_conversation(self): - return Conversation.from_list(self.history) + +class Model(BaseWithVerMetadata): + _extra_fields = ["path", "producer", "deployment"] + _top_level_fields = ["model_type", "task"] + + project_id: str + model_type: ModelType + base_model: str + task: Optional[str] = None + path: Optional[str] = None + producer: Optional[str] = None + deployment: Optional[str] = None + + +class PromptTemplate(BaseWithVerMetadata): + _extra_fields = ["arguments"] + _top_level_fields = ["text"] + + project_id: str + text: str + arguments: Optional[List[str]] = None class Document(BaseWithVerMetadata): - collection: str - source: str + _top_level_fields = ["path", "origin"] + project_id: str + path: str origin: Optional[str] = None - num_chunks: Optional[int] = None -class OutputMode(str, Enum): - Names = "names" - Short = "short" - Dict = "dict" - Details = "details" +class Workflow(BaseWithVerMetadata): + _top_level_fields = ["workflow_type"] + + project_id: str + workflow_type: WorkflowType + workflow_function: Optional[str] = None + configuration: Optional[dict] = None + graph: Optional[dict] = None + deployment: Optional[str] = None + + +class ChatSession(BaseWithMetadata): + _extra_fields = ["history", "features", "state", "agent_name"] + _top_level_fields = ["username"] + + workflow_id: str + user_id: str + history: Optional[List[Message]] = [] + + def to_conversation(self): + return Conversation.from_list(self.history) From 98146e857de317ce5f4f61ff48d2064a053c777d Mon Sep 17 00:00:00 2001 From: yonishelach Date: Fri, 9 Aug 2024 11:35:01 +0300 Subject: [PATCH 03/10] client crud --- controller/src/model.py | 5 +- controller/src/sqlclient.py | 1061 +++++++++++++++++++++++++++++++---- 2 files changed, 948 insertions(+), 118 deletions(-) diff --git a/controller/src/model.py b/controller/src/model.py index 8e65a7d..8b43f4c 100644 --- a/controller/src/model.py +++ b/controller/src/model.py @@ -211,7 +211,7 @@ def merge_into_orm_object(self, orm_object): return orm_object - def to_orm_object(self, obj_class): + def to_orm_object(self, obj_class, uid=None): struct = self.to_dict(drop_none=False, short=False) obj_dict = { k: v @@ -225,6 +225,8 @@ def to_orm_object(self, obj_class): if k not in metadata_fields + self._top_level_fields } labels = obj_dict.pop("labels", None) + if uid: + obj_dict["id"] = uid obj = obj_class(**obj_dict) if labels: obj.labels.clear() @@ -246,6 +248,7 @@ def __str__(self): class BaseWithMetadata(Base): + id: str name: str description: Optional[str] = None labels: Optional[Dict[str, Union[str, None]]] = None diff --git a/controller/src/sqlclient.py b/controller/src/sqlclient.py index 2a386d9..6f34446 100644 --- a/controller/src/sqlclient.py +++ b/controller/src/sqlclient.py @@ -13,19 +13,24 @@ # limitations under the License. import datetime -from typing import Union +import uuid +from typing import List, Type, Union import sqlalchemy from sqlalchemy.orm import sessionmaker -from controller.src import model -from controller.src.model import ApiResponse - +import controller.src.sqldb as db +from controller.src import model as api_models from controller.src.config import config, logger -from controller.src.sqldb import Base, ChatSessionContext, DocumentCollection, User +from controller.src.model import ApiResponse +from controller.src.sqldb import Base class SqlClient: + """ + This is the SQL client that interact with the SQL database. + """ + def __init__(self, db_url: str, verbose: bool = False): self.db_url = db_url self.engine = sqlalchemy.create_engine( @@ -37,12 +42,32 @@ def __init__(self, db_url: str, verbose: bool = False): ) def get_db_session(self, session: sqlalchemy.orm.Session = None): + """ + Get a session from the session maker. + + :param session: The session to use. If None, a new session will be created. + + :return: The session. + """ return session or self._session_maker() def get_local_session(self): + """ + Get a local session from the local session maker. + + :return: The session. + """ return self._local_maker() - def create_tables(self, drop_old: bool = False, names: list = None): + def create_tables(self, drop_old: bool = False, names: list = None) -> ApiResponse: + """ + Create the tables in the database. + + :param drop_old: Whether to drop the old tables before creating the new ones. + :param names: The names of the tables to create. If None, all tables will be created. + + :return: A response object with the success status. + """ tables = None if names: tables = [Base.metadata.tables[name] for name in names] @@ -51,7 +76,65 @@ def create_tables(self, drop_old: bool = False, names: list = None): Base.metadata.create_all(self.engine, tables=tables, checkfirst=True) return ApiResponse(success=True) - def _update(self, session: sqlalchemy.orm.Session, db_class, api_object, **kwargs): + def _create(self, session: sqlalchemy.orm.Session, db_class, obj) -> ApiResponse: + """ + Create an object in the database. + This method generates a UID to the object and adds the object to the session and commits the transaction. + + :param session: The session to use. + :param db_class: The DB class of the object. + :param obj: The object to create. + + :return: A response object with the success status and the created object when successful. + """ + session = self.get_db_session(session) + try: + uid = uuid.uuid4().hex + db_object = obj.to_orm_object(db_class, uid=uid) + session.add(db_object) + session.commit() + return ApiResponse( + success=True, data=obj.__class__.from_orm_object(db_object) + ) + except sqlalchemy.exc.IntegrityError: + return ApiResponse( + success=False, error=f"{db_class} {obj.name} already exists" + ) + + def _get( + self, session: sqlalchemy.orm.Session, db_class, api_class, **kwargs + ) -> ApiResponse: + """ + Get an object from the database. + + :param session: The session to use. + :param db_class: The DB class of the object. + :param api_class: The API class of the object. + :param kwargs: The keyword arguments to filter the object. + + :return: A response object with the success status and the object when successful. + """ + session = self.get_db_session(session) + obj = session.query(db_class).filter_by(**kwargs).one_or_none() + if obj is None: + return ApiResponse( + success=False, error=f"{db_class} object ({kwargs}) not found" + ) + return ApiResponse(success=True, data=api_class.from_orm_object(obj)) + + def _update( + self, session: sqlalchemy.orm.Session, db_class, api_object, **kwargs + ) -> ApiResponse: + """ + Update an object in the database. + + :param session: The session to use. + :param db_class: The DB class of the object. + :param api_object: The API object with the new data. + :param kwargs: The keyword arguments to filter the object. + + :return: A response object with the success status and the updated object when successful. + """ session = self.get_db_session(session) obj = session.query(db_class).filter_by(**kwargs).one_or_none() if obj: @@ -66,7 +149,18 @@ def _update(self, session: sqlalchemy.orm.Session, db_class, api_object, **kwarg success=False, error=f"{db_class} object ({kwargs}) not found" ) - def _delete(self, session: sqlalchemy.orm.Session, db_class, **kwargs): + def _delete( + self, session: sqlalchemy.orm.Session, db_class, **kwargs + ) -> ApiResponse: + """ + Delete an object from the database. + + :param session: The session to use. + :param db_class: The DB class of the object. + :param kwargs: The keyword arguments to filter the object. + + :return: A response object with the success status. + """ session = self.get_db_session(session) query = session.query(db_class).filter_by(**kwargs) for obj in query: @@ -74,127 +168,832 @@ def _delete(self, session: sqlalchemy.orm.Session, db_class, **kwargs): session.commit() return ApiResponse(success=True) - def _get(self, session: sqlalchemy.orm.Session, db_class, api_class, **kwargs): - session = self.get_db_session(session) - obj = session.query(db_class).filter_by(**kwargs).one_or_none() - if obj is None: - return ApiResponse( - success=False, error=f"{db_class} object ({kwargs}) not found" - ) - return ApiResponse(success=True, data=api_class.from_orm_object(obj)) + def _list( + self, + session: sqlalchemy.orm.Session, + db_class: db.Base, + api_class: Type[api_models.Base], + output_mode: api_models.OutputMode, + labels_match: List[str] = None, + filters: list = None, + ) -> ApiResponse: + """ + List objects from the database. + + :param session: The session to use. + :param db_class: The DB class of the object. + :param api_class: The API class of the object. + :param output_mode: The output mode. + :param labels_match: The labels to match, filter the objects by labels. + :param filters: The filters to apply. - def _create(self, session: sqlalchemy.orm.Session, db_class, obj): + :return: A response object with the success status and the list of objects when successful. + """ session = self.get_db_session(session) - try: - db_object = obj.to_orm_object(db_class) - session.add(db_object) - session.commit() - return ApiResponse( - success=True, data=obj.__class__.from_orm_object(db_object) - ) - except sqlalchemy.exc.IntegrityError: - return ApiResponse( - success=False, error=f"{db_class} {obj.name} already exists" - ) - def get_user(self, username: str, session: sqlalchemy.orm.Session = None): - logger.debug(f"Getting user: username={username}") - return self._get(session, User, model.User, name=username) + query = session.query(db_class) + for filter_statement in filters: + query = query.filter(filter_statement) + # TODO: Implement labels_match + if labels_match: + logger.debug("Filtering projects by labels is not supported yet") + # query = self._filter_labels(query, sqldb.Project, labels_match) + pass + output = query.all() + logger.debug(f"output: {output}") + data = _process_output(output, api_class, output_mode) + return ApiResponse(success=True, data=data) + + def create_user( + self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Create a new user in the database. + + :param user: The user object to create. + :param session: The session to use. - def create_user(self, user: model.User, session: sqlalchemy.orm.Session = None): + :return: A response object with the success status and the created user when successful. + """ logger.debug(f"Creating user: {user}") + if isinstance(user, dict): + user = api_models.User.from_dict(user) user.name = user.name or user.email - return self._create(session, User, user) + return self._create(session, db.User, user) - def update_user(self, user: model.User, session: sqlalchemy.orm.Session = None): + def get_user( + self, user_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a user from the database. + + :param user_id: The ID of the user to get. + :param session: The session to use. + + :return: A response object with the success status and the user when successful. + """ + logger.debug(f"Getting user: user_id={user_id}") + return self._get(session, db.User, api_models.User, id=user_id) + + def update_user( + self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Update an existing user in the database. + + :param user: The user object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated user when successful. + """ logger.debug(f"Updating user: {user}") - return self._update(session, User, user, name=user.name) + if isinstance(user, dict): + user = api_models.User.from_dict(user) + return self._update(session, db.User, user, name=user.name) + + def delete_user( + self, user_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a user from the database. - def delete_user(self, username: str, session: sqlalchemy.orm.Session = None): - logger.debug(f"Deleting user: username={username}") - return self._delete(session, User, name=username) + :param user_id: The ID of the user to delete. + :param session: + :return: + """ + logger.debug(f"Deleting user: user_id={user_id}") + return self._delete(session, db.User, id=user_id) def list_users( self, email: str = None, full_name: str = None, labels_match: Union[list, str] = None, - output_mode: model.OutputMode = model.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, session: sqlalchemy.orm.Session = None, - ): + ) -> ApiResponse: + """ + List users from the database. + + :param email: The email to filter the users by. + :param full_name: The full name to filter the users by. + :param labels_match: The labels to match, filter the users by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of users when successful. + """ logger.debug( - f"Getting users: full_name~={full_name}, email={email}, mode={output_mode}" + f"Getting users: email={email}, full_name={full_name}, mode={output_mode}" ) - session = self.get_db_session(session) - query = session.query(User) + filters = [] if email: - query = query.filter(User.email == email) + filters.append(db.User.email == email) if full_name: - query = query.filter(User.full_name.like(f"%{full_name}%")) - data = _process_output(query.all(), model.User, output_mode) - return ApiResponse(success=True, data=data) + filters.append(db.User.full_name.like(f"%{full_name}%")) + return self._list( + session=session, + db_class=db.User, + api_class=api_models.User, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_project( + self, + project: Union[api_models.Project, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new project in the database. + + :param project: The project object to create. + :param session: The session to use. + + :return: A response object with the success status and the created project when successful. + """ + logger.debug(f"Creating project: {project}") + if isinstance(project, dict): + project = api_models.Project.from_dict(project) + return self._create(session, db.Project, project) + + def get_project( + self, project_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a project from the database. + + :param project_id: The ID of the project to get. + :param session: The session to use. + + :return: A response object with the success status and the project when successful. + """ + logger.debug(f"Getting project: project_id={project_id}") + return self._get(session, db.Project, api_models.Project, id=project_id) + + def update_project( + self, + project: Union[api_models.Project, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing project in the database. + + :param project: The project object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated project when successful. + """ + logger.debug(f"Updating project: {project}") + if isinstance(project, dict): + project = api_models.Project.from_dict(project) + return self._update(session, db.Project, project, id=project.id) + + def delete_project( + self, project_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a project from the database. + + :param project_id: The ID of the project to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug(f"Deleting project: project_id={project_id}") + return self._delete(session, db.Project, id=project_id) + + def list_projects( + self, + owner_id: str = None, + version: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + List projects from the database. + + :param owner_id: The owner to filter the projects by. + :param version: The version to filter the projects by. + :param labels_match: The labels to match, filter the projects by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of projects when successful. + """ + logger.debug( + f"Getting projects: owner_id={owner_id}, version={version}, labels_match={labels_match}, mode={output_mode}" + ) + filters = [] + if owner_id: + filters.append(db.Project.owner_id == owner_id) + if version: + filters.append(db.Project.version == version) + return self._list( + session=session, + db_class=db.User, + api_class=api_models.User, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_data_source( + self, + data_source: Union[api_models.DataSource, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new data source in the database. + + :param data_source: The data source object to create. + :param session: The session to use. + + :return: A response object with the success status and the created data source when successful. + """ + logger.debug(f"Creating data source: {data_source}") + if isinstance(data_source, dict): + data_source = api_models.DataSource.from_dict(data_source) + return self._create(session, db.DataSource, data_source) + + def get_data_source( + self, data_source_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a data source from the database. + + :param data_source_id: The ID of the data source to get. + :param session: The session to use. + + :return: A response object with the success status and the data source when successful. + """ + logger.debug(f"Getting data source: data_source_id={data_source_id}") + return self._get( + session, db.DataSource, api_models.DataSource, id=data_source_id + ) + + def update_data_source( + self, + data_source: Union[api_models.DataSource, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing data source in the database. + + :param data_source: The data source object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated data source when successful. + """ + logger.debug(f"Updating data source: {data_source}") + if isinstance(data_source, dict): + data_source = api_models.DataSource.from_dict(data_source) + return self._update(session, db.DataSource, data_source, id=data_source.id) + + def delete_data_source( + self, data_source_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a data source from the database. + + :param data_source_id: The ID of the data source to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug(f"Deleting data source: data_source_id={data_source_id}") + return self._delete(session, db.DataSource, id=data_source_id) + + def list_data_sources( + self, + owner_id: str = None, + version: str = None, + project_id: str = None, + data_source_type: Union[api_models.DataSourceType, str] = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + List data sources from the database. + + :param owner_id: The owner to filter the data sources by. + :param version: The version to filter the data sources by. + :param project_id: The project to filter the data sources by. + :param data_source_type: The data source type to filter the data sources by. + :param labels_match: The labels to match, filter the data sources by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of data sources when successful. + """ + logger.debug( + f"Getting collections: owner_id={owner_id}, version={version}, data_source_type={data_source_type}," + f" labels_match={labels_match}, mode={output_mode}" + ) + filters = [] + if owner_id: + filters.append(db.DataSource.owner_id == owner_id) + if version: + filters.append(db.DataSource.version == version) + if project_id: + filters.append(db.DataSource.project_id == project_id) + if data_source_type: + filters.append(db.DataSource.data_source_type == data_source_type) + return self._list( + session=session, + db_class=db.User, + api_class=api_models.User, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_dataset( + self, + dataset: Union[api_models.Dataset, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new dataset in the database. + + :param dataset: The dataset object to create. + :param session: The session to use. + + :return: A response object with the success status and the created dataset when successful. + """ + logger.debug(f"Creating dataset: {dataset}") + if isinstance(dataset, dict): + dataset = api_models.Dataset.from_dict(dataset) + return self._create(session, db.Dataset, dataset) + + def get_dataset( + self, dataset_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a dataset from the database. + + :param dataset_id: The ID of the dataset to get. + :param session: The session to use. + + :return: A response object with the success status and the dataset when successful. + """ + logger.debug(f"Getting dataset: dataset_id={dataset_id}") + return self._get(session, db.Dataset, api_models.Dataset, id=dataset_id) + + def update_dataset( + self, + dataset: Union[api_models.Dataset, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing dataset in the database. + + :param dataset: The dataset object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated dataset when successful. + """ + logger.debug(f"Updating dataset: {dataset}") + if isinstance(dataset, dict): + dataset = api_models.Dataset.from_dict(dataset) + return self._update(session, db.Dataset, dataset, id=dataset.id) + + def delete_dataset( + self, dataset_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a dataset from the database. + + :param dataset_id: The ID of the dataset to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug(f"Deleting dataset: dataset_id={dataset_id}") + return self._delete(session, db.Dataset, id=dataset_id) + + def list_datasets( + self, + owner_id: str = None, + version: str = None, + project_id: str = None, + task: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + List datasets from the database. + + :param owner_id: The owner to filter the datasets by. + :param version: The version to filter the datasets by. + :param project_id: The project to filter the datasets by. + :param task: The task to filter the datasets by. + :param labels_match: The labels to match, filter the datasets by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of datasets when successful. + """ + logger.debug( + f"Getting datasets: owner_id={owner_id}, version={version}, task={task}, labels_match={labels_match}," + f" mode={output_mode}" + ) + filters = [] + if owner_id: + filters.append(db.Dataset.owner_id == owner_id) + if version: + filters.append(db.Dataset.version == version) + if project_id: + filters.append(db.Dataset.project_id == project_id) + if task: + filters.append(db.Dataset.task == task) + return self._list( + session=session, + db_class=db.Dataset, + api_class=api_models.Dataset, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_model( + self, + model: Union[api_models.Model, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new model in the database. + + :param model: The model object to create. + :param session: The session to use. - def get_collection(self, name: str, session: sqlalchemy.orm.Session = None): - logger.debug(f"Getting collection: name={name}") - return self._get(session, DocumentCollection, model.DocCollection, name=name) - - def create_collection( - self, collection: model.DocCollection, session: sqlalchemy.orm.Session = None - ): - logger.debug(f"Creating collection: {collection}") - if isinstance(collection, dict): - collection = model.DocCollection.from_dict(collection) - return self._create(session, DocumentCollection, collection) - - def update_collection( - self, collection: model.DocCollection, session: sqlalchemy.orm.Session = None - ): - logger.debug(f"Updating collection: {collection}") - if isinstance(collection, dict): - collection = model.DocCollection.from_dict(collection) + :return: A response object with the success status and the created model when successful. + """ + logger.debug(f"Creating model: {model}") + if isinstance(model, dict): + model = api_models.Model.from_dict(model) + return self._create(session, db.Model, model) + + def get_model( + self, model_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a model from the database. + + :param model_id: The ID of the model to get. + :param session: The session to use. + + :return: A response object with the success status and the model when successful. + """ + logger.debug(f"Getting model: model_id={model_id}") + return self._get(session, db.Model, api_models.Model, id=model_id) + + def update_model( + self, + model: Union[api_models.Model, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing model in the database. + + :param model: The model object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated model when successful. + """ + logger.debug(f"Updating model: {model}") + if isinstance(model, dict): + model = api_models.Model.from_dict(model) + return self._update(session, db.Model, model, id=model.id) + + def delete_model( + self, model_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a model from the database. + + :param model_id: The ID of the model to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug(f"Deleting model: model_id={model_id}") + return self._delete(session, db.Model, id=model_id) + + def list_models( + self, + owner_id: str = None, + version: str = None, + project_id: str = None, + model_type: str = None, + task: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + List models from the database. + + :param owner_id: The owner to filter the models by. + :param version: The version to filter the models by. + :param project_id: The project to filter the models by. + :param model_type: The model type to filter the models by. + :param task: The task to filter the models by. + :param labels_match: The labels to match, filter the models by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of models when successful. + """ + logger.debug( + f"Getting models: owner_id={owner_id}, version={version}, project_id={project_id}," + f" model_type={model_type}, task={task}, labels_match={labels_match}, mode={output_mode}" + ) + filters = [] + if owner_id: + filters.append(db.Model.owner_id == owner_id) + if version: + filters.append(db.Model.version == version) + if project_id: + filters.append(db.Model.project_id == project_id) + if model_type: + filters.append(db.Model.model_type == model_type) + if task: + filters.append(db.Model.task == task) + return self._list( + session=session, + db_class=db.Model, + api_class=api_models.Model, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_prompt_template( + self, + prompt_template: Union[api_models.PromptTemplate, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new prompt template in the database. + + :param prompt_template: The prompt template object to create. + :param session: The session to use. + + :return: A response object with the success status and the created prompt template when successful. + """ + logger.debug(f"Creating prompt template: {prompt_template}") + if isinstance(prompt_template, dict): + prompt_template = api_models.PromptTemplate.from_dict(prompt_template) + return self._create(session, db.PromptTemplate, prompt_template) + + def get_prompt_template( + self, prompt_template_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a prompt template from the database. + + :param prompt_template_id: The ID of the prompt template to get. + :param session: The session to use. + + :return: A response object with the success status and the prompt template when successful. + """ + logger.debug( + f"Getting prompt template: prompt_template_id={prompt_template_id}" + ) + return self._get( + session, db.PromptTemplate, api_models.PromptTemplate, id=prompt_template_id + ) + + def update_prompt_template( + self, + prompt_template: Union[api_models.PromptTemplate, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing prompt template in the database. + + :param prompt_template: The prompt template object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated prompt template when successful. + """ + logger.debug(f"Updating prompt template: {prompt_template}") + if isinstance(prompt_template, dict): + prompt_template = api_models.PromptTemplate.from_dict(prompt_template) return self._update( - session, DocumentCollection, collection, name=collection.name + session, db.PromptTemplate, prompt_template, id=prompt_template.id ) - def delete_collection(self, name: str, session: sqlalchemy.orm.Session = None): - logger.debug(f"Deleting collection: name={name}") - return self._delete(session, DocumentCollection, name=name) + def delete_prompt_template( + self, prompt_template_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a prompt template from the database. - def list_collections( + :param prompt_template_id: The ID of the prompt template to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug( + f"Deleting prompt template: prompt_template_id={prompt_template_id}" + ) + return self._delete(session, db.PromptTemplate, id=prompt_template_id) + + def list_prompt_templates( self, - owner: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, labels_match: Union[list, str] = None, - output_mode: model.OutputMode = model.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, session: sqlalchemy.orm.Session = None, - ): + ) -> ApiResponse: + """ + List prompt templates from the database. + + :param owner_id: The owner to filter the prompt templates by. + :param version: The version to filter the prompt templates by. + :param project_id: The project to filter the prompt templates by. + :param labels_match: The labels to match, filter the prompt templates by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of prompt templates when successful. + """ logger.debug( - f"Getting collections: owner={owner}, labels_match={labels_match}, mode={output_mode}" + f"Getting prompt templates: owner_id={owner_id}, version={version}, project_id={project_id}," + f" labels_match={labels_match}, mode={output_mode}" ) - session = self.get_db_session(session) - query = session.query(DocumentCollection) - if owner: - query = query.filter(DocumentCollection.owner_name == owner) - if labels_match: - pass - data = _process_output(query.all(), model.DocCollection, output_mode) - return ApiResponse(success=True, data=data) + filters = [] + if owner_id: + filters.append(db.PromptTemplate.owner_id == owner_id) + if version: + filters.append(db.PromptTemplate.version == version) + if project_id: + filters.append(db.PromptTemplate.project_id == project_id) + return self._list( + session=session, + db_class=db.PromptTemplate, + api_class=api_models.PromptTemplate, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_document( + self, + document: Union[api_models.Document, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new document in the database. + + :param document: The document object to create. + :param session: The session to use. + + :return: A response object with the success status and the created document when successful. + """ + logger.debug(f"Creating document: {document}") + if isinstance(document, dict): + document = api_models.Document.from_dict(document) + return self._create(session, db.Document, document) + + def get_document( + self, document_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Get a document from the database. + + :param document_id: The ID of the document to get. + :param session: The session to use. + + :return: A response object with the success status and the document when successful. + """ + logger.debug(f"Getting document: document_id={document_id}") + return self._get(session, db.Document, api_models.Document, id=document_id) + + def update_document( + self, + document: Union[api_models.Document, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing document in the database. + + :param document: The document object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated document when successful. + """ + logger.debug(f"Updating document: {document}") + if isinstance(document, dict): + document = api_models.Document.from_dict(document) + return self._update(session, db.Document, document, id=document.id) - def get_session( + def delete_document( + self, document_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a document from the database. + + :param document_id: The ID of the document to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug(f"Deleting document: document_id={document_id}") + return self._delete(session, db.Document, id=document_id) + + def list_documents( + self, + owner_id: str = None, + version: str = None, + project_id: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + List documents from the database. + + :param owner_id: The owner to filter the documents by. + :param version: The version to filter the documents by. + :param project_id: The project to filter the documents by. + :param labels_match: The labels to match, filter the documents by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of documents when successful. + """ + logger.debug( + f"Getting documents: owner_id={owner_id}, version={version}, project_id={project_id}," + f" labels_match={labels_match}, mode={output_mode}" + ) + filters = [] + if owner_id: + filters.append(db.Document.owner_id == owner_id) + if version: + filters.append(db.Document.version == version) + if project_id: + filters.append(db.Document.project_id == project_id) + return self._list( + session=session, + db_class=db.Document, + api_class=api_models.Document, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + + def create_chat_session( + self, + chat_session: Union[api_models.ChatSession, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new chat session in the database. + + :param chat_session: The chat session object to create. + :param session: The session to use. + + :return: A response object with the success status and the created chat session when successful. + """ + logger.debug(f"Creating chat session: {chat_session}") + if isinstance(chat_session, dict): + chat_session = api_models.ChatSession.from_dict(chat_session) + return self._create(session, db.Session, chat_session) + + def get_chat_session( self, session_id: str, - username: str = None, + user_id: str = None, session: sqlalchemy.orm.Session = None, - ): + ) -> ApiResponse: + """ + Get a chat session from the database. + + :param session_id: The ID of the chat session to get. + :param user_id: The ID of the user to get the last session for. + :param session: The DB session to use. + + :return: A response object with the success status and the chat session when successful. + """ logger.debug( - f"Getting chat session: session_id={session_id}, username={username}" + f"Getting chat session: session_id={session_id}, user_id={user_id}" ) if session_id: return self._get( - session, ChatSessionContext, model.ChatSession, name=session_id + session, db.Session, api_models.ChatSession, name=session_id ) - elif username: + elif user_id: # get the last session for the user - resp = self.list_sessions(username=username, last=1, session=session) + resp = self.list_chat_sessions(user_id=user_id, last=1, session=session) if resp.success: data = resp.data[0] if resp.data else None return ApiResponse(success=True, data=data) @@ -204,49 +1003,77 @@ def get_session( success=False, error="session_id or username must be provided" ) - def create_session( - self, chat_session: model.ChatSession, session: sqlalchemy.orm.Session = None - ): - logger.debug(f"Creating chat session: {chat_session}") - return self._create(session, ChatSessionContext, chat_session) + def update_chat_session( + self, + chat_session: Union[api_models.ChatSession, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update a chat session in the database. - def update_session( - self, chat_session: model.ChatSession, session: sqlalchemy.orm.Session = None - ): + :param chat_session: The chat session object with the new data. + :param session: The DB session to use. + + :return: A response object with the success status and the updated chat session when successful. + """ logger.debug(f"Updating chat session: {chat_session}") - return self._update( - session, ChatSessionContext, chat_session, name=chat_session.name - ) + return self._update(session, db.Session, chat_session, id=chat_session.id) + + def delete_chat_session( + self, session_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a chat session from the database. - def delete_session(self, session_id: str, session: sqlalchemy.orm.Session = None): + :param session_id: The ID of the chat session to delete. + :param session: The DB session to use. + + :return: A response object with the success status. + """ logger.debug(f"Deleting chat session: session_id={session_id}") - return self._delete(session, ChatSessionContext, name=session_id) + return self._delete(session, db.Session, id=session_id) - def list_sessions( + def list_chat_sessions( self, - username: str = None, + user_id: str = None, + workflow_id: str = None, created_after=None, last=0, - output_mode: model.OutputMode = model.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, session: sqlalchemy.orm.Session = None, - ): + ) -> ApiResponse: + """ + List chat sessions from the database. + + :param user_id: The user ID to filter the chat sessions by. + :param workflow_id: The workflow ID to filter the chat sessions by. + :param created_after: The date to filter the chat sessions by. + :param last: The number of last chat sessions to return. + :param output_mode: The output mode. + :param session: The DB session to use. + + :return: A response object with the success status and the list of chat sessions when successful. + """ logger.debug( - f"Getting chat sessions: username={username}, created>{created_after}, last={last}, mode={output_mode}" + f"Getting chat sessions: user_id={user_id}, workflow_id={workflow_id} created>{created_after}," + f" last={last}, mode={output_mode}" ) session = self.get_db_session(session) - query = session.query(ChatSessionContext) - if username: - query = query.filter(ChatSessionContext.username == username) + query = session.query(db.Session) + if user_id: + query = query.filter(db.Session.user_id == user_id) + if workflow_id: + query = query.filter(db.Session.workflow_id == workflow_id) if created_after: if isinstance(created_after, str): created_after = datetime.datetime.strptime( created_after, "%Y-%m-%d %H:%M" ) - query = query.filter(ChatSessionContext.created >= created_after) - query = query.order_by(ChatSessionContext.updated.desc()) + query = query.filter(db.Session.created >= created_after) + query = query.order_by(db.Session.updated.desc()) if last > 0: query = query.limit(last) - data = _process_output(query.all(), model.ChatSession, output_mode) + data = _process_output(query.all(), api_models.ChatSession, output_mode) return ApiResponse(success=True, data=data) @@ -257,14 +1084,14 @@ def _dict_to_object(cls, d): def _process_output( - items, obj_class, mode: model.OutputMode = model.OutputMode.Details + items, obj_class, mode: api_models.OutputMode = api_models.OutputMode.Details ): - if mode == model.OutputMode.Names: + if mode == api_models.OutputMode.Names: return [item.name for item in items] items = [obj_class.from_orm_object(item) for item in items] - if mode == model.OutputMode.Details: + if mode == api_models.OutputMode.Details: return items - short = mode == model.OutputMode.Short + short = mode == api_models.OutputMode.Short return [item.to_dict(short=short) for item in items] From 9ee48721abf5ad4dad2c66d983fd594970826abb Mon Sep 17 00:00:00 2001 From: yonishelach Date: Tue, 13 Aug 2024 14:10:20 +0300 Subject: [PATCH 04/10] stabilized api and updated cli --- Makefile | 19 + controller/src/api.py | 1296 +++++++++++++++++++++++++++++++---- controller/src/main.py | 346 ++++++++-- controller/src/model.py | 44 +- controller/src/sqlclient.py | 349 ++++++++-- controller/src/sqldb.py | 295 +++++++- 6 files changed, 2041 insertions(+), 308 deletions(-) diff --git a/Makefile b/Makefile index 6907ac6..78bd434 100644 --- a/Makefile +++ b/Makefile @@ -24,3 +24,22 @@ controller: # Announce the server is running: @echo "GenAI Factory Controller is running in the background" + +.PHONY: fmt +fmt: ## Format the code using Ruff + @echo "Running ruff checks and fixes..." + python -m ruff check --fix-only + python -m ruff format + +.PHONY: lint +lint: fmt-check lint-imports ## Run lint on the code + +lint-imports: ## Validates import dependencies + @echo "Running import linter" + lint-imports + +.PHONY: fmt-check +fmt-check: ## Check the code (using ruff) + @echo "Running ruff checks..." + python -m ruff check --exit-non-zero-on-fix + python -m ruff format --check \ No newline at end of file diff --git a/controller/src/api.py b/controller/src/api.py index a32c782..bf2e8e4 100644 --- a/controller/src/api.py +++ b/controller/src/api.py @@ -12,16 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import os from typing import List, Optional, Tuple, Union import requests -from fastapi import (APIRouter, Depends, FastAPI, File, Header, Request, - UploadFile) +from fastapi import APIRouter, Depends, FastAPI, Header, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from controller.src.config import config -from controller.src.model import ChatSession, DocCollection, OutputMode, User, QueryItem +from controller.src.model import ( + ApiResponse, + ChatSession, + Dataset, + DataSource, + DataSourceType, + Document, + Model, + OutputMode, + Project, + PromptTemplate, + QueryItem, + User, + Workflow, +) from controller.src.sqlclient import client app = FastAPI() @@ -68,7 +82,9 @@ def get_auth_user( return AuthInfo(username="guest@example.com", token=token) -def _send_to_application(path: str, method: str = "POST", request=None, auth=None, **kwargs): +def _send_to_application( + path: str, method: str = "POST", request=None, auth=None, **kwargs +): """ Send a request to the application's API. @@ -80,7 +96,10 @@ def _send_to_application(path: str, method: str = "POST", request=None, auth=Non :return: The JSON response from the application. """ - url = f"{config.application_url}/api/{path}" + if config.application_url not in path: + url = f"{config.application_url}/api/{path}" + else: + url = path if isinstance(request, Request): # If the request is a FastAPI request, get the data from the body @@ -108,169 +127,1216 @@ def create_tables(drop_old: bool = False, names: list[str] = None): return client.create_tables(drop_old=drop_old, names=names) -@router.post("/pipeline/{name}/run") -def run_pipeline( - request: Request, item: QueryItem, name: str, auth=Depends(get_auth_user) -): - """This is the query command""" - return _send_to_application( - path=f"pipeline/{name}/run", - method="POST", - request=request, - auth=auth, +@router.post("/users") +def create_user( + user: User, + session=Depends(get_db), +) -> ApiResponse: + """ + Create a new user in the database. + + :param user: The user to create. + :param session: The database session. + + :return: The response from the database. + """ + return client.create_user(user=user, session=session) + + +@router.get("/users/{user_name}") +def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> ApiResponse: + """ + Get a user from the database. + + :param user_name: The name of the user to get. + :param email: The email address to get the user by if the name is not provided. + :param session: The database session. + + :return: The user from the database. + """ + return client.get_user(user_name=user_name, email=email, session=session) + + +@router.put("/users/{user_name}") +def update_user( + user: User, + user_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a user in the database. + + :param user: The user to update. + :param user_name: The name of the user to update. + :param session: The database session. + + :return: The response from the database. + """ + if user_name != user.name: + raise ValueError(f"User name does not match: {user_name} != {user.name}") + return client.update_user(user=user, session=session) + + +@router.delete("/users/{user_name}") +def delete_user(user_name: str, session=Depends(get_db)) -> ApiResponse: + """ + Delete a user from the database. + + :param user_name: The name of the user to delete. + :param session: The database session. + + :return: The response from the database. + """ + return client.delete_user(user_name=user_name, session=session) + + +@router.get("/users") +def list_users( + email: str = None, + full_name: str = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), +) -> ApiResponse: + """ + List users in the database. + + :param email: The email address to filter by. + :param full_name: The full name to filter by. + :param mode: The output mode. + :param session: The database session. + + :return: The response from the database. + """ + return client.list_users( + email=email, full_name=full_name, output_mode=mode, session=session ) -@router.post("/collections/{collection}/{loader}/ingest") -def ingest( - collection, path, loader, metadata, version, from_file, auth=Depends(get_auth_user) -): - """Ingest documents into the vector database""" - params = { - "path": path, - "from_file": from_file, - "version": version, - } - if metadata is not None: - params["metadata"] = json.dumps(metadata) +@router.post("/projects") +def create_project( + project: Project, + session=Depends(get_db), +) -> ApiResponse: + """ + Create a new project in the database. - return _send_to_application( - path=f"collections/{collection}/{loader}/ingest", - method="POST", - params=params, - auth=auth, + :param project: The project to create. + :param session: The database session. + + :return: The response from the database. + """ + return client.create_project(project=project, session=session) + + +@router.get("/projects/{project_name}") +def get_project(project_name: str, session=Depends(get_db)) -> ApiResponse: + """ + Get a project from the database. + + :param project_name: The name of the project to get. + :param session: The database session. + + :return: The project from the database. + """ + return client.get_project(project_name=project_name, session=session) + + +@router.put("/projects/{project_name}") +def update_project( + project: Project, + project_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a project in the database. + + :param project: The project to update. + :param project_name: The name of the project to update. + :param session: The database session. + + :return: The response from the database. + """ + if project_name != project.name: + raise ValueError( + f"Project name does not match: {project_name} != {project.name}" + ) + return client.update_project(project=project, session=session) + + +@router.delete("/projects/{project_name}") +def delete_project(project_name: str, session=Depends(get_db)) -> ApiResponse: + """ + Delete a project from the database. + + :param project_name: The name of the project to delete. + :param session: The database session. + + :return: The response from the database. + """ + return client.delete_project(project_name=project_name, session=session) + + +@router.get("/projects") +def list_projects( + owner_name: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), +) -> ApiResponse: + """ + List projects in the database. + + :param owner_name: The name of the owner to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + + :return: The response from the database. + """ + if owner_name is not None: + owner_id = client.get_user(user_name=owner_name, session=session).data["id"] + else: + owner_id = None + return client.list_projects( + owner_id=owner_id, labels_match=labels, output_mode=mode, session=session ) -@router.get("/collections") -def list_collections( - owner: str = None, +@router.post("projects/{project_name}/data_sources/{data_source_name}") +def create_data_source( + project_name: str, + data_source_name: str, + data_source: DataSource, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new data source in the database. + + :param project_name: The name of the project to create the data source in. + :param data_source_name: The name of the data source to create. + :param data_source: The data source to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if data_source.owner_id is None: + data_source.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + data_source.name = data_source_name + data_source.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_data_source(data_source=data_source, session=session) + + +@router.get("projects/{project_name}/data_sources/{data_source_name}") +def get_data_source( + project_name: str, data_source_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a data source from the database. + + :param project_name: The name of the project to get the data source from. + :param data_source_name: The name of the data source to get. + :param session: The database session. + + :return: The data source from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_data_source( + project_id=project_id, data_source_name=data_source_name, session=session + ) + + +@router.put("projects/{project_name}/data_sources/{data_source_name}") +def update_data_source( + project_name: str, + data_source: DataSource, + data_source_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a data source in the database. + + :param project_name: The name of the project to update the data source in. + :param data_source: The data source to update. + :param data_source_name: The name of the data source to update. + :param session: The database session. + + :return: The response from the database. + """ + data_source.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if data_source_name != data_source.name: + raise ValueError( + f"Data source name does not match: {data_source_name} != {data_source.name}" + ) + return client.update_data_source(data_source=data_source, session=session) + + +@router.delete("projects/{project_name}/data_sources/{data_source_id}") +def delete_data_source( + project_name: str, data_source_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a data source from the database. + + :param project_name: The name of the project to delete the data source from. + :param data_source_id: The ID of the data source to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_data_source( + project_id=project_id, data_source_id=data_source_id, session=session + ) + + +@router.get("projects/{project_name}/data_sources") +def list_data_sources( + project_name: str, + version: str = None, + data_source_type: Union[DataSourceType, str] = None, labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.Details, session=Depends(get_db), -): - return client.list_collections( - owner=owner, labels_match=labels, output_mode=mode, session=session + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List data sources in the database. + + :param project_name: The name of the project to list the data sources from. + :param version: The version to filter by. + :param data_source_type: The data source type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_data_sources( + project_id=project_id, + owner_id=owner_id, + version=version, + data_source_type=data_source_type, + labels_match=labels, + output_mode=mode, + session=session, ) -@router.get("/collection/{name}") -def get_collection(name: str, session=Depends(get_db)): - return client.get_collection(name, session=session) +@router.post("/projects/{project_name}/datasets/{dataset_name}") +def create_dataset( + project_name: str, + dataset_name: str, + dataset: Dataset, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new dataset in the database. + + :param project_name: The name of the project to create the dataset in. + :param dataset_name: The name of the dataset to create. + :param dataset: The dataset to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if dataset.owner_id is None: + dataset.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + dataset.name = dataset_name + dataset.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_dataset(dataset=dataset, session=session) + + +@router.get("/projects/{project_name}/datasets/{dataset_name}") +def get_dataset( + project_name: str, dataset_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a dataset from the database. + + :param project_name: The name of the project to get the dataset from. + :param dataset_name: The name of the dataset to get. + :param session: The database session. + + :return: The dataset from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_dataset( + project_id=project_id, dataset_name=dataset_name, session=session + ) -@router.post("/collection/{name}") -def create_collection( - request: Request, - name: str, - collection: DocCollection, +@router.put("/projects/{project_name}/datasets/{dataset_name}") +def update_dataset( + project_name: str, + dataset: Dataset, + dataset_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a dataset in the database. + + :param project_name: The name of the project to update the dataset in. + :param dataset: The dataset to update. + :param dataset_name: The name of the dataset to update. + :param session: The database session. + + :return: The response from the database. + """ + dataset.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if dataset_name != dataset.name: + raise ValueError( + f"Dataset name does not match: {dataset_name} != {dataset.name}" + ) + return client.update_dataset(dataset=dataset, session=session) + + +@router.delete("/projects/{project_name}/datasets/{dataset_id}") +def delete_dataset( + project_name: str, dataset_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a dataset from the database. + + :param project_name: The name of the project to delete the dataset from. + :param dataset_id: The ID of the dataset to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_dataset( + project_id=project_id, dataset_id=dataset_id, session=session + ) + + +@router.get("/projects/{project_name}/datasets") +def list_datasets( + project_name: str, + version: str = None, + task: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -): - collection.owner_name = auth.username - return client.create_collection(collection, session=session) +) -> ApiResponse: + """ + List datasets in the database. + :param project_name: The name of the project to list the datasets from. + :param version: The version to filter by. + :param task: The task to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. -@router.get("/users") -def list_users( - email: str = None, - username: str = None, + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_datasets( + project_id=project_id, + owner_id=owner_id, + version=version, + task=task, + labels_match=labels, + output_mode=mode, + session=session, + ) + + +@router.post("/projects/{project_name}/models/{model_name}") +def create_model( + project_name: str, + model_name: str, + model: Model, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new model in the database. + + :param project_name: The name of the project to create the model in. + :param model_name: The name of the model to create. + :param model: The model to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if model.owner_id is None: + model.owner_id = client.get_user(user_name=auth.username, session=session).data[ + "id" + ] + model.name = model_name + model.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_model(model=model, session=session) + + +@router.get("/projects/{project_name}/models/{model_name}") +def get_model( + project_name: str, model_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a model from the database. + + :param project_name: The name of the project to get the model from. + :param model_name: The name of the model to get. + :param session: The database session. + + :return: The model from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_model( + project_id=project_id, model_name=model_name, session=session + ) + + +@router.put("/projects/{project_name}/models/{model_name}") +def update_model( + project_name: str, + model: Model, + model_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a model in the database. + + :param project_name: The name of the project to update the model in. + :param model: The model to update. + :param model_name: The name of the model to update. + :param session: The database session. + + :return: The response from the database. + """ + model.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if model_name != model.name: + raise ValueError(f"Model name does not match: {model_name} != {model.name}") + return client.update_model(model=model, session=session) + + +@router.delete("/projects/{project_name}/models/{model_id}") +def delete_model( + project_name: str, model_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a model from the database. + + :param project_name: The name of the project to delete the model from. + :param model_id: The ID of the model to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_model( + project_id=project_id, model_id=model_id, session=session + ) + + +@router.get("/projects/{project_name}/models") +def list_models( + project_name: str, + version: str = None, + model_type: str = None, + labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.Details, session=Depends(get_db), -): - return client.list_users( - email=email, full_name=username, output_mode=mode, session=session + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List models in the database. + + :param project_name: The name of the project to list the models from. + :param version: The version to filter by. + :param model_type: The model type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_models( + project_id=project_id, + owner_id=owner_id, + version=version, + model_type=model_type, + labels_match=labels, + output_mode=mode, + session=session, ) -@router.get("/user/{username}") -def get_user(username: str, session=Depends(get_db)): - return client.get_user(username, session=session) +@router.post("/projects/{project_name}/prompt_templates/{prompt_name}") +def create_prompt( + project_name: str, + prompt_name: str, + prompt: PromptTemplate, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new prompt in the database. + + :param project_name: The name of the project to create the prompt in. + :param prompt_name: The name of the prompt to create. + :param prompt: The prompt to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if prompt.owner_id is None: + prompt.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + prompt.name = prompt_name + prompt.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_prompt_template(prompt=prompt, session=session) -@router.post("/user/{username}") -def create_user( - user: User, - username: str, +@router.get("/projects/{project_name}/prompt_templates/{prompt_name}") +def get_prompt( + project_name: str, prompt_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a prompt from the database. + + :param project_name: The name of the project to get the prompt from. + :param prompt_name: The name of the prompt to get. + :param session: The database session. + + :return: The prompt from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_prompt( + project_id=project_id, prompt_name=prompt_name, session=session + ) + + +@router.put("/projects/{project_name}/prompt_templates/{prompt_name}") +def update_prompt( + project_name: str, + prompt: PromptTemplate, + prompt_name: str, session=Depends(get_db), -): - """This is the user command""" - return client.create_user(user, session=session) +) -> ApiResponse: + """ + Update a prompt in the database. + :param project_name: The name of the project to update the prompt in. + :param prompt: The prompt to update. + :param prompt_name: The name of the prompt to update. + :param session: The database session. -@router.delete("/user/{username}") -def delete_user(username: str, session=Depends(get_db)): - return client.delete_user(username, session=session) + :return: The response from the database. + """ + prompt.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if prompt_name != prompt.name: + raise ValueError(f"Prompt name does not match: {prompt_name} != {prompt.name}") + return client.update_prompt_template(prompt=prompt, session=session) -# get last user sessions, specify user and last -@router.get("/user/{username}/sessions") -def list_user_sessions( - username: str, - last: int = 0, - created: str = None, +@router.delete("/projects/{project_name}/prompt_templates/{prompt_template_id}") +def delete_prompt( + project_name: str, prompt_template_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a prompt from the database. + + :param project_name: The name of the project to delete the prompt from. + :param prompt_template_id: The ID of the prompt to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_prompt_template( + project_id=project_id, prompt_template_id=prompt_template_id, session=session + ) + + +@router.get("/projects/{project_name}/prompt_templates") +def list_prompts( + project_name: str, + version: str = None, + labels: Optional[List[Tuple[str, str]]] = None, mode: OutputMode = OutputMode.Details, session=Depends(get_db), -): - return client.list_sessions( - username, created_after=created, last=last, output_mode=mode, session=session + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List prompts in the database. + + :param project_name: The name of the project to list the prompts from. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_prompt_templates( + project_id=project_id, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, ) -@router.put("/user/{username}") -def update_user( - user: User, - username: str, +@router.post("/projects/{project_name}/documents/{document_name}") +def create_document( + project_name: str, + document_name: str, + document: Document, session=Depends(get_db), -): - return client.update_user(user, session=session) + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new document in the database. + + :param project_name: The name of the project to create the document in. + :param document_name: The name of the document to create. + :param document: The document to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if document.owner_id is None: + document.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + document.name = document_name + document.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_document(document=document, session=session) + + +@router.get("/projects/{project_name}/documents/{document_name}") +def get_document( + project_name: str, document_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a document from the database. + + :param project_name: The name of the project to get the document from. + :param document_name: The name of the document to get. + :param session: The database session. + + :return: The document from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_document( + project_id=project_id, document_name=document_name, session=session + ) + + +@router.put("/projects/{project_name}/documents/{document_name}") +def update_document( + project_name: str, + document: Document, + document_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a document in the database. + + :param project_name: The name of the project to update the document in. + :param document: The document to update. + :param document_name: The name of the document to update. + :param session: The database session. + + :return: The response from the database. + """ + document.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if document_name != document.name: + raise ValueError( + f"Document name does not match: {document_name} != {document.name}" + ) + return client.update_document(document=document, session=session) + + +@router.delete("/projects/{project_name}/documents/{document_id}") +def delete_document( + project_name: str, document_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a document from the database. + + :param project_name: The name of the project to delete the document from. + :param document_id: The ID of the document to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_document( + project_id=project_id, document_id=document_id, session=session + ) + + +@router.get("/projects/{project_name}/documents") +def list_documents( + project_name: str, + version: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List documents in the database. + + :param project_name: The name of the project to list the documents from. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_documents( + project_id=project_id, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) + + +@router.post("/projects/{project_name}/workflows/{workflow_name}") +def create_workflow( + project_name: str, + workflow_name: str, + workflow: Workflow, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new workflow in the database. + + :param project_name: The name of the project to create the workflow in. + :param workflow_name: The name of the workflow to create. + :param workflow: The workflow to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if workflow.owner_id is None: + workflow.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + workflow.name = workflow_name + workflow.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_workflow(workflow=workflow, session=session) + + +@router.get("/projects/{project_name}/workflows/{workflow_name}") +def get_workflow( + project_name: str, workflow_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a workflow from the database. + :param project_name: The name of the project to get the workflow from. + :param workflow_name: The name of the workflow to get. + :param session: The database session. -# add routs for chat sessions, list_sessions, get_session -@router.post("/session") -def create_session(chat_session: ChatSession, session=Depends(get_db)): - return client.create_session(chat_session, session=session) + :return: The workflow from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_workflow( + project_id=project_id, workflow_name=workflow_name, session=session + ) + + +@router.put("/projects/{project_name}/workflows/{workflow_name}") +def update_workflow( + project_name: str, + workflow: Workflow, + workflow_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a workflow in the database. + + :param project_name: The name of the project to update the workflow in. + :param workflow: The workflow to update. + :param workflow_name: The name of the workflow to update. + :param session: The database session. + + :return: The response from the database. + """ + workflow.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if workflow_name != workflow.name: + raise ValueError( + f"Workflow name does not match: {workflow_name} != {workflow.name}" + ) + return client.update_workflow(workflow=workflow, session=session) + + +@router.delete("/projects/{project_name}/workflows/{workflow_id}") +def delete_workflow( + project_name: str, workflow_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a workflow from the database. + + :param project_name: The name of the project to delete the workflow from. + :param workflow_id: The ID of the workflow to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_workflow( + project_id=project_id, workflow_id=workflow_id, session=session + ) + + +@router.get("/projects/{project_name}/workflows") +def list_workflows( + project_name: str, + version: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List workflows in the database. + + :param project_name: The name of the project to list the workflows from. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_workflows( + project_id=project_id, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) + + +@router.post("/users/{user_name}/sessions/{session_name}") +def create_session( + user_name: str, + session_name: str, + chat_session: ChatSession, + session=Depends(get_db), +) -> ApiResponse: + """ + Create a new session in the database. + + :param user_name: The name of the user to create the session for. + :param session_name: The name of the session to create. + :param chat_session: The session to create. + :param session: The database session. + + :return: The response from the database. + """ + chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ + "id" + ] + return client.create_chat_session(chat_session=chat_session, session=session) + + +@router.get("/users/{user_name}/sessions/{session_name}") +def get_session( + user_name: str, session_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a session from the database. If the session ID is "$last", get the last session for the user. + + :param user_name: The name of the user to get the session for. + :param session_name: The name of the session to get. + :param session: The database session. + + :return: The session from the database. + """ + user_id = None + if session_name == "$last": + user_id = client.get_user(user_name=user_name, session=session).data["id"] + session_name = None + return client.get_chat_session( + session_name=session_name, user_id=user_id, session=session + ) + + +@router.put("/users/{user_name}/sessions/{session_name}") +def update_session( + user_name: str, + chat_session: ChatSession, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a session in the database. + + :param user_name: The name of the user to update the session for. + :param chat_session: The session to update. + :param session: The database session. + + :return: The response from the database. + """ + chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ + "id" + ] + return client.update_chat_session(chat_session=chat_session, session=session) -@router.get("/sessions") -async def list_sessions( - username: str = None, +@router.get("/users/{user_name}/sessions") +def list_sessions( + user_name: str, last: int = 0, created: str = None, + workflow_id: str = None, mode: OutputMode = OutputMode.Details, session=Depends(get_db), - auth=Depends(get_auth_user), -): - user = None if username and username == "all" else (username or auth.username) - return client.list_sessions( - user, created_after=created, last=last, output_mode=mode, session=session +) -> ApiResponse: + """ + List sessions in the database. + + :param user_name: The name of the user to list the sessions for. + :param last: The number of sessions to get. + :param created: The date to filter by. + :param workflow_id: The ID of the workflow to filter by. + :param mode: The output mode. + :param session: The database session. + + :return: The response from the database. + """ + user_id = client.get_user(user_name=user_name, session=session).data["id"] + return client.list_chat_sessions( + user_id=user_id, + last=last, + created_after=created, + workflow_id=workflow_id, + output_mode=mode, + session=session, ) -@router.get("/session/{session_id}") -async def get_session( - session_id: str, session=Depends(get_db), auth=Depends(get_auth_user) -): - user = None - if session_id == "$last": - user = auth.username - session_id = None - return client.get_session(session_id, user, session=session) +@router.post("/projects/{project_name}/workflows/{workflow_name}/infer") +def infer_workflow( + project_name: str, + workflow_name: str, + query: QueryItem, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Run application workflow. + + :param project_name: The name of the project to run the workflow in. + :param workflow_name: The name of the workflow to run. + :param query: The query to run. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # Get workflow from the database + workflow = Workflow.from_dict( + get_workflow(project_name, workflow_name, session).data + ) + path = Workflow.get_infer_path(workflow) + + data = { + "item": query.dict(), + "workflow": workflow.to_dict(short=True), + } + + # Sent the event to the application's workflow: + return _send_to_application( + path=path, + method="POST", + data=json.dumps(data), + auth=auth, + ) + + +# @router.post("/pipeline/{name}/run") +# def run_pipeline( +# request: Request, item: QueryItem, name: str, auth=Depends(get_auth_user) +# ): +# """This is the query command""" +# +# return _send_to_application( +# path=f"pipeline/{name}/run", +# method="POST", +# request=request, +# auth=auth, +# ) -@router.put("/session/{session_id}") -async def update_session( - session_id: str, chat_session: ChatSession, session=Depends(get_db) +@router.post("/projects/{project_name}/data_sources/{data_source_name}/ingest") +def ingest( + project_name, + data_source_name, + loader: str, + path: str, + metadata=None, + version: str = None, + from_file: bool = False, + session=Depends(get_db), + auth=Depends(get_auth_user), ): - chat_session.name = session_id - return client.update_session(chat_session, session=session) + """ + Ingest document into the vector database. + + :param project_name: The name of the project to ingest the documents into. + :param data_source_name: The name of the data source to ingest the documents into. + :param loader: The data loader type to use. + :param path: The path to the document to ingest. + :param metadata: The metadata to associate with the documents. + :param version: The version of the documents. + :param from_file: Whether the documents are from a file. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the application. + """ + data_source = get_data_source( + project_name=project_name, data_source_name=data_source_name, session=session + ).data + data_source = DataSource.from_dict(data_source) + + # Create document from path: + document = Document( + name=os.path.basename(path), + version=version, + path=path, + owner_id=data_source.owner_id, + ) + # Add document to the database: + document = create_document( + project_name=project_name, + document_name=document.name, + document=document, + session=session, + auth=auth, + ).data -@router.post("/transcribe") -async def transcribe_file(file: UploadFile = File(...)): - file_contents = await file.read() - file_handler = file.file - return transcribe_file(file_handler) + # Send ingest to application: + params = { + "loader": loader, + "from_file": from_file, + } + + data = { + "document": document, + "database_kwargs": data_source.database_kwargs, + } + if metadata is not None: + params["metadata"] = json.dumps(metadata) + + return _send_to_application( + path=f"data_sources/{data_source_name}/ingest", + method="POST", + data=json.dumps(data), + params=params, + auth=auth, + ) # Include the router in the main app diff --git a/controller/src/main.py b/controller/src/main.py index 02f18ef..06b8a99 100644 --- a/controller/src/main.py +++ b/controller/src/main.py @@ -20,10 +20,17 @@ import yaml from tabulate import tabulate -from controller.src.sqlclient import client -from controller.src.model import User, DocCollection, QueryItem -from controller.src.config import config import controller.src.api as api +from controller.src.config import config +from controller.src.model import ( + DataSource, + Document, + Project, + QueryItem, + User, + Workflow, +) +from controller.src.sqlclient import client @click.group() @@ -34,25 +41,59 @@ def cli(): # click command for initializing the database tables @click.command() def initdb(): - """Initialize the database (delete old tables)""" + """ + Initialize the database tables (delete old tables). + """ click.echo("Running Init DB") session = client.get_db_session() client.create_tables(True) - # create a guest user, and the default document collection - client.create_user( + + # Create admin user: + click.echo("Creating admin user") + user_id = client.create_user( User( name="guest", - email="guest@any.com", + email="guest@example.com", full_name="Guest User", + is_admin=True, + ), + session=session, + ).data["id"] + + # Create project: + click.echo("Creating default project") + project_id = client.create_project( + Project( + name="default", + description="Default Project", + owner_id=user_id, + ), + session=session, + ).data["id"] + + # Create data source: + click.echo("Creating default data source") + client.create_data_source( + DataSource( + name="default", + description="Default Data Source", + owner_id=user_id, + project_id=project_id, + data_source_type="vector", ), session=session, ) - client.create_collection( - DocCollection( + + # Create Workflow: + click.echo("Creating default workflow") + client.create_workflow( + Workflow( name="default", - description="Default Collection", - owner_name="guest", - category="vector", + description="Default Workflow", + owner_id=user_id, + project_id=project_id, + workflow_type="application", + deployment="http://localhost:8000", ), session=session, ) @@ -67,32 +108,77 @@ def print_config(): @click.command() +@click.argument("project", type=str) @click.argument("path", type=str) +@click.option("-n", "--name", type=str, help="Document name", default=None) @click.option("-l", "--loader", type=str, help="Type of data loader") @click.option( "-m", "--metadata", type=(str, str), multiple=True, help="Metadata Key value pair" ) @click.option("-v", "--version", type=str, help="document version") -@click.option("-c", "--collection", type=str, help="Vector DB collection name") +@click.option("-d", "--data-source", type=str, help="Data source name") @click.option( "-f", "--from-file", is_flag=True, help="Take the document paths from the file" ) -def ingest(path, loader, metadata, version, collection, from_file): - """Ingest documents into the vector database""" +def ingest(project, path, name, loader, metadata, version, data_source, from_file): + """ + Ingest data into the data source. + + :param project: The project name to which the document belongs. + :param path: Path to the document + :param name: Name of the document + :param loader: Type of data loader, web, .csv, .md, .pdf, .txt, etc. + :param metadata: Metadata Key value pair labels + :param version: Version of the document + :param data_source: Data source name + :param from_file: Take the document paths from the file + + :return: None + """ + session = client.get_db_session() + project = client.get_project(project_name=project, session=session).data + project = Project.from_dict(project) + data_source = client.get_data_source( + project_id=project.id, + data_source_name=data_source or "default", + session=session, + ).data + data_source = DataSource.from_dict(data_source) + + # Create document from path: + document = Document( + name=name or path, + version=version, + path=path, + owner_id=data_source.owner_id, + project_id=project.id, + ) + + # Add document to the database: + document = client.create_document( + document=document, + session=session, + ).data + + # Send ingest to application: params = { - "path": path, + "loader": loader, "from_file": from_file, - "version": version, } + + data = { + "document": document, + "database_kwargs": data_source.database_kwargs, + } + if metadata: - print(metadata) params["metadata"] = json.dumps({metadata[0]: metadata[1]}) - collection = collection or "default" click.echo(f"Running Data Ingestion from: {path} with loader: {loader}") response = api._send_to_application( - path=f"collections/{collection}/{loader}/ingest", + path=f"data_sources/{data_source.name}/ingest", method="POST", + data=json.dumps(data), params=params, ) if response["status"] == "ok": @@ -103,6 +189,10 @@ def ingest(path, loader, metadata, version, collection, from_file): @click.command() @click.argument("question", type=str) +@click.option("-p", "--project", type=str, default="default", help="Project name") +@click.option( + "-n", "--workflow-name", type=str, default="default", help="Workflow name" +) @click.option( "-f", "--filter", @@ -110,31 +200,56 @@ def ingest(path, loader, metadata, version, collection, from_file): multiple=True, help="Search filter Key value pair", ) -@click.option("-c", "--collection", type=str, help="Vector DB collection name") +@click.option("-c", "--data-source", type=str, help="Data Source name") @click.option("-u", "--user", type=str, help="Username") @click.option("-s", "--session", type=str, help="Session ID") -@click.option( - "-n", "--pipeline-name", type=str, default="default", help="Pipeline name" -) -def query(question, filter, collection, user, session, pipeline_name): - """Run a chat query on the vector database collection""" - click.echo(f"Running Query for: {question}") - search_args = [filter] if filter else None - query_item = QueryItem( +def infer( + question: str, project: str, workflow_name: str, filter, data_source, user, session +): + """ + Run a chat query on the data source + + :param question: The question to ask + :param project: The project name + :param workflow_name: The workflow name + :param filter: Filter Key value pair + :param data_source: Data source name + :param user: The name of the user + :param session: The session name + + :return: None + """ + db_session = client.get_db_session() + + project = client.get_project(project_name=project, session=db_session).data + # Getting the workflow: + workflow = client.get_workflow( + project_id=project["id"], workflow_name=workflow_name, session=db_session + ).data + workflow = Workflow.from_dict(workflow) + path = Workflow.get_infer_path(workflow) + + query = QueryItem( question=question, - session_id=session, - filter=search_args, - collection=collection, + session_name=session, + filter=filter, + data_source=data_source, ) - data = json.dumps(query_item.dict()) + data = { + "item": query.dict(), + "workflow": workflow.dict(), + } headers = {"x_username": user} if user else {} + + # Sent the event to the application's workflow: response = api._send_to_application( - path=f"pipeline/{pipeline_name}/run", + path=path, method="POST", - data=data, + data=json.dumps(data), headers=headers, ) + result = response["data"] click.echo(result["answer"]) click.echo(sources_to_text(result["sources"])) @@ -156,81 +271,151 @@ def update(): @click.option("-u", "--user", type=str, help="user name filter") @click.option("-e", "--email", type=str, help="email filter") def list_users(user, email): - """List users""" + """ + List all the users in the database + + :param user: username filter + :param email: email filter + + :return: None + """ click.echo("Running List Users") - data = client.list_users(email, user, output_mode="short") + data = client.list_users(email, user, output_mode="short").data table = format_table_results(data) click.echo(table) -# add a command to list document collections, similar to the list users command -@click.command("collections") +@click.command("data-sources") @click.option("-o", "--owner", type=str, help="owner filter") +@click.option("-p", "--project", type=str, help="project filter") +@click.option("-v", "--version", type=str, help="version filter") +@click.option("-t", "--source-type", type=str, help="data source type filter") @click.option( "-m", "--metadata", type=(str, str), multiple=True, help="metadata filter" ) -def list_collections(owner, metadata): - """List document collections""" +def list_data_sources(owner, project, version, source_type, metadata): + """ + List all the data sources in the database + + :param owner: owner filter + :param project: project filter + :param version: version filter + :param source_type: data source type filter + :param metadata: metadata filter (labels) + + :return: None + """ click.echo("Running List Collections") - - data = client.list_collections(owner, metadata, output_mode="short") + if owner: + owner = client.get_user(username=owner).data["id"] + if project: + project = client.get_project(project_name=project).data["id"] + + data = client.list_data_sources( + owner_id=owner, + project_id=project, + version=version, + data_source_type=source_type, + labels_match=metadata, + output_mode="short", + ).data table = format_table_results(data) click.echo(table) -@click.command("collection") +@click.command("data-source") @click.argument("name", type=str) +@click.option("-p", "--project", type=str, help="project name", default="default") @click.option("-o", "--owner", type=str, help="owner name") @click.option("-d", "--description", type=str, help="collection description") -@click.option("-c", "--category", type=str, help="collection category") +@click.option("-c", "--source-type", type=str, help="data source type") @click.option( "-l", "--labels", multiple=True, default=[], help="metadata labels filter" ) -def update_collection(name, owner, description, category, labels): - """Create or update a document collection""" +def update_data_source(name, project, owner, description, source_type, labels): + """ + Create or update a data source in the database + + :param name: data source name + :param project: project name + :param owner: owner name + :param description: data source description + :param source_type: type of data source + :param labels: metadata labels + + :return: None + """ click.echo("Running Create or Update Collection") labels = fill_params(labels) session = client.get_db_session() # check if the collection exists, if it does, update it, otherwise create it - collection_exists = client.get_collection(name, session=session).success + project = client.get_project(project_name=project, session=session).data + collection_exists = client.get_data_source( + project_id=project["id"], + data_source_name=name, + session=session, + ).success + if collection_exists: - client.update_collection( + client.update_data_source( session=session, - collection=DocCollection( - name=name, description=description, category=category, labels=labels + collection=DataSource( + project_id=project["id"], + name=name, + description=description, + data_source_type=source_type, + labels=labels, ), ).with_raise() else: client.create_collection( session=session, - collection=DocCollection( + collection=DataSource( + project_id=project["id"], name=name, description=description, owner_name=owner, - category=category, + data_source_type=source_type, labels=labels, ), ).with_raise() -# add a command to list chat sessions, similar to the list_users command @click.command("sessions") @click.option("-u", "--user", type=str, help="user name filter") @click.option("-l", "--last", type=int, default=0, help="last n sessions") @click.option("-c", "--created", type=str, help="created after date") def list_sessions(user, last, created): - """List chat sessions""" + """ + List chat sessions + + :param user: username filter + :param last: last n sessions + :param created: created after date + + :return: None + """ click.echo("Running List Sessions") - data = client.list_sessions(user, created, last, output_mode="short") - table = format_table_results(data["data"]) + if user: + user = client.get_user(user_name=user).data["id"] + data = client.list_chat_sessions( + user_id=user, created_after=created, last=last, output_mode="short" + ).data + table = format_table_results(data) click.echo(table) -def sources_to_text(sources): - """Convert a list of sources to a string.""" +def sources_to_text(sources) -> str: + """ + Convert a list of sources to a text string. + + :param sources: list of sources + + :return: text string + """ if not sources: return "" return "\nSource documents:\n" + "\n".join( @@ -238,8 +423,14 @@ def sources_to_text(sources): ) -def sources_to_md(sources): - """Convert a list of sources to a Markdown string.""" +def sources_to_md(sources) -> str: + """ + Convert a list of sources to a markdown string. + + :param sources: list of sources + + :return: markdown string + """ if not sources: return "" sources = { @@ -250,8 +441,14 @@ def sources_to_md(sources): ) -def get_title(metadata): - """Get title from metadata.""" +def get_title(metadata) -> str: + """ + Get the title from the metadata. + + :param metadata: metadata dictionary + + :return: title string + """ if "chunk" in metadata: return f"{metadata.get('title', '')}-{metadata['chunk']}" if "page" in metadata: @@ -259,7 +456,15 @@ def get_title(metadata): return metadata.get("title", "") -def fill_params(params, params_dict=None): +def fill_params(params, params_dict=None) -> dict: + """ + Fill the parameters dictionary from a list of key=value strings. + + :param params: list of key=value strings + :param params_dict: dictionary to fill + + :return: filled dictionary + """ params_dict = params_dict or {} for param in params: i = param.find("=") @@ -275,21 +480,28 @@ def fill_params(params, params_dict=None): def format_table_results(table_results): + """ + Format the table results as a printed table. + + :param table_results: table results dictionary + + :return: formatted table string + """ return tabulate(table_results, headers="keys", tablefmt="fancy_grid") cli.add_command(ingest) -cli.add_command(query) +cli.add_command(infer) cli.add_command(initdb) cli.add_command(print_config) cli.add_command(list) list.add_command(list_users) -list.add_command(list_collections) +list.add_command(list_data_sources) list.add_command(list_sessions) cli.add_command(update) -update.add_command(update_collection) +update.add_command(update_data_source) if __name__ == "__main__": cli() diff --git a/controller/src/model.py b/controller/src/model.py index 8b43f4c..5068830 100644 --- a/controller/src/model.py +++ b/controller/src/model.py @@ -15,7 +15,7 @@ from datetime import datetime from enum import Enum from http.client import HTTPException -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import yaml from pydantic import BaseModel @@ -26,7 +26,7 @@ # from llmapps.app.schema import Conversation, Message class ApiResponse(BaseModel): success: bool - data: Optional[Union[list, BaseModel, dict]] = None + data: Optional[Union[list, Type[BaseModel], dict]] = None error: Optional[str] = None def with_raise(self, format=None) -> "ApiResponse": @@ -52,7 +52,7 @@ class ChatRole(str, Enum): class Message(BaseModel): role: ChatRole - body: str + content: str extra_data: Optional[dict] = None sources: Optional[List[dict]] = None human_feedback: Optional[str] = None @@ -63,10 +63,10 @@ class Conversation(BaseModel): saved_index: int = 0 def __str__(self): - return "\n".join([f"{m.role}: {m.body}" for m in self.messages]) + return "\n".join([f"{m.role}: {m.content}" for m in self.messages]) - def add_message(self, role, body, sources=None): - self.messages.append(Message(role=role, body=body, sources=sources)) + def add_message(self, role, content, sources=None): + self.messages.append(Message(role=role, content=content, sources=sources)) def to_list(self): return self.dict()["messages"] @@ -84,9 +84,9 @@ def from_list(cls, data: list): class QueryItem(BaseModel): question: str - session_id: Optional[str] = None + session_name: Optional[str] = None filter: Optional[List[Tuple[str, str]]] = None - collection: Optional[str] = None + data_source: Optional[str] = None class OutputMode(str, Enum): @@ -248,8 +248,8 @@ def __str__(self): class BaseWithMetadata(Base): - id: str name: str + id: Optional[str] = None description: Optional[str] = None labels: Optional[Dict[str, Union[str, None]]] = None created: Optional[Union[str, datetime]] = None @@ -282,16 +282,16 @@ class Project(BaseWithVerMetadata): class DataSource(BaseWithVerMetadata): _top_level_fields = ["data_source_type"] - project_id: str data_source_type: DataSourceType + project_id: Optional[str] = None category: Optional[str] = None - database_kwargs: Optional[dict[str, str]] = None + database_kwargs: Optional[dict[str, str]] = {} class Dataset(BaseWithVerMetadata): _top_level_fields = ["task"] - project_id: str + project_id: Optional[str] = None task: str sources: Optional[List[str]] = None path: str @@ -302,9 +302,9 @@ class Model(BaseWithVerMetadata): _extra_fields = ["path", "producer", "deployment"] _top_level_fields = ["model_type", "task"] - project_id: str model_type: ModelType base_model: str + project_id: Optional[str] = None task: Optional[str] = None path: Optional[str] = None producer: Optional[str] = None @@ -315,35 +315,37 @@ class PromptTemplate(BaseWithVerMetadata): _extra_fields = ["arguments"] _top_level_fields = ["text"] - project_id: str text: str + project_id: Optional[str] = None arguments: Optional[List[str]] = None class Document(BaseWithVerMetadata): _top_level_fields = ["path", "origin"] - project_id: str path: str + project_id: Optional[str] = None origin: Optional[str] = None class Workflow(BaseWithVerMetadata): _top_level_fields = ["workflow_type"] - project_id: str workflow_type: WorkflowType + deployment: str + project_id: Optional[str] = None workflow_function: Optional[str] = None configuration: Optional[dict] = None graph: Optional[dict] = None - deployment: Optional[str] = None + + def get_infer_path(self): + return f"{self.deployment}/api/workflows/{self.name}/infer" -class ChatSession(BaseWithMetadata): - _extra_fields = ["history", "features", "state", "agent_name"] - _top_level_fields = ["username"] +class ChatSession(BaseWithOwner): + _extra_fields = ["history"] + _top_level_fields = ["workflow_id"] workflow_id: str - user_id: str history: Optional[List[Message]] = [] def to_conversation(self): diff --git a/controller/src/sqlclient.py b/controller/src/sqlclient.py index 6f34446..bf974bd 100644 --- a/controller/src/sqlclient.py +++ b/controller/src/sqlclient.py @@ -122,6 +122,22 @@ def _get( ) return ApiResponse(success=True, data=api_class.from_orm_object(obj)) + # def _get_by_name(self, session: sqlalchemy.orm.Session, db_class, api_class, name: str) -> ApiResponse: + # """ + # Get an object from the database by name. + # + # :param session: The session to use. + # :param db_class: The DB class of the object. + # :param api_class: The API class of the object. + # + # :return: A response object with the success status and the object when successful. + # """ + # session = self.get_db_session(session) + # obj = session.query(db_class).filter_by(name=name).one_or_none() + # if obj is None: + # return ApiResponse(success=False, error=f"{db_class} object ({name}) not found") + # return ApiResponse(success=True, data=api_class.from_orm_object(obj)) + def _update( self, session: sqlalchemy.orm.Session, db_class, api_object, **kwargs ) -> ApiResponse: @@ -222,18 +238,36 @@ def create_user( return self._create(session, db.User, user) def get_user( - self, user_id: str, session: sqlalchemy.orm.Session = None + self, + user_id: str = None, + user_name: str = None, + email: str = None, + session: sqlalchemy.orm.Session = None, ) -> ApiResponse: """ Get a user from the database. + Either user_id or user_name or email must be provided. - :param user_id: The ID of the user to get. - :param session: The session to use. + :param user_id: The ID of the user to get. + :param user_name: The name of the user to get. + :param email: The email of the user to get. + :param session: The session to use. :return: A response object with the success status and the user when successful. """ - logger.debug(f"Getting user: user_id={user_id}") - return self._get(session, db.User, api_models.User, id=user_id) + args = {} + if email: + args["email"] = email + elif user_name: + args["name"] = user_name + elif user_id: + args["id"] = user_id + else: + return ApiResponse( + success=False, error="user_id or user_name or email must be provided" + ) + logger.debug(f"Getting user: user_id={user_id}, user_name={user_name}") + return self._get(session, db.User, api_models.User, **args) def update_user( self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None @@ -252,17 +286,17 @@ def update_user( return self._update(session, db.User, user, name=user.name) def delete_user( - self, user_id: str, session: sqlalchemy.orm.Session = None + self, user_name: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Delete a user from the database. - :param user_id: The ID of the user to delete. + :param user_name: The name of the user to delete. :param session: :return: """ - logger.debug(f"Deleting user: user_id={user_id}") - return self._delete(session, db.User, id=user_id) + logger.debug(f"Deleting user: user_name={user_name}") + return self._delete(session, db.User, name=user_name) def list_users( self, @@ -319,18 +353,18 @@ def create_project( return self._create(session, db.Project, project) def get_project( - self, project_id: str, session: sqlalchemy.orm.Session = None + self, project_name: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Get a project from the database. - :param project_id: The ID of the project to get. - :param session: The session to use. + :param project_name: The name of the project to get. + :param session: The session to use. :return: A response object with the success status and the project when successful. """ - logger.debug(f"Getting project: project_id={project_id}") - return self._get(session, db.Project, api_models.Project, id=project_id) + logger.debug(f"Getting project: project_name={project_name}") + return self._get(session, db.Project, api_models.Project, name=project_name) def update_project( self, @@ -348,21 +382,21 @@ def update_project( logger.debug(f"Updating project: {project}") if isinstance(project, dict): project = api_models.Project.from_dict(project) - return self._update(session, db.Project, project, id=project.id) + return self._update(session, db.Project, project, name=project.name) def delete_project( - self, project_id: str, session: sqlalchemy.orm.Session = None + self, project_name: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Delete a project from the database. - :param project_id: The ID of the project to delete. - :param session: The session to use. + :param project_name: The name of the project to delete. + :param session: The session to use. :return: A response object with the success status. """ - logger.debug(f"Deleting project: project_id={project_id}") - return self._delete(session, db.Project, id=project_id) + logger.debug(f"Deleting project: project_name={project_name}") + return self._delete(session, db.Project, name=project_name) def list_projects( self, @@ -411,7 +445,7 @@ def create_data_source( :param data_source: The data source object to create. :param session: The session to use. - :return: A response object with the success status and the created data source when successful. + :return: A response object with the success status and the created data source when successful. """ logger.debug(f"Creating data source: {data_source}") if isinstance(data_source, dict): @@ -419,19 +453,27 @@ def create_data_source( return self._create(session, db.DataSource, data_source) def get_data_source( - self, data_source_id: str, session: sqlalchemy.orm.Session = None + self, + project_id: str, + data_source_name: str, + session: sqlalchemy.orm.Session = None, ) -> ApiResponse: """ Get a data source from the database. - :param data_source_id: The ID of the data source to get. - :param session: The session to use. + :param project_id: The ID of the project to get the data source from. + :param data_source_name: The ID of the data source to get. + :param session: The session to use. - :return: A response object with the success status and the data source when successful. + :return: A response object with the success status and the data source when successful. """ - logger.debug(f"Getting data source: data_source_id={data_source_id}") + logger.debug(f"Getting data source: data_source_name={data_source_name}") return self._get( - session, db.DataSource, api_models.DataSource, id=data_source_id + session, + db.DataSource, + api_models.DataSource, + name=data_source_name, + project_id=project_id, ) def update_data_source( @@ -450,21 +492,27 @@ def update_data_source( logger.debug(f"Updating data source: {data_source}") if isinstance(data_source, dict): data_source = api_models.DataSource.from_dict(data_source) - return self._update(session, db.DataSource, data_source, id=data_source.id) + return self._update(session, db.DataSource, data_source) def delete_data_source( - self, data_source_id: str, session: sqlalchemy.orm.Session = None + self, + project_id: str, + data_source_id: str, + session: sqlalchemy.orm.Session = None, ) -> ApiResponse: """ Delete a data source from the database. + :param project_id: The ID of the project to delete the data source from. :param data_source_id: The ID of the data source to delete. :param session: The session to use. :return: A response object with the success status. """ logger.debug(f"Deleting data source: data_source_id={data_source_id}") - return self._delete(session, db.DataSource, id=data_source_id) + return self._delete( + session, db.DataSource, project_id=project_id, id=data_source_id + ) def list_data_sources( self, @@ -504,8 +552,8 @@ def list_data_sources( filters.append(db.DataSource.data_source_type == data_source_type) return self._list( session=session, - db_class=db.User, - api_class=api_models.User, + db_class=db.DataSource, + api_class=api_models.DataSource, output_mode=output_mode, labels_match=labels_match, filters=filters, @@ -530,18 +578,25 @@ def create_dataset( return self._create(session, db.Dataset, dataset) def get_dataset( - self, dataset_id: str, session: sqlalchemy.orm.Session = None + self, project_id: str, dataset_id: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Get a dataset from the database. - :param dataset_id: The ID of the dataset to get. - :param session: The session to use. + :param project_id: The ID of the project to get the dataset from. + :param dataset_id: The ID of the dataset to get. + :param session: The session to use. :return: A response object with the success status and the dataset when successful. """ logger.debug(f"Getting dataset: dataset_id={dataset_id}") - return self._get(session, db.Dataset, api_models.Dataset, id=dataset_id) + return self._get( + session, + db.Dataset, + api_models.Dataset, + id=dataset_id, + project_id=project_id, + ) def update_dataset( self, @@ -562,18 +617,19 @@ def update_dataset( return self._update(session, db.Dataset, dataset, id=dataset.id) def delete_dataset( - self, dataset_id: str, session: sqlalchemy.orm.Session = None + self, project_id: str, dataset_id: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Delete a dataset from the database. - :param dataset_id: The ID of the dataset to delete. - :param session: The session to use. + :param project_id: The ID of the project to delete the dataset from. + :param dataset_id: The ID of the dataset to delete. + :param session: The session to use. :return: A response object with the success status. """ logger.debug(f"Deleting dataset: dataset_id={dataset_id}") - return self._delete(session, db.Dataset, id=dataset_id) + return self._delete(session, db.Dataset, project_id=project_id, id=dataset_id) def list_datasets( self, @@ -628,7 +684,7 @@ def create_model( """ Create a new model in the database. - :param model: The model object to create. + :param model: The model object to create. :param session: The session to use. :return: A response object with the success status and the created model when successful. @@ -639,18 +695,21 @@ def create_model( return self._create(session, db.Model, model) def get_model( - self, model_id: str, session: sqlalchemy.orm.Session = None + self, project_id: str, model_id: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Get a model from the database. - :param model_id: The ID of the model to get. - :param session: The session to use. + :param project_id: The ID of the project to get the model from. + :param model_id: The ID of the model to get. + :param session: The session to use. :return: A response object with the success status and the model when successful. """ logger.debug(f"Getting model: model_id={model_id}") - return self._get(session, db.Model, api_models.Model, id=model_id) + return self._get( + session, db.Model, api_models.Model, project_id=project_id, id=model_id + ) def update_model( self, @@ -660,7 +719,7 @@ def update_model( """ Update an existing model in the database. - :param model: The model object with the new data. + :param model: The model object with the new data. :param session: The session to use. :return: A response object with the success status and the updated model when successful. @@ -671,18 +730,19 @@ def update_model( return self._update(session, db.Model, model, id=model.id) def delete_model( - self, model_id: str, session: sqlalchemy.orm.Session = None + self, project_id: str, model_id: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Delete a model from the database. - :param model_id: The ID of the model to delete. - :param session: The session to use. + :param project_id: The ID of the project to delete the model from. + :param model_id: The ID of the model to delete. + :param session: The session to use. :return: A response object with the success status. """ logger.debug(f"Deleting model: model_id={model_id}") - return self._delete(session, db.Model, id=model_id) + return self._delete(session, db.Model, project_id=project_id, id=model_id) def list_models( self, @@ -742,7 +802,7 @@ def create_prompt_template( Create a new prompt template in the database. :param prompt_template: The prompt template object to create. - :param session: The session to use. + :param session: The session to use. :return: A response object with the success status and the created prompt template when successful. """ @@ -752,13 +812,17 @@ def create_prompt_template( return self._create(session, db.PromptTemplate, prompt_template) def get_prompt_template( - self, prompt_template_id: str, session: sqlalchemy.orm.Session = None + self, + project_id: str, + prompt_template_id: str, + session: sqlalchemy.orm.Session = None, ) -> ApiResponse: """ Get a prompt template from the database. - :param prompt_template_id: The ID of the prompt template to get. - :param session: The session to use. + :param project_id: The ID of the project to get the prompt template from. + :param prompt_template_id: The ID of the prompt template to get. + :param session: The session to use. :return: A response object with the success status and the prompt template when successful. """ @@ -766,7 +830,11 @@ def get_prompt_template( f"Getting prompt template: prompt_template_id={prompt_template_id}" ) return self._get( - session, db.PromptTemplate, api_models.PromptTemplate, id=prompt_template_id + session, + db.PromptTemplate, + api_models.PromptTemplate, + project_id=project_id, + id=prompt_template_id, ) def update_prompt_template( @@ -778,7 +846,7 @@ def update_prompt_template( Update an existing prompt template in the database. :param prompt_template: The prompt template object with the new data. - :param session: The session to use. + :param session: The session to use. :return: A response object with the success status and the updated prompt template when successful. """ @@ -790,20 +858,26 @@ def update_prompt_template( ) def delete_prompt_template( - self, prompt_template_id: str, session: sqlalchemy.orm.Session = None + self, + project_id: str, + prompt_template_id: str, + session: sqlalchemy.orm.Session = None, ) -> ApiResponse: """ Delete a prompt template from the database. - :param prompt_template_id: The ID of the prompt template to delete. - :param session: The session to use. + :param project_id: The ID of the project to delete the prompt template from. + :param prompt_template_id: The ID of the prompt template to delete. + :param session: The session to use. :return: A response object with the success status. """ logger.debug( f"Deleting prompt template: prompt_template_id={prompt_template_id}" ) - return self._delete(session, db.PromptTemplate, id=prompt_template_id) + return self._delete( + session, db.PromptTemplate, project_id=project_id, id=prompt_template_id + ) def list_prompt_templates( self, @@ -865,18 +939,25 @@ def create_document( return self._create(session, db.Document, document) def get_document( - self, document_id: str, session: sqlalchemy.orm.Session = None + self, project_id: str, document_id: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Get a document from the database. + :param project_id: The ID of the project to get the document from. :param document_id: The ID of the document to get. :param session: The session to use. :return: A response object with the success status and the document when successful. """ logger.debug(f"Getting document: document_id={document_id}") - return self._get(session, db.Document, api_models.Document, id=document_id) + return self._get( + session, + db.Document, + api_models.Document, + project_id=project_id, + id=document_id, + ) def update_document( self, @@ -897,18 +978,19 @@ def update_document( return self._update(session, db.Document, document, id=document.id) def delete_document( - self, document_id: str, session: sqlalchemy.orm.Session = None + self, project_id: str, document_id: str, session: sqlalchemy.orm.Session = None ) -> ApiResponse: """ Delete a document from the database. + :param project_id: The ID of the project to delete the document from. :param document_id: The ID of the document to delete. - :param session: The session to use. + :param session: The session to use. :return: A response object with the success status. """ logger.debug(f"Deleting document: document_id={document_id}") - return self._delete(session, db.Document, id=document_id) + return self._delete(session, db.Document, project_id=project_id, id=document_id) def list_documents( self, @@ -951,6 +1033,125 @@ def list_documents( filters=filters, ) + def create_workflow( + self, + workflow: Union[api_models.Workflow, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Create a new workflow in the database. + + :param workflow: The workflow object to create. + :param session: The session to use. + + :return: A response object with the success status and the created workflow when successful. + """ + logger.debug(f"Creating workflow: {workflow}") + if isinstance(workflow, dict): + workflow = api_models.Workflow.from_dict(workflow) + return self._create(session, db.Workflow, workflow) + + def get_workflow( + self, + project_id: str, + workflow_name: str, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Get a workflow from the database. + + :param project_id: The ID of the project to get the workflow from. + :param workflow_name: The name of the workflow to get. + :param session: The session to use. + + :return: A response object with the success status and the workflow when successful. + """ + logger.debug(f"Getting workflow: workflow_name={workflow_name}") + return self._get( + session, + db.Workflow, + api_models.Workflow, + project_id=project_id, + name=workflow_name, + ) + + def update_workflow( + self, + workflow: Union[api_models.Workflow, dict], + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + Update an existing workflow in the database. + + :param workflow: The workflow object with the new data. + :param session: The session to use. + + :return: A response object with the success status and the updated workflow when successful. + """ + logger.debug(f"Updating workflow: {workflow}") + if isinstance(workflow, dict): + workflow = api_models.Workflow.from_dict(workflow) + return self._update(session, db.Workflow, workflow, id=workflow.id) + + def delete_workflow( + self, workflow_id: str, session: sqlalchemy.orm.Session = None + ) -> ApiResponse: + """ + Delete a workflow from the database. + + :param workflow_id: The ID of the workflow to delete. + :param session: The session to use. + + :return: A response object with the success status. + """ + logger.debug(f"Deleting workflow: workflow_id={workflow_id}") + return self._delete(session, db.Workflow, id=workflow_id) + + def list_workflows( + self, + owner_id: str = None, + version: str = None, + project_id: str = None, + workflow_type: Union[api_models.WorkflowType, str] = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.Details, + session: sqlalchemy.orm.Session = None, + ) -> ApiResponse: + """ + List workflows from the database. + + :param owner_id: The owner to filter the workflows by. + :param version: The version to filter the workflows by. + :param project_id: The project to filter the workflows by. + :param workflow_type: The workflow type to filter the workflows by. + :param labels_match: The labels to match, filter the workflows by labels. + :param output_mode: The output mode. + :param session: The session to use. + + :return: A response object with the success status and the list of workflows when successful. + """ + logger.debug( + f"Getting workflows: owner_id={owner_id}, version={version}, project_id={project_id}," + f" workflow_type={workflow_type}, labels_match={labels_match}, mode={output_mode}" + ) + filters = [] + if owner_id: + filters.append(db.Workflow.owner_id == owner_id) + if version: + filters.append(db.Workflow.version == version) + if project_id: + filters.append(db.Workflow.project_id == project_id) + if workflow_type: + filters.append(db.Workflow.workflow_type == workflow_type) + return self._list( + session=session, + db_class=db.Workflow, + api_class=api_models.Workflow, + output_mode=output_mode, + labels_match=labels_match, + filters=filters, + ) + def create_chat_session( self, chat_session: Union[api_models.ChatSession, dict], @@ -971,25 +1172,25 @@ def create_chat_session( def get_chat_session( self, - session_id: str, + session_name: str = None, user_id: str = None, session: sqlalchemy.orm.Session = None, ) -> ApiResponse: """ Get a chat session from the database. - :param session_id: The ID of the chat session to get. - :param user_id: The ID of the user to get the last session for. - :param session: The DB session to use. + :param session_name: The ID of the chat session to get. + :param user_id: The ID of the user to get the last session for. + :param session: The DB session to use. :return: A response object with the success status and the chat session when successful. """ logger.debug( - f"Getting chat session: session_id={session_id}, user_id={user_id}" + f"Getting chat session: session_name={session_name}, user_id={user_id}" ) - if session_id: + if session_name: return self._get( - session, db.Session, api_models.ChatSession, name=session_id + session, db.Session, api_models.ChatSession, name=session_name ) elif user_id: # get the last session for the user @@ -1017,7 +1218,7 @@ def update_chat_session( :return: A response object with the success status and the updated chat session when successful. """ logger.debug(f"Updating chat session: {chat_session}") - return self._update(session, db.Session, chat_session, id=chat_session.id) + return self._update(session, db.Session, chat_session, name=chat_session.name) def delete_chat_session( self, session_id: str, session: sqlalchemy.orm.Session = None @@ -1061,7 +1262,7 @@ def list_chat_sessions( session = self.get_db_session(session) query = session.query(db.Session) if user_id: - query = query.filter(db.Session.user_id == user_id) + query = query.filter(db.Session.owner_id == user_id) if workflow_id: query = query.filter(db.Session.workflow_id == workflow_id) if created_after: diff --git a/controller/src/sqldb.py b/controller/src/sqldb.py index 3484a9a..d97ffdc 100644 --- a/controller/src/sqldb.py +++ b/controller/src/sqldb.py @@ -77,11 +77,10 @@ class BaseSchema(Base): :arg id: unique identifier for each entry. :arg name: entry's name. :arg description: The entry's description. - :arg owner_id: The entry's owner's id. The following columns are automatically added to each table: - - date_created: The entry's creation date. - - date_updated: The entry's last update date. + - created: The entry's creation date. + - updated: The entry's last update date. - spec: A dictionary to store additional information. """ @@ -102,22 +101,20 @@ def labels(cls): # Columns: id: Mapped[str] = mapped_column(String(ID_LENGTH), primary_key=True) - name: Mapped[str] + name: Mapped[str] = mapped_column(String(255), unique=True) description: Mapped[Optional[str]] - date_created: Mapped[datetime.datetime] = mapped_column( - default=datetime.datetime.utcnow - ) - date_updated: Mapped[Optional[datetime.datetime]] = mapped_column( + created: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow) + updated: Mapped[Optional[datetime.datetime]] = mapped_column( default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, ) spec = Column(MutableDict.as_mutable(JSON), nullable=True) - def __init__(self, id, name, description=None, owner_id=None, labels=None): + def __init__(self, id, name, spec, description=None, labels=None): self.id = id self.name = name + self.spec = spec self.description = description - self.owner_id = owner_id self.labels = labels or [] @@ -134,6 +131,10 @@ class OwnerBaseSchema(BaseSchema): String(ID_LENGTH), ForeignKey("user.id") ) + def __init__(self, id, name, spec, description=None, owner_id=None, labels=None): + super().__init__(id, name, spec, description, labels) + self.owner_id = owner_id + class VersionedBaseSchema(OwnerBaseSchema): """ @@ -144,10 +145,12 @@ class VersionedBaseSchema(OwnerBaseSchema): """ __abstract__ = True - version: Mapped[str] = mapped_column(String(255), primary_key=True) + version: Mapped[str] = mapped_column(String(255), primary_key=True, default="") - def __init__(self, id, name, version, description=None, owner_id=None, labels=None): - super().__init__(id, name, description, owner_id, labels) + def __init__( + self, id, name, spec, version, description=None, owner_id=None, labels=None + ): + super().__init__(id, name, spec, description, owner_id, labels) self.version = version @@ -203,10 +206,28 @@ class User(BaseSchema): # many-to-many relationship with projects: projects: Mapped[List["Project"]] = relationship( - back_populates="users", secondary=user_project + back_populates="users", + secondary=user_project, + primaryjoin="User.id == user_project.c.user_id", + secondaryjoin="and_(Project.id == user_project.c.project_id," + " Project.version == user_project.c.project_version)", + foreign_keys=[ + user_project.c.user_id, + user_project.c.project_id, + user_project.c.project_version, + ], ) # one-to-many relationship with sessions: - sessions: Mapped[List["Session"]] = relationship(back_populates="user") + sessions: Mapped[List["Session"]] = relationship( + back_populates="user", foreign_keys="Session.owner_id" + ) + + def __init__(self, id, name, email, full_name, spec, description=None, labels=None): + super().__init__( + id=id, name=name, description=description, spec=spec, labels=labels + ) + self.email = email + self.full_name = full_name class Project(VersionedBaseSchema): @@ -218,7 +239,15 @@ class Project(VersionedBaseSchema): # many-to-many relationship with user: users: Mapped[List["User"]] = relationship( - back_populates="projects", secondary=user_project + back_populates="projects", + secondary=user_project, + primaryjoin="and_(Project.id == user_project.c.project_id, Project.version == user_project.c.project_version)", + secondaryjoin="User.id == user_project.c.user_id", + foreign_keys=[ + user_project.c.user_id, + user_project.c.project_id, + user_project.c.project_version, + ], ) # one-to-many relationships: @@ -231,15 +260,17 @@ class Project(VersionedBaseSchema): workflows: Mapped[List["Workflow"]] = relationship(**relationship_args) def __init__( - self, - id, - name, - version=None, - description=None, - owner_id=None, - labels=None, + self, id, name, spec, version, description=None, owner_id=None, labels=None ): - super().__init__(id, name, version, description, owner_id, labels) + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) update_labels(self, {"_GENAI_FACTORY": True}) @@ -262,9 +293,44 @@ class DataSource(VersionedBaseSchema): project: Mapped["Project"] = relationship(back_populates="data_sources") # many-to-many relationship with documents: documents: Mapped[List["Document"]] = relationship( - back_populates="data_sources", secondary=ingestions + back_populates="data_sources", + secondary=ingestions, + primaryjoin="and_(DataSource.id == ingestions.c.data_source_id," + " DataSource.version == ingestions.c.data_source_version)", + secondaryjoin="and_(Document.id == ingestions.c.document_id," + " Document.version == ingestions.c.document_version)", + foreign_keys=[ + ingestions.c.data_source_id, + ingestions.c.data_source_version, + ingestions.c.document_id, + ingestions.c.document_version, + ], ) + def __init__( + self, + id, + name, + spec, + version, + project_id, + data_source_type, + description=None, + owner_id=None, + labels=None, + ): + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.project_id = project_id + self.data_source_type = data_source_type + class Dataset(VersionedBaseSchema): """ @@ -283,6 +349,30 @@ class Dataset(VersionedBaseSchema): # Many-to-one relationship with projects: project: Mapped["Project"] = relationship(back_populates="datasets") + def __init__( + self, + id, + name, + spec, + version, + project_id, + task, + description=None, + owner_id=None, + labels=None, + ): + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.project_id = project_id + self.task = task + class Model(VersionedBaseSchema): """ @@ -304,9 +394,46 @@ class Model(VersionedBaseSchema): project: Mapped["Project"] = relationship(back_populates="models") # many-to-many relationship with prompt_templates: prompt_templates: Mapped[List["PromptTemplate"]] = relationship( - back_populates="models", secondary=model_prompt_template + back_populates="models", + secondary=model_prompt_template, + primaryjoin="and_(Model.id == model_prompt_template.c.model_id," + " Model.version == model_prompt_template.c.model_version)", + secondaryjoin="and_(PromptTemplate.id == model_prompt_template.c.prompt_id," + " PromptTemplate.version == model_prompt_template.c.prompt_version)", + foreign_keys=[ + model_prompt_template.c.model_id, + model_prompt_template.c.model_version, + model_prompt_template.c.prompt_id, + model_prompt_template.c.prompt_version, + ], ) + def __init__( + self, + id, + name, + spec, + version, + project_id, + model_type, + task, + description=None, + owner_id=None, + labels=None, + ): + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.project_id = project_id + self.model_type = model_type + self.task = task + class PromptTemplate(VersionedBaseSchema): """ @@ -325,9 +452,42 @@ class PromptTemplate(VersionedBaseSchema): project: Mapped["Project"] = relationship(back_populates="prompt_templates") # many-to-many relationship with the 'Model' table models: Mapped[List["Model"]] = relationship( - back_populates="prompt_templates", secondary=model_prompt_template + back_populates="prompt_templates", + secondary=model_prompt_template, + primaryjoin="and_(PromptTemplate.id == model_prompt_template.c.prompt_id," + " PromptTemplate.version == model_prompt_template.c.prompt_version)", + secondaryjoin="and_(Model.id == model_prompt_template.c.model_id," + " Model.version == model_prompt_template.c.model_version)", + foreign_keys=[ + model_prompt_template.c.prompt_id, + model_prompt_template.c.prompt_version, + model_prompt_template.c.model_id, + model_prompt_template.c.model_version, + ], ) + def __init__( + self, + id, + name, + spec, + version, + project_id, + description=None, + owner_id=None, + labels=None, + ): + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.project_id = project_id + class Document(VersionedBaseSchema): """ @@ -349,9 +509,45 @@ class Document(VersionedBaseSchema): project: Mapped["Project"] = relationship(back_populates="documents") # many-to-many relationship with ingestion: data_sources: Mapped[List["DataSource"]] = relationship( - back_populates="documents", secondary=ingestions + back_populates="documents", + secondary=ingestions, + primaryjoin="and_(Document.id == ingestions.c.document_id, Document.version == ingestions.c.document_version)", + secondaryjoin="and_(DataSource.id == ingestions.c.data_source_id," + " DataSource.version == ingestions.c.data_source_version)", + foreign_keys=[ + ingestions.c.document_id, + ingestions.c.document_version, + ingestions.c.data_source_id, + ingestions.c.data_source_version, + ], ) + def __init__( + self, + id, + name, + spec, + version, + project_id, + path, + origin, + description=None, + owner_id=None, + labels=None, + ): + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.project_id = project_id + self.path = path + self.origin = origin + class Workflow(VersionedBaseSchema): """ @@ -374,24 +570,61 @@ class Workflow(VersionedBaseSchema): # one-to-many relationship with sessions: sessions: Mapped[List["Session"]] = relationship(back_populates="workflow") + def __init__( + self, + id, + name, + spec, + version, + project_id, + workflow_type, + description=None, + owner_id=None, + labels=None, + ): + super().__init__( + id=id, + name=name, + version=version, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.project_id = project_id + self.workflow_type = workflow_type + class Session(OwnerBaseSchema): """ The Chat Session table which is used to define chat sessions of an application workflow per user. :arg workflow_id: The workflow's id. - :arg user_id: The user's id. """ # Columns: workflow_id: Mapped[str] = mapped_column( String(ID_LENGTH), ForeignKey("workflow.id") ) - user_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("user.id")) # Relationships: # Many-to-one relationship with workflows: workflow: Mapped["Workflow"] = relationship(back_populates="sessions") # Many-to-one relationship with users: - user: Mapped["User"] = relationship(back_populates="sessions") + user: Mapped["User"] = relationship( + back_populates="sessions", foreign_keys="Session.owner_id" + ) + + def __init__( + self, id, name, spec, workflow_id, description=None, owner_id=None, labels=None + ): + super().__init__( + id=id, + name=name, + spec=spec, + description=description, + owner_id=owner_id, + labels=labels, + ) + self.workflow_id = workflow_id From 9d354adb684ff62299b752c11168ce6c17f571a1 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Tue, 13 Aug 2024 16:29:00 +0300 Subject: [PATCH 05/10] arrange code in directories --- README.md | 9 +- controller/Dockerfile | 2 +- controller/src/api.py | 1343 ----------------- controller/src/api/__init__.py | 13 + controller/src/api/api.py | 90 ++ controller/src/api/endpoints/__init__.py | 13 + controller/src/api/endpoints/base.py | 24 + controller/src/api/endpoints/data_sources.py | 243 +++ controller/src/api/endpoints/datasets.py | 161 ++ controller/src/api/endpoints/documents.py | 158 ++ controller/src/api/endpoints/models.py | 159 ++ controller/src/api/endpoints/projects.py | 113 ++ .../src/api/endpoints/prompt_templates.py | 156 ++ controller/src/api/endpoints/sessions.py | 119 ++ controller/src/api/endpoints/users.py | 106 ++ controller/src/api/endpoints/workflows.py | 211 +++ controller/src/api/utils.py | 90 ++ controller/src/config.py | 1 + controller/src/db/__init__.py | 22 + controller/src/{ => db}/sqlclient.py | 18 +- controller/src/{ => db}/sqldb.py | 0 controller/src/main.py | 12 +- controller/src/schemas/__init__.py | 24 + controller/src/{model.py => schemas/base.py} | 205 +-- controller/src/schemas/data_source.py | 37 + controller/src/schemas/dataset.py | 27 + controller/src/schemas/document.py | 24 + controller/src/schemas/model.py | 36 + controller/src/schemas/project.py | 19 + controller/src/schemas/prompt_template.py | 26 + controller/src/schemas/session.py | 51 + controller/src/schemas/user.py | 28 + controller/src/schemas/workflow.py | 40 + pyproject.toml | 25 + 34 files changed, 2055 insertions(+), 1550 deletions(-) delete mode 100644 controller/src/api.py create mode 100644 controller/src/api/__init__.py create mode 100644 controller/src/api/api.py create mode 100644 controller/src/api/endpoints/__init__.py create mode 100644 controller/src/api/endpoints/base.py create mode 100644 controller/src/api/endpoints/data_sources.py create mode 100644 controller/src/api/endpoints/datasets.py create mode 100644 controller/src/api/endpoints/documents.py create mode 100644 controller/src/api/endpoints/models.py create mode 100644 controller/src/api/endpoints/projects.py create mode 100644 controller/src/api/endpoints/prompt_templates.py create mode 100644 controller/src/api/endpoints/sessions.py create mode 100644 controller/src/api/endpoints/users.py create mode 100644 controller/src/api/endpoints/workflows.py create mode 100644 controller/src/api/utils.py create mode 100644 controller/src/db/__init__.py rename controller/src/{ => db}/sqlclient.py (98%) rename controller/src/{ => db}/sqldb.py (100%) create mode 100644 controller/src/schemas/__init__.py rename controller/src/{model.py => schemas/base.py} (57%) create mode 100644 controller/src/schemas/data_source.py create mode 100644 controller/src/schemas/dataset.py create mode 100644 controller/src/schemas/document.py create mode 100644 controller/src/schemas/model.py create mode 100644 controller/src/schemas/project.py create mode 100644 controller/src/schemas/prompt_template.py create mode 100644 controller/src/schemas/session.py create mode 100644 controller/src/schemas/user.py create mode 100644 controller/src/schemas/workflow.py create mode 100644 pyproject.toml diff --git a/README.md b/README.md index 641b2ba..3159061 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ python -m controller.src.main ingest -l web https://milvus.io/docs/overview.md To ask a question: ```shell -python -m controller.src.main query "What is a vector?" +python -m controller.src.main infer "What is Milvus?" ``` @@ -64,8 +64,9 @@ Options: Commands: config Print the config as a yaml file - ingest Ingest documents into the vector database - initdb Initialize the database (delete old tables) + infer Run a chat query on the data source + ingest Ingest data into the data source. + initdb Initialize the database tables (delete old tables). list List the different objects in the database (by category) - query Run a chat query on the vector database collection + update Create or update an object in the database ``` diff --git a/controller/Dockerfile b/controller/Dockerfile index cd42172..c1a51b2 100644 --- a/controller/Dockerfile +++ b/controller/Dockerfile @@ -42,4 +42,4 @@ RUN pip install -r /controller/requirements.txt RUN python -m controller.src.main initdb # Run the controller's API server: -CMD ["uvicorn", "controller.src.api:app", "--port", "8001"] +CMD ["uvicorn", "controller.src.api.api:app", "--port", "8001"] diff --git a/controller/src/api.py b/controller/src/api.py deleted file mode 100644 index bf2e8e4..0000000 --- a/controller/src/api.py +++ /dev/null @@ -1,1343 +0,0 @@ -# Copyright 2023 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -from typing import List, Optional, Tuple, Union - -import requests -from fastapi import APIRouter, Depends, FastAPI, Header, Request -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel - -from controller.src.config import config -from controller.src.model import ( - ApiResponse, - ChatSession, - Dataset, - DataSource, - DataSourceType, - Document, - Model, - OutputMode, - Project, - PromptTemplate, - QueryItem, - User, - Workflow, -) -from controller.src.sqlclient import client - -app = FastAPI() - -# Add CORS middleware, remove in production -origins = ["*"] # React app -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Create a router with a prefix -router = APIRouter(prefix="/api") - - -def get_db(): - db_session = None - try: - db_session = client.get_local_session() - yield db_session - finally: - if db_session: - db_session.close() - - -class AuthInfo(BaseModel): - username: str - token: str - roles: List[str] = [] - - -# placeholder for extracting the Auth info from the request -def get_auth_user( - request: Request, x_username: Union[str, None] = Header(None) -) -> AuthInfo: - """Get the user from the database""" - token = request.cookies.get("Authorization", "") - if x_username: - return AuthInfo(username=x_username, token=token) - else: - return AuthInfo(username="guest@example.com", token=token) - - -def _send_to_application( - path: str, method: str = "POST", request=None, auth=None, **kwargs -): - """ - Send a request to the application's API. - - :param path: The API path to send the request to. - :param method: The HTTP method to use: GET, POST, PUT, DELETE, etc. - :param request: The FastAPI request object. If provided, the data will be taken from the body of the request. - :param auth: The authentication information to use. If provided, the username will be added to the headers. - :param kwargs: Additional keyword arguments to pass in the request function. For example, headers, params, etc. - - :return: The JSON response from the application. - """ - if config.application_url not in path: - url = f"{config.application_url}/api/{path}" - else: - url = path - - if isinstance(request, Request): - # If the request is a FastAPI request, get the data from the body - kwargs["data"] = request._body.decode("utf-8") - if auth is not None: - kwargs["headers"] = {"x_username": auth.username} - - response = requests.request( - method=method, - url=url, - **kwargs, - ) - - # Check the response - if response.status_code == 200: - # If the request was successful, return the JSON response - return response.json() - else: - # If the request failed, raise an exception - response.raise_for_status() - - -@router.post("/tables") -def create_tables(drop_old: bool = False, names: list[str] = None): - return client.create_tables(drop_old=drop_old, names=names) - - -@router.post("/users") -def create_user( - user: User, - session=Depends(get_db), -) -> ApiResponse: - """ - Create a new user in the database. - - :param user: The user to create. - :param session: The database session. - - :return: The response from the database. - """ - return client.create_user(user=user, session=session) - - -@router.get("/users/{user_name}") -def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> ApiResponse: - """ - Get a user from the database. - - :param user_name: The name of the user to get. - :param email: The email address to get the user by if the name is not provided. - :param session: The database session. - - :return: The user from the database. - """ - return client.get_user(user_name=user_name, email=email, session=session) - - -@router.put("/users/{user_name}") -def update_user( - user: User, - user_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a user in the database. - - :param user: The user to update. - :param user_name: The name of the user to update. - :param session: The database session. - - :return: The response from the database. - """ - if user_name != user.name: - raise ValueError(f"User name does not match: {user_name} != {user.name}") - return client.update_user(user=user, session=session) - - -@router.delete("/users/{user_name}") -def delete_user(user_name: str, session=Depends(get_db)) -> ApiResponse: - """ - Delete a user from the database. - - :param user_name: The name of the user to delete. - :param session: The database session. - - :return: The response from the database. - """ - return client.delete_user(user_name=user_name, session=session) - - -@router.get("/users") -def list_users( - email: str = None, - full_name: str = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), -) -> ApiResponse: - """ - List users in the database. - - :param email: The email address to filter by. - :param full_name: The full name to filter by. - :param mode: The output mode. - :param session: The database session. - - :return: The response from the database. - """ - return client.list_users( - email=email, full_name=full_name, output_mode=mode, session=session - ) - - -@router.post("/projects") -def create_project( - project: Project, - session=Depends(get_db), -) -> ApiResponse: - """ - Create a new project in the database. - - :param project: The project to create. - :param session: The database session. - - :return: The response from the database. - """ - return client.create_project(project=project, session=session) - - -@router.get("/projects/{project_name}") -def get_project(project_name: str, session=Depends(get_db)) -> ApiResponse: - """ - Get a project from the database. - - :param project_name: The name of the project to get. - :param session: The database session. - - :return: The project from the database. - """ - return client.get_project(project_name=project_name, session=session) - - -@router.put("/projects/{project_name}") -def update_project( - project: Project, - project_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a project in the database. - - :param project: The project to update. - :param project_name: The name of the project to update. - :param session: The database session. - - :return: The response from the database. - """ - if project_name != project.name: - raise ValueError( - f"Project name does not match: {project_name} != {project.name}" - ) - return client.update_project(project=project, session=session) - - -@router.delete("/projects/{project_name}") -def delete_project(project_name: str, session=Depends(get_db)) -> ApiResponse: - """ - Delete a project from the database. - - :param project_name: The name of the project to delete. - :param session: The database session. - - :return: The response from the database. - """ - return client.delete_project(project_name=project_name, session=session) - - -@router.get("/projects") -def list_projects( - owner_name: str = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), -) -> ApiResponse: - """ - List projects in the database. - - :param owner_name: The name of the owner to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - - :return: The response from the database. - """ - if owner_name is not None: - owner_id = client.get_user(user_name=owner_name, session=session).data["id"] - else: - owner_id = None - return client.list_projects( - owner_id=owner_id, labels_match=labels, output_mode=mode, session=session - ) - - -@router.post("projects/{project_name}/data_sources/{data_source_name}") -def create_data_source( - project_name: str, - data_source_name: str, - data_source: DataSource, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Create a new data source in the database. - - :param project_name: The name of the project to create the data source in. - :param data_source_name: The name of the data source to create. - :param data_source: The data source to create. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # If the owner ID is not provided, get it from the username - if data_source.owner_id is None: - data_source.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - data_source.name = data_source_name - data_source.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_data_source(data_source=data_source, session=session) - - -@router.get("projects/{project_name}/data_sources/{data_source_name}") -def get_data_source( - project_name: str, data_source_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a data source from the database. - - :param project_name: The name of the project to get the data source from. - :param data_source_name: The name of the data source to get. - :param session: The database session. - - :return: The data source from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_data_source( - project_id=project_id, data_source_name=data_source_name, session=session - ) - - -@router.put("projects/{project_name}/data_sources/{data_source_name}") -def update_data_source( - project_name: str, - data_source: DataSource, - data_source_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a data source in the database. - - :param project_name: The name of the project to update the data source in. - :param data_source: The data source to update. - :param data_source_name: The name of the data source to update. - :param session: The database session. - - :return: The response from the database. - """ - data_source.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if data_source_name != data_source.name: - raise ValueError( - f"Data source name does not match: {data_source_name} != {data_source.name}" - ) - return client.update_data_source(data_source=data_source, session=session) - - -@router.delete("projects/{project_name}/data_sources/{data_source_id}") -def delete_data_source( - project_name: str, data_source_id: str, session=Depends(get_db) -) -> ApiResponse: - """ - Delete a data source from the database. - - :param project_name: The name of the project to delete the data source from. - :param data_source_id: The ID of the data source to delete. - :param session: The database session. - - :return: The response from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_data_source( - project_id=project_id, data_source_id=data_source_id, session=session - ) - - -@router.get("projects/{project_name}/data_sources") -def list_data_sources( - project_name: str, - version: str = None, - data_source_type: Union[DataSourceType, str] = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - List data sources in the database. - - :param project_name: The name of the project to list the data sources from. - :param version: The version to filter by. - :param data_source_type: The data source type to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_data_sources( - project_id=project_id, - owner_id=owner_id, - version=version, - data_source_type=data_source_type, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/projects/{project_name}/datasets/{dataset_name}") -def create_dataset( - project_name: str, - dataset_name: str, - dataset: Dataset, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Create a new dataset in the database. - - :param project_name: The name of the project to create the dataset in. - :param dataset_name: The name of the dataset to create. - :param dataset: The dataset to create. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # If the owner ID is not provided, get it from the username - if dataset.owner_id is None: - dataset.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - dataset.name = dataset_name - dataset.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_dataset(dataset=dataset, session=session) - - -@router.get("/projects/{project_name}/datasets/{dataset_name}") -def get_dataset( - project_name: str, dataset_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a dataset from the database. - - :param project_name: The name of the project to get the dataset from. - :param dataset_name: The name of the dataset to get. - :param session: The database session. - - :return: The dataset from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_dataset( - project_id=project_id, dataset_name=dataset_name, session=session - ) - - -@router.put("/projects/{project_name}/datasets/{dataset_name}") -def update_dataset( - project_name: str, - dataset: Dataset, - dataset_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a dataset in the database. - - :param project_name: The name of the project to update the dataset in. - :param dataset: The dataset to update. - :param dataset_name: The name of the dataset to update. - :param session: The database session. - - :return: The response from the database. - """ - dataset.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if dataset_name != dataset.name: - raise ValueError( - f"Dataset name does not match: {dataset_name} != {dataset.name}" - ) - return client.update_dataset(dataset=dataset, session=session) - - -@router.delete("/projects/{project_name}/datasets/{dataset_id}") -def delete_dataset( - project_name: str, dataset_id: str, session=Depends(get_db) -) -> ApiResponse: - """ - Delete a dataset from the database. - - :param project_name: The name of the project to delete the dataset from. - :param dataset_id: The ID of the dataset to delete. - :param session: The database session. - - :return: The response from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_dataset( - project_id=project_id, dataset_id=dataset_id, session=session - ) - - -@router.get("/projects/{project_name}/datasets") -def list_datasets( - project_name: str, - version: str = None, - task: str = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - List datasets in the database. - - :param project_name: The name of the project to list the datasets from. - :param version: The version to filter by. - :param task: The task to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_datasets( - project_id=project_id, - owner_id=owner_id, - version=version, - task=task, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/projects/{project_name}/models/{model_name}") -def create_model( - project_name: str, - model_name: str, - model: Model, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Create a new model in the database. - - :param project_name: The name of the project to create the model in. - :param model_name: The name of the model to create. - :param model: The model to create. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # If the owner ID is not provided, get it from the username - if model.owner_id is None: - model.owner_id = client.get_user(user_name=auth.username, session=session).data[ - "id" - ] - model.name = model_name - model.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_model(model=model, session=session) - - -@router.get("/projects/{project_name}/models/{model_name}") -def get_model( - project_name: str, model_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a model from the database. - - :param project_name: The name of the project to get the model from. - :param model_name: The name of the model to get. - :param session: The database session. - - :return: The model from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_model( - project_id=project_id, model_name=model_name, session=session - ) - - -@router.put("/projects/{project_name}/models/{model_name}") -def update_model( - project_name: str, - model: Model, - model_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a model in the database. - - :param project_name: The name of the project to update the model in. - :param model: The model to update. - :param model_name: The name of the model to update. - :param session: The database session. - - :return: The response from the database. - """ - model.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if model_name != model.name: - raise ValueError(f"Model name does not match: {model_name} != {model.name}") - return client.update_model(model=model, session=session) - - -@router.delete("/projects/{project_name}/models/{model_id}") -def delete_model( - project_name: str, model_id: str, session=Depends(get_db) -) -> ApiResponse: - """ - Delete a model from the database. - - :param project_name: The name of the project to delete the model from. - :param model_id: The ID of the model to delete. - :param session: The database session. - - :return: The response from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_model( - project_id=project_id, model_id=model_id, session=session - ) - - -@router.get("/projects/{project_name}/models") -def list_models( - project_name: str, - version: str = None, - model_type: str = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - List models in the database. - - :param project_name: The name of the project to list the models from. - :param version: The version to filter by. - :param model_type: The model type to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_models( - project_id=project_id, - owner_id=owner_id, - version=version, - model_type=model_type, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/projects/{project_name}/prompt_templates/{prompt_name}") -def create_prompt( - project_name: str, - prompt_name: str, - prompt: PromptTemplate, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Create a new prompt in the database. - - :param project_name: The name of the project to create the prompt in. - :param prompt_name: The name of the prompt to create. - :param prompt: The prompt to create. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # If the owner ID is not provided, get it from the username - if prompt.owner_id is None: - prompt.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - prompt.name = prompt_name - prompt.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_prompt_template(prompt=prompt, session=session) - - -@router.get("/projects/{project_name}/prompt_templates/{prompt_name}") -def get_prompt( - project_name: str, prompt_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a prompt from the database. - - :param project_name: The name of the project to get the prompt from. - :param prompt_name: The name of the prompt to get. - :param session: The database session. - - :return: The prompt from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_prompt( - project_id=project_id, prompt_name=prompt_name, session=session - ) - - -@router.put("/projects/{project_name}/prompt_templates/{prompt_name}") -def update_prompt( - project_name: str, - prompt: PromptTemplate, - prompt_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a prompt in the database. - - :param project_name: The name of the project to update the prompt in. - :param prompt: The prompt to update. - :param prompt_name: The name of the prompt to update. - :param session: The database session. - - :return: The response from the database. - """ - prompt.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if prompt_name != prompt.name: - raise ValueError(f"Prompt name does not match: {prompt_name} != {prompt.name}") - return client.update_prompt_template(prompt=prompt, session=session) - - -@router.delete("/projects/{project_name}/prompt_templates/{prompt_template_id}") -def delete_prompt( - project_name: str, prompt_template_id: str, session=Depends(get_db) -) -> ApiResponse: - """ - Delete a prompt from the database. - - :param project_name: The name of the project to delete the prompt from. - :param prompt_template_id: The ID of the prompt to delete. - :param session: The database session. - - :return: The response from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_prompt_template( - project_id=project_id, prompt_template_id=prompt_template_id, session=session - ) - - -@router.get("/projects/{project_name}/prompt_templates") -def list_prompts( - project_name: str, - version: str = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - List prompts in the database. - - :param project_name: The name of the project to list the prompts from. - :param version: The version to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_prompt_templates( - project_id=project_id, - owner_id=owner_id, - version=version, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/projects/{project_name}/documents/{document_name}") -def create_document( - project_name: str, - document_name: str, - document: Document, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Create a new document in the database. - - :param project_name: The name of the project to create the document in. - :param document_name: The name of the document to create. - :param document: The document to create. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # If the owner ID is not provided, get it from the username - if document.owner_id is None: - document.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - document.name = document_name - document.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_document(document=document, session=session) - - -@router.get("/projects/{project_name}/documents/{document_name}") -def get_document( - project_name: str, document_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a document from the database. - - :param project_name: The name of the project to get the document from. - :param document_name: The name of the document to get. - :param session: The database session. - - :return: The document from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_document( - project_id=project_id, document_name=document_name, session=session - ) - - -@router.put("/projects/{project_name}/documents/{document_name}") -def update_document( - project_name: str, - document: Document, - document_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a document in the database. - - :param project_name: The name of the project to update the document in. - :param document: The document to update. - :param document_name: The name of the document to update. - :param session: The database session. - - :return: The response from the database. - """ - document.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if document_name != document.name: - raise ValueError( - f"Document name does not match: {document_name} != {document.name}" - ) - return client.update_document(document=document, session=session) - - -@router.delete("/projects/{project_name}/documents/{document_id}") -def delete_document( - project_name: str, document_id: str, session=Depends(get_db) -) -> ApiResponse: - """ - Delete a document from the database. - - :param project_name: The name of the project to delete the document from. - :param document_id: The ID of the document to delete. - :param session: The database session. - - :return: The response from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_document( - project_id=project_id, document_id=document_id, session=session - ) - - -@router.get("/projects/{project_name}/documents") -def list_documents( - project_name: str, - version: str = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - List documents in the database. - - :param project_name: The name of the project to list the documents from. - :param version: The version to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_documents( - project_id=project_id, - owner_id=owner_id, - version=version, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/projects/{project_name}/workflows/{workflow_name}") -def create_workflow( - project_name: str, - workflow_name: str, - workflow: Workflow, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Create a new workflow in the database. - - :param project_name: The name of the project to create the workflow in. - :param workflow_name: The name of the workflow to create. - :param workflow: The workflow to create. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # If the owner ID is not provided, get it from the username - if workflow.owner_id is None: - workflow.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - workflow.name = workflow_name - workflow.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_workflow(workflow=workflow, session=session) - - -@router.get("/projects/{project_name}/workflows/{workflow_name}") -def get_workflow( - project_name: str, workflow_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a workflow from the database. - - :param project_name: The name of the project to get the workflow from. - :param workflow_name: The name of the workflow to get. - :param session: The database session. - - :return: The workflow from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_workflow( - project_id=project_id, workflow_name=workflow_name, session=session - ) - - -@router.put("/projects/{project_name}/workflows/{workflow_name}") -def update_workflow( - project_name: str, - workflow: Workflow, - workflow_name: str, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a workflow in the database. - - :param project_name: The name of the project to update the workflow in. - :param workflow: The workflow to update. - :param workflow_name: The name of the workflow to update. - :param session: The database session. - - :return: The response from the database. - """ - workflow.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if workflow_name != workflow.name: - raise ValueError( - f"Workflow name does not match: {workflow_name} != {workflow.name}" - ) - return client.update_workflow(workflow=workflow, session=session) - - -@router.delete("/projects/{project_name}/workflows/{workflow_id}") -def delete_workflow( - project_name: str, workflow_id: str, session=Depends(get_db) -) -> ApiResponse: - """ - Delete a workflow from the database. - - :param project_name: The name of the project to delete the workflow from. - :param workflow_id: The ID of the workflow to delete. - :param session: The database session. - - :return: The response from the database. - """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_workflow( - project_id=project_id, workflow_id=workflow_id, session=session - ) - - -@router.get("/projects/{project_name}/workflows") -def list_workflows( - project_name: str, - version: str = None, - labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - List workflows in the database. - - :param project_name: The name of the project to list the workflows from. - :param version: The version to filter by. - :param labels: The labels to filter by. - :param mode: The output mode. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_workflows( - project_id=project_id, - owner_id=owner_id, - version=version, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/users/{user_name}/sessions/{session_name}") -def create_session( - user_name: str, - session_name: str, - chat_session: ChatSession, - session=Depends(get_db), -) -> ApiResponse: - """ - Create a new session in the database. - - :param user_name: The name of the user to create the session for. - :param session_name: The name of the session to create. - :param chat_session: The session to create. - :param session: The database session. - - :return: The response from the database. - """ - chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ - "id" - ] - return client.create_chat_session(chat_session=chat_session, session=session) - - -@router.get("/users/{user_name}/sessions/{session_name}") -def get_session( - user_name: str, session_name: str, session=Depends(get_db) -) -> ApiResponse: - """ - Get a session from the database. If the session ID is "$last", get the last session for the user. - - :param user_name: The name of the user to get the session for. - :param session_name: The name of the session to get. - :param session: The database session. - - :return: The session from the database. - """ - user_id = None - if session_name == "$last": - user_id = client.get_user(user_name=user_name, session=session).data["id"] - session_name = None - return client.get_chat_session( - session_name=session_name, user_id=user_id, session=session - ) - - -@router.put("/users/{user_name}/sessions/{session_name}") -def update_session( - user_name: str, - chat_session: ChatSession, - session=Depends(get_db), -) -> ApiResponse: - """ - Update a session in the database. - - :param user_name: The name of the user to update the session for. - :param chat_session: The session to update. - :param session: The database session. - - :return: The response from the database. - """ - chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ - "id" - ] - return client.update_chat_session(chat_session=chat_session, session=session) - - -@router.get("/users/{user_name}/sessions") -def list_sessions( - user_name: str, - last: int = 0, - created: str = None, - workflow_id: str = None, - mode: OutputMode = OutputMode.Details, - session=Depends(get_db), -) -> ApiResponse: - """ - List sessions in the database. - - :param user_name: The name of the user to list the sessions for. - :param last: The number of sessions to get. - :param created: The date to filter by. - :param workflow_id: The ID of the workflow to filter by. - :param mode: The output mode. - :param session: The database session. - - :return: The response from the database. - """ - user_id = client.get_user(user_name=user_name, session=session).data["id"] - return client.list_chat_sessions( - user_id=user_id, - last=last, - created_after=created, - workflow_id=workflow_id, - output_mode=mode, - session=session, - ) - - -@router.post("/projects/{project_name}/workflows/{workflow_name}/infer") -def infer_workflow( - project_name: str, - workflow_name: str, - query: QueryItem, - session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: - """ - Run application workflow. - - :param project_name: The name of the project to run the workflow in. - :param workflow_name: The name of the workflow to run. - :param query: The query to run. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the database. - """ - # Get workflow from the database - workflow = Workflow.from_dict( - get_workflow(project_name, workflow_name, session).data - ) - path = Workflow.get_infer_path(workflow) - - data = { - "item": query.dict(), - "workflow": workflow.to_dict(short=True), - } - - # Sent the event to the application's workflow: - return _send_to_application( - path=path, - method="POST", - data=json.dumps(data), - auth=auth, - ) - - -# @router.post("/pipeline/{name}/run") -# def run_pipeline( -# request: Request, item: QueryItem, name: str, auth=Depends(get_auth_user) -# ): -# """This is the query command""" -# -# return _send_to_application( -# path=f"pipeline/{name}/run", -# method="POST", -# request=request, -# auth=auth, -# ) - - -@router.post("/projects/{project_name}/data_sources/{data_source_name}/ingest") -def ingest( - project_name, - data_source_name, - loader: str, - path: str, - metadata=None, - version: str = None, - from_file: bool = False, - session=Depends(get_db), - auth=Depends(get_auth_user), -): - """ - Ingest document into the vector database. - - :param project_name: The name of the project to ingest the documents into. - :param data_source_name: The name of the data source to ingest the documents into. - :param loader: The data loader type to use. - :param path: The path to the document to ingest. - :param metadata: The metadata to associate with the documents. - :param version: The version of the documents. - :param from_file: Whether the documents are from a file. - :param session: The database session. - :param auth: The authentication information. - - :return: The response from the application. - """ - data_source = get_data_source( - project_name=project_name, data_source_name=data_source_name, session=session - ).data - data_source = DataSource.from_dict(data_source) - - # Create document from path: - document = Document( - name=os.path.basename(path), - version=version, - path=path, - owner_id=data_source.owner_id, - ) - - # Add document to the database: - document = create_document( - project_name=project_name, - document_name=document.name, - document=document, - session=session, - auth=auth, - ).data - - # Send ingest to application: - params = { - "loader": loader, - "from_file": from_file, - } - - data = { - "document": document, - "database_kwargs": data_source.database_kwargs, - } - if metadata is not None: - params["metadata"] = json.dumps(metadata) - - return _send_to_application( - path=f"data_sources/{data_source_name}/ingest", - method="POST", - data=json.dumps(data), - params=params, - auth=auth, - ) - - -# Include the router in the main app -app.include_router(router) diff --git a/controller/src/api/__init__.py b/controller/src/api/__init__.py new file mode 100644 index 0000000..84f1919 --- /dev/null +++ b/controller/src/api/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/controller/src/api/api.py b/controller/src/api/api.py new file mode 100644 index 0000000..ea58d21 --- /dev/null +++ b/controller/src/api/api.py @@ -0,0 +1,90 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from controller.src.api.endpoints import ( + base, + data_sources, + datasets, + documents, + models, + projects, + prompt_templates, + sessions, + users, + workflows, +) + +app = FastAPI() + +# Add CORS middleware, remove in production +origins = ["*"] # React app +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Create a router with a prefix +api_router = APIRouter(prefix="/api") + + +# Include the routers for the different endpoints +api_router.include_router( + base.router, + tags=["base"], +) +api_router.include_router( + users.router, + tags=["users"], +) +api_router.include_router( + projects.router, + tags=["projects"], +) +api_router.include_router( + data_sources.router, + tags=["data_sources"], +) +api_router.include_router( + datasets.router, + tags=["datasets"], +) +api_router.include_router( + models.router, + tags=["models"], +) +api_router.include_router( + prompt_templates.router, + tags=["prompt_templates"], +) +api_router.include_router( + documents.router, + tags=["documents"], +) +api_router.include_router( + workflows.router, + tags=["workflows"], +) +api_router.include_router( + sessions.router, + tags=["chat_sessions"], +) + +# Include the router in the main app +app.include_router(api_router) diff --git a/controller/src/api/endpoints/__init__.py b/controller/src/api/endpoints/__init__.py new file mode 100644 index 0000000..84f1919 --- /dev/null +++ b/controller/src/api/endpoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/controller/src/api/endpoints/base.py b/controller/src/api/endpoints/base.py new file mode 100644 index 0000000..d95279e --- /dev/null +++ b/controller/src/api/endpoints/base.py @@ -0,0 +1,24 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import APIRouter + +from controller.src.db import client + +router = APIRouter() + + +@router.post("/tables") +def create_tables(drop_old: bool = False, names: list[str] = None): + return client.create_tables(drop_old=drop_old, names=names) diff --git a/controller/src/api/endpoints/data_sources.py b/controller/src/api/endpoints/data_sources.py new file mode 100644 index 0000000..542ece6 --- /dev/null +++ b/controller/src/api/endpoints/data_sources.py @@ -0,0 +1,243 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from typing import List, Optional, Tuple, Union + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import ( + AuthInfo, + _send_to_application, + get_auth_user, + get_db, +) +from controller.src.db import client +from controller.src.schemas import ( + ApiResponse, + DataSource, + DataSourceType, + Document, + OutputMode, +) + +router = APIRouter(prefix="/projects/{project_name}") + + +@router.post("/data_sources/{data_source_name}") +def create_data_source( + project_name: str, + data_source_name: str, + data_source: DataSource, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new data source in the database. + + :param project_name: The name of the project to create the data source in. + :param data_source_name: The name of the data source to create. + :param data_source: The data source to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if data_source.owner_id is None: + data_source.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + data_source.name = data_source_name + data_source.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_data_source(data_source=data_source, session=session) + + +@router.get("/data_sources/{data_source_name}") +def get_data_source( + project_name: str, data_source_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a data source from the database. + + :param project_name: The name of the project to get the data source from. + :param data_source_name: The name of the data source to get. + :param session: The database session. + + :return: The data source from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_data_source( + project_id=project_id, data_source_name=data_source_name, session=session + ) + + +@router.put("/data_sources/{data_source_name}") +def update_data_source( + project_name: str, + data_source: DataSource, + data_source_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a data source in the database. + + :param project_name: The name of the project to update the data source in. + :param data_source: The data source to update. + :param data_source_name: The name of the data source to update. + :param session: The database session. + + :return: The response from the database. + """ + data_source.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if data_source_name != data_source.name: + raise ValueError( + f"Data source name does not match: {data_source_name} != {data_source.name}" + ) + return client.update_data_source(data_source=data_source, session=session) + + +@router.delete("/data_sources/{data_source_id}") +def delete_data_source( + project_name: str, data_source_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a data source from the database. + + :param project_name: The name of the project to delete the data source from. + :param data_source_id: The ID of the data source to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_data_source( + project_id=project_id, data_source_id=data_source_id, session=session + ) + + +@router.get("/data_sources") +def list_data_sources( + project_name: str, + version: str = None, + data_source_type: Union[DataSourceType, str] = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List data sources in the database. + + :param project_name: The name of the project to list the data sources from. + :param version: The version to filter by. + :param data_source_type: The data source type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_data_sources( + project_id=project_id, + owner_id=owner_id, + version=version, + data_source_type=data_source_type, + labels_match=labels, + output_mode=mode, + session=session, + ) + + +@router.post("/data_sources/{data_source_name}/ingest") +def ingest( + project_name, + data_source_name, + loader: str, + path: str, + metadata=None, + version: str = None, + from_file: bool = False, + session=Depends(get_db), + auth=Depends(get_auth_user), +): + """ + Ingest document into the vector database. + + :param project_name: The name of the project to ingest the documents into. + :param data_source_name: The name of the data source to ingest the documents into. + :param loader: The data loader type to use. + :param path: The path to the document to ingest. + :param metadata: The metadata to associate with the documents. + :param version: The version of the documents. + :param from_file: Whether the documents are from a file. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the application. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + data_source = client.get_data_source( + project_id=project_id, data_source_name=data_source_name, session=session + ).data + + data_source = DataSource.from_dict(data_source) + + # Create document from path: + document = Document( + name=os.path.basename(path), + version=version, + path=path, + owner_id=data_source.owner_id, + project_id=project_id, + ) + + # Add document to the database: + document = client.create_document(document=document, session=session).data + + # Send ingest to application: + params = { + "loader": loader, + "from_file": from_file, + } + + data = { + "document": document, + "database_kwargs": data_source.database_kwargs, + } + if metadata is not None: + params["metadata"] = json.dumps(metadata) + + return _send_to_application( + path=f"data_sources/{data_source_name}/ingest", + method="POST", + data=json.dumps(data), + params=params, + auth=auth, + ) diff --git a/controller/src/api/endpoints/datasets.py b/controller/src/api/endpoints/datasets.py new file mode 100644 index 0000000..0dd2065 --- /dev/null +++ b/controller/src/api/endpoints/datasets.py @@ -0,0 +1,161 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import AuthInfo, get_auth_user, get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, Dataset, OutputMode + +router = APIRouter(prefix="/projects/{project_name}") + + +@router.post("/datasets/{dataset_name}") +def create_dataset( + project_name: str, + dataset_name: str, + dataset: Dataset, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new dataset in the database. + + :param project_name: The name of the project to create the dataset in. + :param dataset_name: The name of the dataset to create. + :param dataset: The dataset to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if dataset.owner_id is None: + dataset.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + dataset.name = dataset_name + dataset.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_dataset(dataset=dataset, session=session) + + +@router.get("/datasets/{dataset_name}") +def get_dataset( + project_name: str, dataset_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a dataset from the database. + + :param project_name: The name of the project to get the dataset from. + :param dataset_name: The name of the dataset to get. + :param session: The database session. + + :return: The dataset from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_dataset( + project_id=project_id, dataset_name=dataset_name, session=session + ) + + +@router.put("/datasets/{dataset_name}") +def update_dataset( + project_name: str, + dataset: Dataset, + dataset_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a dataset in the database. + + :param project_name: The name of the project to update the dataset in. + :param dataset: The dataset to update. + :param dataset_name: The name of the dataset to update. + :param session: The database session. + + :return: The response from the database. + """ + dataset.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if dataset_name != dataset.name: + raise ValueError( + f"Dataset name does not match: {dataset_name} != {dataset.name}" + ) + return client.update_dataset(dataset=dataset, session=session) + + +@router.delete("/datasets/{dataset_id}") +def delete_dataset( + project_name: str, dataset_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a dataset from the database. + + :param project_name: The name of the project to delete the dataset from. + :param dataset_id: The ID of the dataset to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_dataset( + project_id=project_id, dataset_id=dataset_id, session=session + ) + + +@router.get("/datasets") +def list_datasets( + project_name: str, + version: str = None, + task: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List datasets in the database. + + :param project_name: The name of the project to list the datasets from. + :param version: The version to filter by. + :param task: The task to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_datasets( + project_id=project_id, + owner_id=owner_id, + version=version, + task=task, + labels_match=labels, + output_mode=mode, + session=session, + ) diff --git a/controller/src/api/endpoints/documents.py b/controller/src/api/endpoints/documents.py new file mode 100644 index 0000000..80f83ef --- /dev/null +++ b/controller/src/api/endpoints/documents.py @@ -0,0 +1,158 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import AuthInfo, get_auth_user, get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, Document, OutputMode + +router = APIRouter(prefix="/projects/{project_name}") + + +@router.post("/documents/{document_name}") +def create_document( + project_name: str, + document_name: str, + document: Document, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new document in the database. + + :param project_name: The name of the project to create the document in. + :param document_name: The name of the document to create. + :param document: The document to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if document.owner_id is None: + document.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + document.name = document_name + document.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_document(document=document, session=session) + + +@router.get("/documents/{document_name}") +def get_document( + project_name: str, document_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a document from the database. + + :param project_name: The name of the project to get the document from. + :param document_name: The name of the document to get. + :param session: The database session. + + :return: The document from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_document( + project_id=project_id, document_name=document_name, session=session + ) + + +@router.put("/documents/{document_name}") +def update_document( + project_name: str, + document: Document, + document_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a document in the database. + + :param project_name: The name of the project to update the document in. + :param document: The document to update. + :param document_name: The name of the document to update. + :param session: The database session. + + :return: The response from the database. + """ + document.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if document_name != document.name: + raise ValueError( + f"Document name does not match: {document_name} != {document.name}" + ) + return client.update_document(document=document, session=session) + + +@router.delete("/documents/{document_id}") +def delete_document( + project_name: str, document_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a document from the database. + + :param project_name: The name of the project to delete the document from. + :param document_id: The ID of the document to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_document( + project_id=project_id, document_id=document_id, session=session + ) + + +@router.get("/documents") +def list_documents( + project_name: str, + version: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List documents in the database. + + :param project_name: The name of the project to list the documents from. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_documents( + project_id=project_id, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) diff --git a/controller/src/api/endpoints/models.py b/controller/src/api/endpoints/models.py new file mode 100644 index 0000000..4086e55 --- /dev/null +++ b/controller/src/api/endpoints/models.py @@ -0,0 +1,159 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import AuthInfo, get_auth_user, get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, Model, OutputMode + +router = APIRouter(prefix="/projects/{project_name}") + + +@router.post("/models/{model_name}") +def create_model( + project_name: str, + model_name: str, + model: Model, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new model in the database. + + :param project_name: The name of the project to create the model in. + :param model_name: The name of the model to create. + :param model: The model to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if model.owner_id is None: + model.owner_id = client.get_user(user_name=auth.username, session=session).data[ + "id" + ] + model.name = model_name + model.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_model(model=model, session=session) + + +@router.get("/models/{model_name}") +def get_model( + project_name: str, model_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a model from the database. + + :param project_name: The name of the project to get the model from. + :param model_name: The name of the model to get. + :param session: The database session. + + :return: The model from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_model( + project_id=project_id, model_name=model_name, session=session + ) + + +@router.put("/models/{model_name}") +def update_model( + project_name: str, + model: Model, + model_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a model in the database. + + :param project_name: The name of the project to update the model in. + :param model: The model to update. + :param model_name: The name of the model to update. + :param session: The database session. + + :return: The response from the database. + """ + model.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if model_name != model.name: + raise ValueError(f"Model name does not match: {model_name} != {model.name}") + return client.update_model(model=model, session=session) + + +@router.delete("/models/{model_id}") +def delete_model( + project_name: str, model_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a model from the database. + + :param project_name: The name of the project to delete the model from. + :param model_id: The ID of the model to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_model( + project_id=project_id, model_id=model_id, session=session + ) + + +@router.get("/models") +def list_models( + project_name: str, + version: str = None, + model_type: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List models in the database. + + :param project_name: The name of the project to list the models from. + :param version: The version to filter by. + :param model_type: The model type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_models( + project_id=project_id, + owner_id=owner_id, + version=version, + model_type=model_type, + labels_match=labels, + output_mode=mode, + session=session, + ) diff --git a/controller/src/api/endpoints/projects.py b/controller/src/api/endpoints/projects.py new file mode 100644 index 0000000..6a5343b --- /dev/null +++ b/controller/src/api/endpoints/projects.py @@ -0,0 +1,113 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, OutputMode, Project + +router = APIRouter() + + +@router.post("/projects") +def create_project( + project: Project, + session=Depends(get_db), +) -> ApiResponse: + """ + Create a new project in the database. + + :param project: The project to create. + :param session: The database session. + + :return: The response from the database. + """ + return client.create_project(project=project, session=session) + + +@router.get("/projects/{project_name}") +def get_project(project_name: str, session=Depends(get_db)) -> ApiResponse: + """ + Get a project from the database. + + :param project_name: The name of the project to get. + :param session: The database session. + + :return: The project from the database. + """ + return client.get_project(project_name=project_name, session=session) + + +@router.put("/projects/{project_name}") +def update_project( + project: Project, + project_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a project in the database. + + :param project: The project to update. + :param project_name: The name of the project to update. + :param session: The database session. + + :return: The response from the database. + """ + if project_name != project.name: + raise ValueError( + f"Project name does not match: {project_name} != {project.name}" + ) + return client.update_project(project=project, session=session) + + +@router.delete("/projects/{project_name}") +def delete_project(project_name: str, session=Depends(get_db)) -> ApiResponse: + """ + Delete a project from the database. + + :param project_name: The name of the project to delete. + :param session: The database session. + + :return: The response from the database. + """ + return client.delete_project(project_name=project_name, session=session) + + +@router.get("/projects") +def list_projects( + owner_name: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), +) -> ApiResponse: + """ + List projects in the database. + + :param owner_name: The name of the owner to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + + :return: The response from the database. + """ + if owner_name is not None: + owner_id = client.get_user(user_name=owner_name, session=session).data["id"] + else: + owner_id = None + return client.list_projects( + owner_id=owner_id, labels_match=labels, output_mode=mode, session=session + ) diff --git a/controller/src/api/endpoints/prompt_templates.py b/controller/src/api/endpoints/prompt_templates.py new file mode 100644 index 0000000..5e8a2f6 --- /dev/null +++ b/controller/src/api/endpoints/prompt_templates.py @@ -0,0 +1,156 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import AuthInfo, get_auth_user, get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, OutputMode, PromptTemplate + +router = APIRouter(prefix="/projects/{project_name}") + + +@router.post("/prompt_templates/{prompt_name}") +def create_prompt( + project_name: str, + prompt_name: str, + prompt: PromptTemplate, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new prompt in the database. + + :param project_name: The name of the project to create the prompt in. + :param prompt_name: The name of the prompt to create. + :param prompt: The prompt to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if prompt.owner_id is None: + prompt.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + prompt.name = prompt_name + prompt.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_prompt_template(prompt=prompt, session=session) + + +@router.get("/prompt_templates/{prompt_name}") +def get_prompt( + project_name: str, prompt_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a prompt from the database. + + :param project_name: The name of the project to get the prompt from. + :param prompt_name: The name of the prompt to get. + :param session: The database session. + + :return: The prompt from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_prompt( + project_id=project_id, prompt_name=prompt_name, session=session + ) + + +@router.put("/prompt_templates/{prompt_name}") +def update_prompt( + project_name: str, + prompt: PromptTemplate, + prompt_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a prompt in the database. + + :param project_name: The name of the project to update the prompt in. + :param prompt: The prompt to update. + :param prompt_name: The name of the prompt to update. + :param session: The database session. + + :return: The response from the database. + """ + prompt.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if prompt_name != prompt.name: + raise ValueError(f"Prompt name does not match: {prompt_name} != {prompt.name}") + return client.update_prompt_template(prompt=prompt, session=session) + + +@router.delete("/prompt_templates/{prompt_template_id}") +def delete_prompt( + project_name: str, prompt_template_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a prompt from the database. + + :param project_name: The name of the project to delete the prompt from. + :param prompt_template_id: The ID of the prompt to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_prompt_template( + project_id=project_id, prompt_template_id=prompt_template_id, session=session + ) + + +@router.get("/prompt_templates") +def list_prompts( + project_name: str, + version: str = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List prompts in the database. + + :param project_name: The name of the project to list the prompts from. + :param version: The version to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_prompt_templates( + project_id=project_id, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) diff --git a/controller/src/api/endpoints/sessions.py b/controller/src/api/endpoints/sessions.py new file mode 100644 index 0000000..9ceb2f8 --- /dev/null +++ b/controller/src/api/endpoints/sessions.py @@ -0,0 +1,119 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, ChatSession, OutputMode + +router = APIRouter(prefix="/users/{user_name}") + + +@router.post("/sessions/{session_name}") +def create_session( + user_name: str, + session_name: str, + chat_session: ChatSession, + session=Depends(get_db), +) -> ApiResponse: + """ + Create a new session in the database. + + :param user_name: The name of the user to create the session for. + :param session_name: The name of the session to create. + :param chat_session: The session to create. + :param session: The database session. + + :return: The response from the database. + """ + chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ + "id" + ] + return client.create_chat_session(chat_session=chat_session, session=session) + + +@router.get("/sessions/{session_name}") +def get_session( + user_name: str, session_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a session from the database. If the session ID is "$last", get the last session for the user. + + :param user_name: The name of the user to get the session for. + :param session_name: The name of the session to get. + :param session: The database session. + + :return: The session from the database. + """ + user_id = None + if session_name == "$last": + user_id = client.get_user(user_name=user_name, session=session).data["id"] + session_name = None + return client.get_chat_session( + session_name=session_name, user_id=user_id, session=session + ) + + +@router.put("/sessions/{session_name}") +def update_session( + user_name: str, + chat_session: ChatSession, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a session in the database. + + :param user_name: The name of the user to update the session for. + :param chat_session: The session to update. + :param session: The database session. + + :return: The response from the database. + """ + chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ + "id" + ] + return client.update_chat_session(chat_session=chat_session, session=session) + + +@router.get("/sessions") +def list_sessions( + user_name: str, + last: int = 0, + created: str = None, + workflow_id: str = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), +) -> ApiResponse: + """ + List sessions in the database. + + :param user_name: The name of the user to list the sessions for. + :param last: The number of sessions to get. + :param created: The date to filter by. + :param workflow_id: The ID of the workflow to filter by. + :param mode: The output mode. + :param session: The database session. + + :return: The response from the database. + """ + user_id = client.get_user(user_name=user_name, session=session).data["id"] + return client.list_chat_sessions( + user_id=user_id, + last=last, + created_after=created, + workflow_id=workflow_id, + output_mode=mode, + session=session, + ) diff --git a/controller/src/api/endpoints/users.py b/controller/src/api/endpoints/users.py new file mode 100644 index 0000000..0117c76 --- /dev/null +++ b/controller/src/api/endpoints/users.py @@ -0,0 +1,106 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import get_db +from controller.src.db import client +from controller.src.schemas import ApiResponse, OutputMode, User + +router = APIRouter() + + +@router.post("/users") +def create_user( + user: User, + session=Depends(get_db), +) -> ApiResponse: + """ + Create a new user in the database. + + :param user: The user to create. + :param session: The database session. + + :return: The response from the database. + """ + return client.create_user(user=user, session=session) + + +@router.get("/users/{user_name}") +def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> ApiResponse: + """ + Get a user from the database. + + :param user_name: The name of the user to get. + :param email: The email address to get the user by if the name is not provided. + :param session: The database session. + + :return: The user from the database. + """ + return client.get_user(user_name=user_name, email=email, session=session) + + +@router.put("/users/{user_name}") +def update_user( + user: User, + user_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a user in the database. + + :param user: The user to update. + :param user_name: The name of the user to update. + :param session: The database session. + + :return: The response from the database. + """ + if user_name != user.name: + raise ValueError(f"User name does not match: {user_name} != {user.name}") + return client.update_user(user=user, session=session) + + +@router.delete("/users/{user_name}") +def delete_user(user_name: str, session=Depends(get_db)) -> ApiResponse: + """ + Delete a user from the database. + + :param user_name: The name of the user to delete. + :param session: The database session. + + :return: The response from the database. + """ + return client.delete_user(user_name=user_name, session=session) + + +@router.get("/users/users") +def list_users( + email: str = None, + full_name: str = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), +) -> ApiResponse: + """ + List users in the database. + + :param email: The email address to filter by. + :param full_name: The full name to filter by. + :param mode: The output mode. + :param session: The database session. + + :return: The response from the database. + """ + return client.list_users( + email=email, full_name=full_name, output_mode=mode, session=session + ) diff --git a/controller/src/api/endpoints/workflows.py b/controller/src/api/endpoints/workflows.py new file mode 100644 index 0000000..2e96652 --- /dev/null +++ b/controller/src/api/endpoints/workflows.py @@ -0,0 +1,211 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import List, Optional, Tuple, Union + +from fastapi import APIRouter, Depends + +from controller.src.api.utils import ( + AuthInfo, + _send_to_application, + get_auth_user, + get_db, +) +from controller.src.db import client +from controller.src.schemas import ( + ApiResponse, + OutputMode, + QueryItem, + Workflow, + WorkflowType, +) + +router = APIRouter(prefix="/projects/{project_name}") + + +@router.post("/workflows/{workflow_name}") +def create_workflow( + project_name: str, + workflow_name: str, + workflow: Workflow, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Create a new workflow in the database. + + :param project_name: The name of the project to create the workflow in. + :param workflow_name: The name of the workflow to create. + :param workflow: The workflow to create. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # If the owner ID is not provided, get it from the username + if workflow.owner_id is None: + workflow.owner_id = client.get_user( + user_name=auth.username, session=session + ).data["id"] + workflow.name = workflow_name + workflow.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + return client.create_workflow(workflow=workflow, session=session) + + +@router.get("/workflows/{workflow_name}") +def get_workflow( + project_name: str, workflow_name: str, session=Depends(get_db) +) -> ApiResponse: + """ + Get a workflow from the database. + + :param project_name: The name of the project to get the workflow from. + :param workflow_name: The name of the workflow to get. + :param session: The database session. + + :return: The workflow from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.get_workflow( + project_id=project_id, workflow_name=workflow_name, session=session + ) + + +@router.put("/workflows/{workflow_name}") +def update_workflow( + project_name: str, + workflow: Workflow, + workflow_name: str, + session=Depends(get_db), +) -> ApiResponse: + """ + Update a workflow in the database. + + :param project_name: The name of the project to update the workflow in. + :param workflow: The workflow to update. + :param workflow_name: The name of the workflow to update. + :param session: The database session. + + :return: The response from the database. + """ + workflow.project_id = client.get_project( + project_name=project_name, session=session + ).data["id"] + if workflow_name != workflow.name: + raise ValueError( + f"Workflow name does not match: {workflow_name} != {workflow.name}" + ) + return client.update_workflow(workflow=workflow, session=session) + + +@router.delete("/workflows/{workflow_id}") +def delete_workflow( + project_name: str, workflow_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a workflow from the database. + + :param project_name: The name of the project to delete the workflow from. + :param workflow_id: The ID of the workflow to delete. + :param session: The database session. + + :return: The response from the database. + """ + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.delete_workflow( + project_id=project_id, workflow_id=workflow_id, session=session + ) + + +@router.get("/workflows") +def list_workflows( + project_name: str, + version: str = None, + workflow_type: Union[WorkflowType, str] = None, + labels: Optional[List[Tuple[str, str]]] = None, + mode: OutputMode = OutputMode.Details, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + List workflows in the database. + + :param project_name: The name of the project to list the workflows from. + :param version: The version to filter by. + :param workflow_type: The workflow type to filter by. + :param labels: The labels to filter by. + :param mode: The output mode. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + owner_id = client.get_user(user_name=auth.username, session=session).data["id"] + project_id = client.get_project(project_name=project_name, session=session).data[ + "id" + ] + return client.list_workflows( + project_id=project_id, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) + + +@router.post("/workflows/{workflow_name}/infer") +def infer_workflow( + project_name: str, + workflow_name: str, + query: QueryItem, + session=Depends(get_db), + auth: AuthInfo = Depends(get_auth_user), +) -> ApiResponse: + """ + Run application workflow. + + :param project_name: The name of the project to run the workflow in. + :param workflow_name: The name of the workflow to run. + :param query: The query to run. + :param session: The database session. + :param auth: The authentication information. + + :return: The response from the database. + """ + # Get workflow from the database + workflow = Workflow.from_dict( + get_workflow(project_name, workflow_name, session).data + ) + path = Workflow.get_infer_path(workflow) + + data = { + "item": query.dict(), + "workflow": workflow.to_dict(short=True), + } + + # Sent the event to the application's workflow: + return _send_to_application( + path=path, + method="POST", + data=json.dumps(data), + auth=auth, + ) diff --git a/controller/src/api/utils.py b/controller/src/api/utils.py new file mode 100644 index 0000000..e076940 --- /dev/null +++ b/controller/src/api/utils.py @@ -0,0 +1,90 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import requests +from fastapi import Header, Request +from pydantic import BaseModel + +from controller.src.config import config +from controller.src.db import client + + +def get_db(): + db_session = None + try: + db_session = client.get_local_session() + yield db_session + finally: + if db_session: + db_session.close() + + +class AuthInfo(BaseModel): + username: str + token: str + roles: List[str] = [] + + +# placeholder for extracting the Auth info from the request +def get_auth_user( + request: Request, x_username: Union[str, None] = Header(None) +) -> AuthInfo: + """Get the user from the database""" + token = request.cookies.get("Authorization", "") + if x_username: + return AuthInfo(username=x_username, token=token) + else: + return AuthInfo(username="guest@example.com", token=token) + + +def _send_to_application( + path: str, method: str = "POST", request=None, auth=None, **kwargs +): + """ + Send a request to the application's API. + + :param path: The API path to send the request to. + :param method: The HTTP method to use: GET, POST, PUT, DELETE, etc. + :param request: The FastAPI request object. If provided, the data will be taken from the body of the request. + :param auth: The authentication information to use. If provided, the username will be added to the headers. + :param kwargs: Additional keyword arguments to pass in the request function. For example, headers, params, etc. + + :return: The JSON response from the application. + """ + if config.application_url not in path: + url = f"{config.application_url}/api/{path}" + else: + url = path + + if isinstance(request, Request): + # If the request is a FastAPI request, get the data from the body + kwargs["data"] = request._body.decode("utf-8") + if auth is not None: + kwargs["headers"] = {"x_username": auth.username} + + response = requests.request( + method=method, + url=url, + **kwargs, + ) + + # Check the response + if response.status_code == 200: + # If the request was successful, return the JSON response + return response.json() + else: + # If the request failed, raise an exception + response.raise_for_status() diff --git a/controller/src/config.py b/controller/src/config.py index ea289c1..f80342b 100644 --- a/controller/src/config.py +++ b/controller/src/config.py @@ -35,6 +35,7 @@ class CtrlConfig(BaseModel): verbose: bool = True log_level: str = "DEBUG" # SQL Database + db_type = "sql" sql_connection_str: str = default_db_path application_url: str = "http://localhost:8000" diff --git a/controller/src/db/__init__.py b/controller/src/db/__init__.py new file mode 100644 index 0000000..28da560 --- /dev/null +++ b/controller/src/db/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from controller.src.config import config + +from .sqlclient import SqlClient + +client = None + +if config.db_type == "sql": + client = SqlClient(config.sql_connection_str, verbose=config.verbose) diff --git a/controller/src/sqlclient.py b/controller/src/db/sqlclient.py similarity index 98% rename from controller/src/sqlclient.py rename to controller/src/db/sqlclient.py index bf974bd..f15fe78 100644 --- a/controller/src/sqlclient.py +++ b/controller/src/db/sqlclient.py @@ -19,11 +19,10 @@ import sqlalchemy from sqlalchemy.orm import sessionmaker -import controller.src.sqldb as db -from controller.src import model as api_models -from controller.src.config import config, logger -from controller.src.model import ApiResponse -from controller.src.sqldb import Base +import controller.src.db.sqldb as db +import controller.src.schemas as api_models +from controller.src.config import logger +from controller.src.schemas import ApiResponse class SqlClient: @@ -70,10 +69,10 @@ def create_tables(self, drop_old: bool = False, names: list = None) -> ApiRespon """ tables = None if names: - tables = [Base.metadata.tables[name] for name in names] + tables = [db.Base.metadata.tables[name] for name in names] if drop_old: - Base.metadata.drop_all(self.engine, tables=tables) - Base.metadata.create_all(self.engine, tables=tables, checkfirst=True) + db.Base.metadata.drop_all(self.engine, tables=tables) + db.Base.metadata.create_all(self.engine, tables=tables, checkfirst=True) return ApiResponse(success=True) def _create(self, session: sqlalchemy.orm.Session, db_class, obj) -> ApiResponse: @@ -1294,6 +1293,3 @@ def _process_output( return items short = mode == api_models.OutputMode.Short return [item.to_dict(short=short) for item in items] - - -client = SqlClient(config.sql_connection_str, verbose=config.verbose) diff --git a/controller/src/sqldb.py b/controller/src/db/sqldb.py similarity index 100% rename from controller/src/sqldb.py rename to controller/src/db/sqldb.py diff --git a/controller/src/main.py b/controller/src/main.py index 06b8a99..1cac4fe 100644 --- a/controller/src/main.py +++ b/controller/src/main.py @@ -22,7 +22,8 @@ import controller.src.api as api from controller.src.config import config -from controller.src.model import ( +from controller.src.db import client +from controller.src.schemas import ( DataSource, Document, Project, @@ -30,7 +31,6 @@ User, Workflow, ) -from controller.src.sqlclient import client @click.group() @@ -108,8 +108,8 @@ def print_config(): @click.command() -@click.argument("project", type=str) @click.argument("path", type=str) +@click.option("-p", "--project", type=str, help="Project name", default="default") @click.option("-n", "--name", type=str, help="Document name", default=None) @click.option("-l", "--loader", type=str, help="Type of data loader") @click.option( @@ -120,12 +120,12 @@ def print_config(): @click.option( "-f", "--from-file", is_flag=True, help="Take the document paths from the file" ) -def ingest(project, path, name, loader, metadata, version, data_source, from_file): +def ingest(path, project, name, loader, metadata, version, data_source, from_file): """ Ingest data into the data source. - :param project: The project name to which the document belongs. :param path: Path to the document + :param project: The project name to which the document belongs. :param name: Name of the document :param loader: Type of data loader, web, .csv, .md, .pdf, .txt, etc. :param metadata: Metadata Key value pair labels @@ -201,7 +201,7 @@ def ingest(project, path, name, loader, metadata, version, data_source, from_fil help="Search filter Key value pair", ) @click.option("-c", "--data-source", type=str, help="Data Source name") -@click.option("-u", "--user", type=str, help="Username") +@click.option("-u", "--user", type=str, help="Username", default="guest") @click.option("-s", "--session", type=str, help="Session ID") def infer( question: str, project: str, workflow_name: str, filter, data_source, user, session diff --git a/controller/src/schemas/__init__.py b/controller/src/schemas/__init__.py new file mode 100644 index 0000000..61289f2 --- /dev/null +++ b/controller/src/schemas/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ApiResponse, Base, OutputMode +from .data_source import DataSource, DataSourceType +from .dataset import Dataset +from .document import Document +from .model import Model, ModelType +from .project import Project +from .prompt_template import PromptTemplate +from .session import ChatSession, QueryItem +from .user import User +from .workflow import Workflow, WorkflowType diff --git a/controller/src/model.py b/controller/src/schemas/base.py similarity index 57% rename from controller/src/model.py rename to controller/src/schemas/base.py index 5068830..58bb5c9 100644 --- a/controller/src/model.py +++ b/controller/src/schemas/base.py @@ -15,113 +15,11 @@ from datetime import datetime from enum import Enum from http.client import HTTPException -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Optional, Type, Union import yaml from pydantic import BaseModel - -# ============================== from llmapps/app/schema.py ============================== -# Temporary: This was copied to here to avoid import from the app like this: -# from llmapps.app.schema import Conversation, Message -class ApiResponse(BaseModel): - success: bool - data: Optional[Union[list, Type[BaseModel], dict]] = None - error: Optional[str] = None - - def with_raise(self, format=None) -> "ApiResponse": - if not self.success: - format = format or "API call failed: %s" - raise ValueError(format % self.error) - return self - - def with_raise_http(self, format=None) -> "ApiResponse": - if not self.success: - format = format or "API call failed: %s" - raise HTTPException(status_code=400, detail=format % self.error) - return self - - -class ChatRole(str, Enum): - Human = "Human" - AI = "AI" - System = "System" - User = "User" # for co-pilot user (vs Human?) - Agent = "Agent" # for co-pilot agent - - -class Message(BaseModel): - role: ChatRole - content: str - extra_data: Optional[dict] = None - sources: Optional[List[dict]] = None - human_feedback: Optional[str] = None - - -class Conversation(BaseModel): - messages: list[Message] = [] - saved_index: int = 0 - - def __str__(self): - return "\n".join([f"{m.role}: {m.content}" for m in self.messages]) - - def add_message(self, role, content, sources=None): - self.messages.append(Message(role=role, content=content, sources=sources)) - - def to_list(self): - return self.dict()["messages"] - # return self.model_dump(mode="json")["messages"] - - def to_dict(self): - return self.dict()["messages"] - # return self.model_dump(mode="json")["messages"] - - @classmethod - def from_list(cls, data: list): - return cls.parse_obj({"messages": data or []}) - # return cls.model_validate({"messages": data or []}) - - -class QueryItem(BaseModel): - question: str - session_name: Optional[str] = None - filter: Optional[List[Tuple[str, str]]] = None - data_source: Optional[str] = None - - -class OutputMode(str, Enum): - Names = "names" - Short = "short" - Dict = "dict" - Details = "details" - - -class DataSourceType(str, Enum): - relational = "relational" - vector = "vector" - graph = "graph" - key_value = "key-value" - column_family = "column-family" - storage = "storage" - other = "other" - - -class ModelType(str, Enum): - model = "model" - adapter = "adapter" - - -class WorkflowType(str, Enum): - ingestion = "ingestion" - application = "application" - data_processing = "data-processing" - training = "training" - evaluation = "evaluation" - - -# ======================================================================================== - - metadata_fields = [ "id", "name", @@ -264,89 +162,26 @@ class BaseWithVerMetadata(BaseWithOwner): version: Optional[str] = "" -class User(BaseWithMetadata): - _extra_fields = ["policy", "features"] - _top_level_fields = ["email", "full_name"] - - email: str - full_name: Optional[str] = None - features: Optional[dict[str, str]] = None - policy: Optional[dict[str, str]] = None - is_admin: Optional[bool] = False - - -class Project(BaseWithVerMetadata): - pass - - -class DataSource(BaseWithVerMetadata): - _top_level_fields = ["data_source_type"] - - data_source_type: DataSourceType - project_id: Optional[str] = None - category: Optional[str] = None - database_kwargs: Optional[dict[str, str]] = {} - - -class Dataset(BaseWithVerMetadata): - _top_level_fields = ["task"] - - project_id: Optional[str] = None - task: str - sources: Optional[List[str]] = None - path: str - producer: Optional[str] = None - - -class Model(BaseWithVerMetadata): - _extra_fields = ["path", "producer", "deployment"] - _top_level_fields = ["model_type", "task"] - - model_type: ModelType - base_model: str - project_id: Optional[str] = None - task: Optional[str] = None - path: Optional[str] = None - producer: Optional[str] = None - deployment: Optional[str] = None - - -class PromptTemplate(BaseWithVerMetadata): - _extra_fields = ["arguments"] - _top_level_fields = ["text"] - - text: str - project_id: Optional[str] = None - arguments: Optional[List[str]] = None - - -class Document(BaseWithVerMetadata): - _top_level_fields = ["path", "origin"] - path: str - project_id: Optional[str] = None - origin: Optional[str] = None - - -class Workflow(BaseWithVerMetadata): - _top_level_fields = ["workflow_type"] - - workflow_type: WorkflowType - deployment: str - project_id: Optional[str] = None - workflow_function: Optional[str] = None - configuration: Optional[dict] = None - graph: Optional[dict] = None - - def get_infer_path(self): - return f"{self.deployment}/api/workflows/{self.name}/infer" +class ApiResponse(BaseModel): + success: bool + data: Optional[Union[list, Type[BaseModel], dict]] = None + error: Optional[str] = None + def with_raise(self, format=None) -> "ApiResponse": + if not self.success: + format = format or "API call failed: %s" + raise ValueError(format % self.error) + return self -class ChatSession(BaseWithOwner): - _extra_fields = ["history"] - _top_level_fields = ["workflow_id"] + def with_raise_http(self, format=None) -> "ApiResponse": + if not self.success: + format = format or "API call failed: %s" + raise HTTPException(status_code=400, detail=format % self.error) + return self - workflow_id: str - history: Optional[List[Message]] = [] - def to_conversation(self): - return Conversation.from_list(self.history) +class OutputMode(str, Enum): + Names = "names" + Short = "short" + Dict = "dict" + Details = "details" diff --git a/controller/src/schemas/data_source.py b/controller/src/schemas/data_source.py new file mode 100644 index 0000000..1cf0754 --- /dev/null +++ b/controller/src/schemas/data_source.py @@ -0,0 +1,37 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Optional + +from controller.src.schemas.base import BaseWithVerMetadata + + +class DataSourceType(str, Enum): + relational = "relational" + vector = "vector" + graph = "graph" + key_value = "key-value" + column_family = "column-family" + storage = "storage" + other = "other" + + +class DataSource(BaseWithVerMetadata): + _top_level_fields = ["data_source_type"] + + data_source_type: DataSourceType + project_id: Optional[str] = None + category: Optional[str] = None + database_kwargs: Optional[dict[str, str]] = {} diff --git a/controller/src/schemas/dataset.py b/controller/src/schemas/dataset.py new file mode 100644 index 0000000..ca8e70f --- /dev/null +++ b/controller/src/schemas/dataset.py @@ -0,0 +1,27 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from controller.src.schemas.base import BaseWithVerMetadata + + +class Dataset(BaseWithVerMetadata): + _top_level_fields = ["task"] + + project_id: Optional[str] = None + task: str + sources: Optional[List[str]] = None + path: str + producer: Optional[str] = None diff --git a/controller/src/schemas/document.py b/controller/src/schemas/document.py new file mode 100644 index 0000000..9a68861 --- /dev/null +++ b/controller/src/schemas/document.py @@ -0,0 +1,24 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from controller.src.schemas.base import BaseWithVerMetadata + + +class Document(BaseWithVerMetadata): + _top_level_fields = ["path", "origin"] + path: str + project_id: Optional[str] = None + origin: Optional[str] = None diff --git a/controller/src/schemas/model.py b/controller/src/schemas/model.py new file mode 100644 index 0000000..2c18064 --- /dev/null +++ b/controller/src/schemas/model.py @@ -0,0 +1,36 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Optional + +from controller.src.schemas.base import BaseWithVerMetadata + + +class ModelType(str, Enum): + model = "model" + adapter = "adapter" + + +class Model(BaseWithVerMetadata): + _extra_fields = ["path", "producer", "deployment"] + _top_level_fields = ["model_type", "task"] + + model_type: ModelType + base_model: str + project_id: Optional[str] = None + task: Optional[str] = None + path: Optional[str] = None + producer: Optional[str] = None + deployment: Optional[str] = None diff --git a/controller/src/schemas/project.py b/controller/src/schemas/project.py new file mode 100644 index 0000000..ae996e9 --- /dev/null +++ b/controller/src/schemas/project.py @@ -0,0 +1,19 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from controller.src.schemas.base import BaseWithVerMetadata + + +class Project(BaseWithVerMetadata): + pass diff --git a/controller/src/schemas/prompt_template.py b/controller/src/schemas/prompt_template.py new file mode 100644 index 0000000..57ad3e7 --- /dev/null +++ b/controller/src/schemas/prompt_template.py @@ -0,0 +1,26 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from controller.src.schemas.base import BaseWithVerMetadata + + +class PromptTemplate(BaseWithVerMetadata): + _extra_fields = ["arguments"] + _top_level_fields = ["text"] + + text: str + project_id: Optional[str] = None + arguments: Optional[List[str]] = None diff --git a/controller/src/schemas/session.py b/controller/src/schemas/session.py new file mode 100644 index 0000000..f54e8fc --- /dev/null +++ b/controller/src/schemas/session.py @@ -0,0 +1,51 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import List, Optional, Tuple + +from pydantic import BaseModel + +from controller.src.schemas.base import BaseWithOwner + + +class QueryItem(BaseModel): + question: str + session_name: Optional[str] = None + filter: Optional[List[Tuple[str, str]]] = None + data_source: Optional[str] = None + + +class ChatRole(str, Enum): + Human = "Human" + AI = "AI" + System = "System" + User = "User" # for co-pilot user (vs Human?) + Agent = "Agent" # for co-pilot agent + + +class Message(BaseModel): + role: ChatRole + content: str + extra_data: Optional[dict] = None + sources: Optional[List[dict]] = None + human_feedback: Optional[str] = None + + +class ChatSession(BaseWithOwner): + _extra_fields = ["history"] + _top_level_fields = ["workflow_id"] + + workflow_id: str + history: Optional[List[Message]] = [] diff --git a/controller/src/schemas/user.py b/controller/src/schemas/user.py new file mode 100644 index 0000000..d29ead1 --- /dev/null +++ b/controller/src/schemas/user.py @@ -0,0 +1,28 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from controller.src.schemas.base import BaseWithMetadata + + +class User(BaseWithMetadata): + _extra_fields = ["policy", "features"] + _top_level_fields = ["email", "full_name"] + + email: str + full_name: Optional[str] = None + features: Optional[dict[str, str]] = None + policy: Optional[dict[str, str]] = None + is_admin: Optional[bool] = False diff --git a/controller/src/schemas/workflow.py b/controller/src/schemas/workflow.py new file mode 100644 index 0000000..20b3d58 --- /dev/null +++ b/controller/src/schemas/workflow.py @@ -0,0 +1,40 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Optional + +from controller.src.schemas.base import BaseWithVerMetadata + + +class WorkflowType(str, Enum): + ingestion = "ingestion" + application = "application" + data_processing = "data-processing" + training = "training" + evaluation = "evaluation" + + +class Workflow(BaseWithVerMetadata): + _top_level_fields = ["workflow_type"] + + workflow_type: WorkflowType + deployment: str + project_id: Optional[str] = None + workflow_function: Optional[str] = None + configuration: Optional[dict] = None + graph: Optional[dict] = None + + def get_infer_path(self): + return f"{self.deployment}/api/workflows/{self.name}/infer" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8c16310 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[tool.ruff] +extend-include = ["*.ipynb"] +target-version = "py39" + +[tool.ruff.lint] +select = [ + "F", # pyflakes + "W", # pycodestyle + "E", # pycodestyle + "I", # isort +] +exclude = ["*.ipynb"] + +[tool.ruff.lint.pycodestyle] +max-line-length = 120 + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + + +[tool.importlinter] +root_packages = [ + "controller.src", +] +include_external_packages = true \ No newline at end of file From f6835a1c058e4274085ae8dd3dce1b699e82d233 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Wed, 14 Aug 2024 10:54:53 +0300 Subject: [PATCH 06/10] update _send_to_app import from cli --- controller/src/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/controller/src/main.py b/controller/src/main.py index 1cac4fe..e29f998 100644 --- a/controller/src/main.py +++ b/controller/src/main.py @@ -20,7 +20,7 @@ import yaml from tabulate import tabulate -import controller.src.api as api +from controller.src.api.utils import _send_to_application from controller.src.config import config from controller.src.db import client from controller.src.schemas import ( @@ -175,7 +175,7 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil params["metadata"] = json.dumps({metadata[0]: metadata[1]}) click.echo(f"Running Data Ingestion from: {path} with loader: {loader}") - response = api._send_to_application( + response = _send_to_application( path=f"data_sources/{data_source.name}/ingest", method="POST", data=json.dumps(data), @@ -243,7 +243,7 @@ def infer( headers = {"x_username": user} if user else {} # Sent the event to the application's workflow: - response = api._send_to_application( + response = _send_to_application( path=path, method="POST", data=json.dumps(data), From 5680710529276d6066506461e4f653dac258cb0b Mon Sep 17 00:00:00 2001 From: yonishelach Date: Wed, 14 Aug 2024 15:52:42 +0300 Subject: [PATCH 07/10] fix cli --- controller/src/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/src/main.py b/controller/src/main.py index e29f998..49ff615 100644 --- a/controller/src/main.py +++ b/controller/src/main.py @@ -159,7 +159,7 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil document=document, session=session, ).data - + document = Document.from_dict(document).to_dict(to_datestr=True) # Send ingest to application: params = { "loader": loader, From b3ca439d710eb2028ca93e89abc0d9f58e580906 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Thu, 15 Aug 2024 11:32:05 +0300 Subject: [PATCH 08/10] fixed path of create endpoints, added delete session, removed field from data source model --- controller/src/api/endpoints/data_sources.py | 5 +--- controller/src/api/endpoints/datasets.py | 5 +--- controller/src/api/endpoints/documents.py | 5 +--- controller/src/api/endpoints/models.py | 5 +--- .../src/api/endpoints/prompt_templates.py | 5 +--- controller/src/api/endpoints/sessions.py | 23 ++++++++++++++++--- controller/src/api/endpoints/users.py | 2 +- controller/src/api/endpoints/workflows.py | 6 ++--- controller/src/main.py | 7 +++--- controller/src/schemas/data_source.py | 1 - 10 files changed, 32 insertions(+), 32 deletions(-) diff --git a/controller/src/api/endpoints/data_sources.py b/controller/src/api/endpoints/data_sources.py index 542ece6..abd9944 100644 --- a/controller/src/api/endpoints/data_sources.py +++ b/controller/src/api/endpoints/data_sources.py @@ -35,10 +35,9 @@ router = APIRouter(prefix="/projects/{project_name}") -@router.post("/data_sources/{data_source_name}") +@router.post("/data_sources") def create_data_source( project_name: str, - data_source_name: str, data_source: DataSource, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), @@ -47,7 +46,6 @@ def create_data_source( Create a new data source in the database. :param project_name: The name of the project to create the data source in. - :param data_source_name: The name of the data source to create. :param data_source: The data source to create. :param session: The database session. :param auth: The authentication information. @@ -59,7 +57,6 @@ def create_data_source( data_source.owner_id = client.get_user( user_name=auth.username, session=session ).data["id"] - data_source.name = data_source_name data_source.project_id = client.get_project( project_name=project_name, session=session ).data["id"] diff --git a/controller/src/api/endpoints/datasets.py b/controller/src/api/endpoints/datasets.py index 0dd2065..353b0b5 100644 --- a/controller/src/api/endpoints/datasets.py +++ b/controller/src/api/endpoints/datasets.py @@ -23,10 +23,9 @@ router = APIRouter(prefix="/projects/{project_name}") -@router.post("/datasets/{dataset_name}") +@router.post("/datasets") def create_dataset( project_name: str, - dataset_name: str, dataset: Dataset, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), @@ -35,7 +34,6 @@ def create_dataset( Create a new dataset in the database. :param project_name: The name of the project to create the dataset in. - :param dataset_name: The name of the dataset to create. :param dataset: The dataset to create. :param session: The database session. :param auth: The authentication information. @@ -47,7 +45,6 @@ def create_dataset( dataset.owner_id = client.get_user( user_name=auth.username, session=session ).data["id"] - dataset.name = dataset_name dataset.project_id = client.get_project( project_name=project_name, session=session ).data["id"] diff --git a/controller/src/api/endpoints/documents.py b/controller/src/api/endpoints/documents.py index 80f83ef..490917b 100644 --- a/controller/src/api/endpoints/documents.py +++ b/controller/src/api/endpoints/documents.py @@ -23,10 +23,9 @@ router = APIRouter(prefix="/projects/{project_name}") -@router.post("/documents/{document_name}") +@router.post("/documents") def create_document( project_name: str, - document_name: str, document: Document, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), @@ -35,7 +34,6 @@ def create_document( Create a new document in the database. :param project_name: The name of the project to create the document in. - :param document_name: The name of the document to create. :param document: The document to create. :param session: The database session. :param auth: The authentication information. @@ -47,7 +45,6 @@ def create_document( document.owner_id = client.get_user( user_name=auth.username, session=session ).data["id"] - document.name = document_name document.project_id = client.get_project( project_name=project_name, session=session ).data["id"] diff --git a/controller/src/api/endpoints/models.py b/controller/src/api/endpoints/models.py index 4086e55..00ab82b 100644 --- a/controller/src/api/endpoints/models.py +++ b/controller/src/api/endpoints/models.py @@ -23,10 +23,9 @@ router = APIRouter(prefix="/projects/{project_name}") -@router.post("/models/{model_name}") +@router.post("/models") def create_model( project_name: str, - model_name: str, model: Model, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), @@ -35,7 +34,6 @@ def create_model( Create a new model in the database. :param project_name: The name of the project to create the model in. - :param model_name: The name of the model to create. :param model: The model to create. :param session: The database session. :param auth: The authentication information. @@ -47,7 +45,6 @@ def create_model( model.owner_id = client.get_user(user_name=auth.username, session=session).data[ "id" ] - model.name = model_name model.project_id = client.get_project( project_name=project_name, session=session ).data["id"] diff --git a/controller/src/api/endpoints/prompt_templates.py b/controller/src/api/endpoints/prompt_templates.py index 5e8a2f6..2a8cbf6 100644 --- a/controller/src/api/endpoints/prompt_templates.py +++ b/controller/src/api/endpoints/prompt_templates.py @@ -23,10 +23,9 @@ router = APIRouter(prefix="/projects/{project_name}") -@router.post("/prompt_templates/{prompt_name}") +@router.post("/prompt_templates") def create_prompt( project_name: str, - prompt_name: str, prompt: PromptTemplate, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), @@ -35,7 +34,6 @@ def create_prompt( Create a new prompt in the database. :param project_name: The name of the project to create the prompt in. - :param prompt_name: The name of the prompt to create. :param prompt: The prompt to create. :param session: The database session. :param auth: The authentication information. @@ -47,7 +45,6 @@ def create_prompt( prompt.owner_id = client.get_user( user_name=auth.username, session=session ).data["id"] - prompt.name = prompt_name prompt.project_id = client.get_project( project_name=project_name, session=session ).data["id"] diff --git a/controller/src/api/endpoints/sessions.py b/controller/src/api/endpoints/sessions.py index 9ceb2f8..1d67198 100644 --- a/controller/src/api/endpoints/sessions.py +++ b/controller/src/api/endpoints/sessions.py @@ -21,10 +21,9 @@ router = APIRouter(prefix="/users/{user_name}") -@router.post("/sessions/{session_name}") +@router.post("/sessions") def create_session( user_name: str, - session_name: str, chat_session: ChatSession, session=Depends(get_db), ) -> ApiResponse: @@ -32,7 +31,6 @@ def create_session( Create a new session in the database. :param user_name: The name of the user to create the session for. - :param session_name: The name of the session to create. :param chat_session: The session to create. :param session: The database session. @@ -87,6 +85,25 @@ def update_session( return client.update_chat_session(chat_session=chat_session, session=session) +@router.delete("/sessions/{session_id}") +def delete_session( + user_name: str, session_id: str, session=Depends(get_db) +) -> ApiResponse: + """ + Delete a session from the database. + + :param user_name: The name of the user to delete the session for. + :param session_id: The ID of the session to delete. + :param session: The database session. + + :return: The response from the database. + """ + user_id = client.get_user(user_name=user_name, session=session).data["id"] + return client.delete_chat_session( + session_name=session_id, user_id=user_id, session=session + ) + + @router.get("/sessions") def list_sessions( user_name: str, diff --git a/controller/src/api/endpoints/users.py b/controller/src/api/endpoints/users.py index 0117c76..09525d9 100644 --- a/controller/src/api/endpoints/users.py +++ b/controller/src/api/endpoints/users.py @@ -84,7 +84,7 @@ def delete_user(user_name: str, session=Depends(get_db)) -> ApiResponse: return client.delete_user(user_name=user_name, session=session) -@router.get("/users/users") +@router.get("/users") def list_users( email: str = None, full_name: str = None, diff --git a/controller/src/api/endpoints/workflows.py b/controller/src/api/endpoints/workflows.py index 2e96652..cda5fa8 100644 --- a/controller/src/api/endpoints/workflows.py +++ b/controller/src/api/endpoints/workflows.py @@ -35,10 +35,9 @@ router = APIRouter(prefix="/projects/{project_name}") -@router.post("/workflows/{workflow_name}") +@router.post("/workflows") def create_workflow( project_name: str, - workflow_name: str, workflow: Workflow, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), @@ -47,7 +46,6 @@ def create_workflow( Create a new workflow in the database. :param project_name: The name of the project to create the workflow in. - :param workflow_name: The name of the workflow to create. :param workflow: The workflow to create. :param session: The database session. :param auth: The authentication information. @@ -59,7 +57,6 @@ def create_workflow( workflow.owner_id = client.get_user( user_name=auth.username, session=session ).data["id"] - workflow.name = workflow_name workflow.project_id = client.get_project( project_name=project_name, session=session ).data["id"] @@ -166,6 +163,7 @@ def list_workflows( project_id=project_id, owner_id=owner_id, version=version, + workflow_type=workflow_type, labels_match=labels, output_mode=mode, session=session, diff --git a/controller/src/main.py b/controller/src/main.py index 49ff615..70a6145 100644 --- a/controller/src/main.py +++ b/controller/src/main.py @@ -155,11 +155,12 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil ) # Add document to the database: - document = client.create_document( + response = client.create_document( document=document, session=session, - ).data - document = Document.from_dict(document).to_dict(to_datestr=True) + ) + document = Document.from_dict(response.data).to_dict(to_datestr=True) + # Send ingest to application: params = { "loader": loader, diff --git a/controller/src/schemas/data_source.py b/controller/src/schemas/data_source.py index 1cf0754..5b46060 100644 --- a/controller/src/schemas/data_source.py +++ b/controller/src/schemas/data_source.py @@ -33,5 +33,4 @@ class DataSource(BaseWithVerMetadata): data_source_type: DataSourceType project_id: Optional[str] = None - category: Optional[str] = None database_kwargs: Optional[dict[str, str]] = {} From c93abbd780c97931a12dab06178819ff4401d607 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Tue, 20 Aug 2024 15:36:59 +0300 Subject: [PATCH 09/10] guy's rc --- controller/Dockerfile | 2 +- controller/src/api/__init__.py | 72 +++ controller/src/api/api.py | 90 --- controller/src/api/endpoints/base.py | 24 - controller/src/api/endpoints/data_sources.py | 180 +++--- controller/src/api/endpoints/datasets.py | 131 +++-- controller/src/api/endpoints/documents.py | 129 +++-- controller/src/api/endpoints/models.py | 132 +++-- controller/src/api/endpoints/projects.py | 72 ++- .../src/api/endpoints/prompt_templates.py | 137 +++-- controller/src/api/endpoints/sessions.py | 123 ++-- controller/src/api/endpoints/users.py | 71 ++- controller/src/api/endpoints/workflows.py | 193 +++--- controller/src/db/sqlclient.py | 548 ++++++++---------- controller/src/db/sqldb.py | 250 ++++---- controller/src/main.py | 65 +-- controller/src/schemas/__init__.py | 2 +- controller/src/schemas/base.py | 34 +- controller/src/schemas/data_source.py | 16 +- controller/src/schemas/dataset.py | 6 +- controller/src/schemas/document.py | 4 +- controller/src/schemas/model.py | 14 +- controller/src/schemas/prompt_template.py | 4 +- controller/src/schemas/session.py | 20 +- controller/src/schemas/user.py | 6 +- controller/src/schemas/workflow.py | 25 +- 26 files changed, 1242 insertions(+), 1108 deletions(-) delete mode 100644 controller/src/api/api.py delete mode 100644 controller/src/api/endpoints/base.py diff --git a/controller/Dockerfile b/controller/Dockerfile index c1a51b2..cd42172 100644 --- a/controller/Dockerfile +++ b/controller/Dockerfile @@ -42,4 +42,4 @@ RUN pip install -r /controller/requirements.txt RUN python -m controller.src.main initdb # Run the controller's API server: -CMD ["uvicorn", "controller.src.api.api:app", "--port", "8001"] +CMD ["uvicorn", "controller.src.api:app", "--port", "8001"] diff --git a/controller/src/api/__init__.py b/controller/src/api/__init__.py index 84f1919..eb605b4 100644 --- a/controller/src/api/__init__.py +++ b/controller/src/api/__init__.py @@ -11,3 +11,75 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from controller.src.api.endpoints import ( + data_sources, + datasets, + documents, + models, + projects, + prompt_templates, + sessions, + users, + workflows, +) + +app = FastAPI() + +# Add CORS middleware, remove in production +origins = ["*"] # React app +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Create a router with a prefix +api_router = APIRouter(prefix="/api") + + +# Include the routers for the different endpoints +api_router.include_router( + users.router, + tags=["users"], +) +api_router.include_router( + projects.router, + tags=["projects"], +) +api_router.include_router( + data_sources.router, + tags=["data_sources"], +) +api_router.include_router( + datasets.router, + tags=["datasets"], +) +api_router.include_router( + models.router, + tags=["models"], +) +api_router.include_router( + prompt_templates.router, + tags=["prompt_templates"], +) +api_router.include_router( + documents.router, + tags=["documents"], +) +api_router.include_router( + workflows.router, + tags=["workflows"], +) +api_router.include_router( + sessions.router, + tags=["chat_sessions"], +) + +# Include the router in the main app +app.include_router(api_router) diff --git a/controller/src/api/api.py b/controller/src/api/api.py deleted file mode 100644 index ea58d21..0000000 --- a/controller/src/api/api.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2023 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from fastapi import APIRouter, FastAPI -from fastapi.middleware.cors import CORSMiddleware - -from controller.src.api.endpoints import ( - base, - data_sources, - datasets, - documents, - models, - projects, - prompt_templates, - sessions, - users, - workflows, -) - -app = FastAPI() - -# Add CORS middleware, remove in production -origins = ["*"] # React app -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Create a router with a prefix -api_router = APIRouter(prefix="/api") - - -# Include the routers for the different endpoints -api_router.include_router( - base.router, - tags=["base"], -) -api_router.include_router( - users.router, - tags=["users"], -) -api_router.include_router( - projects.router, - tags=["projects"], -) -api_router.include_router( - data_sources.router, - tags=["data_sources"], -) -api_router.include_router( - datasets.router, - tags=["datasets"], -) -api_router.include_router( - models.router, - tags=["models"], -) -api_router.include_router( - prompt_templates.router, - tags=["prompt_templates"], -) -api_router.include_router( - documents.router, - tags=["documents"], -) -api_router.include_router( - workflows.router, - tags=["workflows"], -) -api_router.include_router( - sessions.router, - tags=["chat_sessions"], -) - -# Include the router in the main app -app.include_router(api_router) diff --git a/controller/src/api/endpoints/base.py b/controller/src/api/endpoints/base.py deleted file mode 100644 index d95279e..0000000 --- a/controller/src/api/endpoints/base.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2023 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from fastapi import APIRouter - -from controller.src.db import client - -router = APIRouter() - - -@router.post("/tables") -def create_tables(drop_old: bool = False, names: list[str] = None): - return client.create_tables(drop_old=drop_old, names=names) diff --git a/controller/src/api/endpoints/data_sources.py b/controller/src/api/endpoints/data_sources.py index abd9944..ac5f19d 100644 --- a/controller/src/api/endpoints/data_sources.py +++ b/controller/src/api/endpoints/data_sources.py @@ -25,7 +25,7 @@ ) from controller.src.db import client from controller.src.schemas import ( - ApiResponse, + APIResponse, DataSource, DataSourceType, Document, @@ -40,48 +40,52 @@ def create_data_source( project_name: str, data_source: DataSource, session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Create a new data source in the database. :param project_name: The name of the project to create the data source in. :param data_source: The data source to create. :param session: The database session. - :param auth: The authentication information. :return: The response from the database. """ - # If the owner ID is not provided, get it from the username - if data_source.owner_id is None: - data_source.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - data_source.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_data_source(data_source=data_source, session=session) - - -@router.get("/data_sources/{data_source_name}") + try: + data = client.create_data_source(data_source=data_source, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create data source {data_source.name} in project {project_name}: {e}", + ) + + +@router.get("/data_sources/{uid}") def get_data_source( - project_name: str, data_source_name: str, session=Depends(get_db) -) -> ApiResponse: + project_name: str, uid: str, session=Depends(get_db) +) -> APIResponse: """ Get a data source from the database. - :param project_name: The name of the project to get the data source from. - :param data_source_name: The name of the data source to get. - :param session: The database session. + :param project_name: The name of the project to get the data source from. + :param uid: The uid of the data source to get. + :param session: The database session. :return: The data source from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_data_source( - project_id=project_id, data_source_name=data_source_name, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.get_data_source(project_id=project_id, uid=uid, session=session) + if data is None: + return APIResponse( + success=False, error=f"Data source with uid = {uid} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get data source {uid} in project {project_name}: {e}", + ) @router.put("/data_sources/{data_source_name}") @@ -90,7 +94,7 @@ def update_data_source( data_source: DataSource, data_source_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a data source in the database. @@ -101,51 +105,56 @@ def update_data_source( :return: The response from the database. """ - data_source.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if data_source_name != data_source.name: - raise ValueError( - f"Data source name does not match: {data_source_name} != {data_source.name}" + try: + data = client.update_data_source(data_source=data_source, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update data source {data_source_name} in project {project_name}: {e}", ) - return client.update_data_source(data_source=data_source, session=session) -@router.delete("/data_sources/{data_source_id}") +@router.delete("/data_sources/{uid}") def delete_data_source( - project_name: str, data_source_id: str, session=Depends(get_db) -) -> ApiResponse: + project_name: str, uid: str, session=Depends(get_db) +) -> APIResponse: """ Delete a data source from the database. - :param project_name: The name of the project to delete the data source from. - :param data_source_id: The ID of the data source to delete. - :param session: The database session. + :param project_name: The name of the project to delete the data source from. + :param uid: The ID of the data source to delete. + :param session: The database session. :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_data_source( - project_id=project_id, data_source_id=data_source_id, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + client.delete_data_source(project_id=project_id, uid=uid, session=session) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete data source {uid} in project {project_name}: {e}", + ) + return APIResponse(success=True) @router.get("/data_sources") def list_data_sources( project_name: str, + name: str = None, version: str = None, data_source_type: Union[DataSourceType, str] = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ List data sources in the database. :param project_name: The name of the project to list the data sources from. + :param name: The name to filter by. :param version: The version to filter by. :param data_source_type: The data source type to filter by. :param labels: The labels to filter by. @@ -155,25 +164,31 @@ def list_data_sources( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_data_sources( - project_id=project_id, - owner_id=owner_id, - version=version, - data_source_type=data_source_type, - labels_match=labels, - output_mode=mode, - session=session, - ) + owner_id = client.get_user(user_name=auth.username, session=session).uid + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.list_data_sources( + project_id=project_id, + name=name, + owner_id=owner_id, + version=version, + data_source_type=data_source_type, + labels_match=labels, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to list data sources in project {project_name}: {e}", + ) -@router.post("/data_sources/{data_source_name}/ingest") +@router.post("/data_sources/{uid}/ingest") def ingest( project_name, - data_source_name, + uid, loader: str, path: str, metadata=None, @@ -186,7 +201,7 @@ def ingest( Ingest document into the vector database. :param project_name: The name of the project to ingest the documents into. - :param data_source_name: The name of the data source to ingest the documents into. + :param uid: The UID of the data source to ingest the documents into. :param loader: The data loader type to use. :param path: The path to the document to ingest. :param metadata: The metadata to associate with the documents. @@ -197,14 +212,10 @@ def ingest( :return: The response from the application. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] + project_id = client.get_project(project_name=project_name, session=session).uid data_source = client.get_data_source( - project_id=project_id, data_source_name=data_source_name, session=session - ).data - - data_source = DataSource.from_dict(data_source) + project_id=project_id, uid=uid, session=session + ) # Create document from path: document = Document( @@ -216,7 +227,7 @@ def ingest( ) # Add document to the database: - document = client.create_document(document=document, session=session).data + document = client.create_document(document=document, session=session) # Send ingest to application: params = { @@ -225,16 +236,23 @@ def ingest( } data = { - "document": document, + "document": document.to_dict(), "database_kwargs": data_source.database_kwargs, } if metadata is not None: params["metadata"] = json.dumps(metadata) - return _send_to_application( - path=f"data_sources/{data_source_name}/ingest", - method="POST", - data=json.dumps(data), - params=params, - auth=auth, - ) + try: + data = _send_to_application( + path=f"data_sources/{data_source.name}/ingest", + method="POST", + data=json.dumps(data), + params=params, + auth=auth, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to ingest document into data source {data_source.name}: {e}", + ) diff --git a/controller/src/api/endpoints/datasets.py b/controller/src/api/endpoints/datasets.py index 353b0b5..f662ed1 100644 --- a/controller/src/api/endpoints/datasets.py +++ b/controller/src/api/endpoints/datasets.py @@ -18,7 +18,7 @@ from controller.src.api.utils import AuthInfo, get_auth_user, get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, Dataset, OutputMode +from controller.src.schemas import APIResponse, Dataset, OutputMode router = APIRouter(prefix="/projects/{project_name}") @@ -28,48 +28,50 @@ def create_dataset( project_name: str, dataset: Dataset, session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Create a new dataset in the database. :param project_name: The name of the project to create the dataset in. :param dataset: The dataset to create. :param session: The database session. - :param auth: The authentication information. :return: The response from the database. """ - # If the owner ID is not provided, get it from the username - if dataset.owner_id is None: - dataset.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - dataset.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_dataset(dataset=dataset, session=session) - - -@router.get("/datasets/{dataset_name}") -def get_dataset( - project_name: str, dataset_name: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.create_dataset(dataset=dataset, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create dataset {dataset.name} in project {project_name}: {e}", + ) + + +@router.get("/datasets/{uid}") +def get_dataset(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Get a dataset from the database. :param project_name: The name of the project to get the dataset from. - :param dataset_name: The name of the dataset to get. + :param uid: The name of the dataset to get. :param session: The database session. :return: The dataset from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_dataset( - project_id=project_id, dataset_name=dataset_name, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.get_dataset(project_id=project_id, uid=uid, session=session) + if data is None: + return APIResponse( + success=False, error=f"Dataset with uid = {uid} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get dataset {uid} in project {project_name}: {e}", + ) @router.put("/datasets/{dataset_name}") @@ -78,7 +80,7 @@ def update_dataset( dataset: Dataset, dataset_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a dataset in the database. @@ -89,51 +91,54 @@ def update_dataset( :return: The response from the database. """ - dataset.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if dataset_name != dataset.name: - raise ValueError( - f"Dataset name does not match: {dataset_name} != {dataset.name}" + try: + data = client.update_dataset(dataset=dataset, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update dataset {dataset_name} in project {project_name}: {e}", ) - return client.update_dataset(dataset=dataset, session=session) -@router.delete("/datasets/{dataset_id}") -def delete_dataset( - project_name: str, dataset_id: str, session=Depends(get_db) -) -> ApiResponse: +@router.delete("/datasets/{uid}") +def delete_dataset(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Delete a dataset from the database. :param project_name: The name of the project to delete the dataset from. - :param dataset_id: The ID of the dataset to delete. + :param uid: The UID of the dataset to delete. :param session: The database session. :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_dataset( - project_id=project_id, dataset_id=dataset_id, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + client.delete_dataset(project_id=project_id, uid=uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete dataset {uid} in project {project_name}: {e}", + ) @router.get("/datasets") def list_datasets( project_name: str, + name: str = None, version: str = None, task: str = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ List datasets in the database. :param project_name: The name of the project to list the datasets from. + :param name: The name to filter by. :param version: The version to filter by. :param task: The task to filter by. :param labels: The labels to filter by. @@ -143,16 +148,22 @@ def list_datasets( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_datasets( - project_id=project_id, - owner_id=owner_id, - version=version, - task=task, - labels_match=labels, - output_mode=mode, - session=session, - ) + owner_id = client.get_user(user_name=auth.username, session=session).uid + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.list_datasets( + project_id=project_id, + name=name, + owner_id=owner_id, + version=version, + task=task, + labels_match=labels, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to list datasets in project {project_name}: {e}", + ) diff --git a/controller/src/api/endpoints/documents.py b/controller/src/api/endpoints/documents.py index 490917b..3f3149a 100644 --- a/controller/src/api/endpoints/documents.py +++ b/controller/src/api/endpoints/documents.py @@ -18,7 +18,7 @@ from controller.src.api.utils import AuthInfo, get_auth_user, get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, Document, OutputMode +from controller.src.schemas import APIResponse, Document, OutputMode router = APIRouter(prefix="/projects/{project_name}") @@ -28,48 +28,50 @@ def create_document( project_name: str, document: Document, session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Create a new document in the database. :param project_name: The name of the project to create the document in. :param document: The document to create. :param session: The database session. - :param auth: The authentication information. :return: The response from the database. """ - # If the owner ID is not provided, get it from the username - if document.owner_id is None: - document.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - document.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_document(document=document, session=session) - - -@router.get("/documents/{document_name}") -def get_document( - project_name: str, document_name: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.create_document(document=document, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create document {document.name} in project {project_name}: {e}", + ) + + +@router.get("/documents/{uid}") +def get_document(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Get a document from the database. :param project_name: The name of the project to get the document from. - :param document_name: The name of the document to get. + :param uid: The UID of the document to get. :param session: The database session. :return: The document from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_document( - project_id=project_id, document_name=document_name, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.get_document(project_id=project_id, uid=uid, session=session) + if data is None: + return APIResponse( + success=False, error=f"Document with uid = {uid} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get document {uid} in project {project_name}: {e}", + ) @router.put("/documents/{document_name}") @@ -78,7 +80,7 @@ def update_document( document: Document, document_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a document in the database. @@ -89,50 +91,55 @@ def update_document( :return: The response from the database. """ - document.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if document_name != document.name: - raise ValueError( - f"Document name does not match: {document_name} != {document.name}" + try: + data = client.update_document(document=document, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update document {document.name} in project {project_name}: {e}", ) - return client.update_document(document=document, session=session) -@router.delete("/documents/{document_id}") +@router.delete("/documents/{uid}") def delete_document( - project_name: str, document_id: str, session=Depends(get_db) -) -> ApiResponse: + project_name: str, uid: str, session=Depends(get_db) +) -> APIResponse: """ Delete a document from the database. :param project_name: The name of the project to delete the document from. - :param document_id: The ID of the document to delete. + :param uid: The UID of the document to delete. :param session: The database session. :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_document( - project_id=project_id, document_id=document_id, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + client.delete_document(project_id=project_id, uid=uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete document {uid} in project {project_name}: {e}", + ) @router.get("/documents") def list_documents( project_name: str, + name: str = None, version: str = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ List documents in the database. :param project_name: The name of the project to list the documents from. + :param name: The name to filter by. :param version: The version to filter by. :param labels: The labels to filter by. :param mode: The output mode. @@ -141,15 +148,21 @@ def list_documents( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_documents( - project_id=project_id, - owner_id=owner_id, - version=version, - labels_match=labels, - output_mode=mode, - session=session, - ) + owner_id = client.get_user(user_name=auth.username, session=session).uid + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.list_documents( + project_id=project_id, + name=name, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to list documents in project {project_name}: {e}", + ) diff --git a/controller/src/api/endpoints/models.py b/controller/src/api/endpoints/models.py index 00ab82b..1c5e72f 100644 --- a/controller/src/api/endpoints/models.py +++ b/controller/src/api/endpoints/models.py @@ -18,7 +18,7 @@ from controller.src.api.utils import AuthInfo, get_auth_user, get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, Model, OutputMode +from controller.src.schemas import APIResponse, Model, OutputMode router = APIRouter(prefix="/projects/{project_name}") @@ -28,48 +28,48 @@ def create_model( project_name: str, model: Model, session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Create a new model in the database. :param project_name: The name of the project to create the model in. :param model: The model to create. :param session: The database session. - :param auth: The authentication information. :return: The response from the database. """ - # If the owner ID is not provided, get it from the username - if model.owner_id is None: - model.owner_id = client.get_user(user_name=auth.username, session=session).data[ - "id" - ] - model.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_model(model=model, session=session) - - -@router.get("/models/{model_name}") -def get_model( - project_name: str, model_name: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.create_model(model=model, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create model {model.name} in project {project_name}: {e}", + ) + + +@router.get("/models/{uid}") +def get_model(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Get a model from the database. :param project_name: The name of the project to get the model from. - :param model_name: The name of the model to get. + :param uid: The UID of the model to get. :param session: The database session. :return: The model from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_model( - project_id=project_id, model_name=model_name, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.get_model(project_id=project_id, uid=uid, session=session) + if data is None: + return APIResponse(success=False, error=f"Model with uid = {uid} not found") + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get model {uid} in project {project_name}: {e}", + ) @router.put("/models/{model_name}") @@ -78,7 +78,7 @@ def update_model( model: Model, model_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a model in the database. @@ -89,49 +89,54 @@ def update_model( :return: The response from the database. """ - model.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if model_name != model.name: - raise ValueError(f"Model name does not match: {model_name} != {model.name}") - return client.update_model(model=model, session=session) - - -@router.delete("/models/{model_id}") -def delete_model( - project_name: str, model_id: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.update_model(model=model, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update model {model.name} in project {project_name}: {e}", + ) + + +@router.delete("/models/{uid}") +def delete_model(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Delete a model from the database. :param project_name: The name of the project to delete the model from. - :param model_id: The ID of the model to delete. + :param uid: The ID of the model to delete. :param session: The database session. :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_model( - project_id=project_id, model_id=model_id, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + client.delete_model(project_id=project_id, uid=uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete model {uid} in project {project_name}: {e}", + ) @router.get("/models") def list_models( project_name: str, + name: str = None, version: str = None, model_type: str = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ List models in the database. :param project_name: The name of the project to list the models from. + :param name: The name to filter by. :param version: The version to filter by. :param model_type: The model type to filter by. :param labels: The labels to filter by. @@ -141,16 +146,21 @@ def list_models( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_models( - project_id=project_id, - owner_id=owner_id, - version=version, - model_type=model_type, - labels_match=labels, - output_mode=mode, - session=session, - ) + owner_id = client.get_user(user_name=auth.username, session=session).uid + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.list_models( + project_id=project_id, + name=name, + owner_id=owner_id, + version=version, + model_type=model_type, + labels_match=labels, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to list models in project {project_name}: {e}" + ) diff --git a/controller/src/api/endpoints/projects.py b/controller/src/api/endpoints/projects.py index 6a5343b..486ed65 100644 --- a/controller/src/api/endpoints/projects.py +++ b/controller/src/api/endpoints/projects.py @@ -18,7 +18,7 @@ from controller.src.api.utils import get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, OutputMode, Project +from controller.src.schemas import APIResponse, OutputMode, Project router = APIRouter() @@ -27,7 +27,7 @@ def create_project( project: Project, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Create a new project in the database. @@ -36,11 +36,17 @@ def create_project( :return: The response from the database. """ - return client.create_project(project=project, session=session) + try: + data = client.create_project(project=project, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to create project {project.name}: {e}" + ) @router.get("/projects/{project_name}") -def get_project(project_name: str, session=Depends(get_db)) -> ApiResponse: +def get_project(project_name: str, session=Depends(get_db)) -> APIResponse: """ Get a project from the database. @@ -49,7 +55,17 @@ def get_project(project_name: str, session=Depends(get_db)) -> ApiResponse: :return: The project from the database. """ - return client.get_project(project_name=project_name, session=session) + try: + data = client.get_project(project_name=project_name, session=session) + if data is None: + return APIResponse( + success=False, error=f"Project with name {project_name} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to get project {project_name}: {e}" + ) @router.put("/projects/{project_name}") @@ -57,7 +73,7 @@ def update_project( project: Project, project_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a project in the database. @@ -67,15 +83,17 @@ def update_project( :return: The response from the database. """ - if project_name != project.name: - raise ValueError( - f"Project name does not match: {project_name} != {project.name}" + try: + data = client.update_project(project=project, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to update project {project_name}: {e}" ) - return client.update_project(project=project, session=session) @router.delete("/projects/{project_name}") -def delete_project(project_name: str, session=Depends(get_db)) -> ApiResponse: +def delete_project(project_name: str, session=Depends(get_db)) -> APIResponse: """ Delete a project from the database. @@ -84,19 +102,29 @@ def delete_project(project_name: str, session=Depends(get_db)) -> ApiResponse: :return: The response from the database. """ - return client.delete_project(project_name=project_name, session=session) + project = client.get_project(project_name=project_name, session=session) + + try: + client.delete_project(uid=project.uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to delete project {project_name}: {e}" + ) @router.get("/projects") def list_projects( + name: str = None, owner_name: str = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ List projects in the database. + :param name: The name of the project to filter by. :param owner_name: The name of the owner to filter by. :param labels: The labels to filter by. :param mode: The output mode. @@ -105,9 +133,17 @@ def list_projects( :return: The response from the database. """ if owner_name is not None: - owner_id = client.get_user(user_name=owner_name, session=session).data["id"] + owner_id = client.get_user(user_name=owner_name, session=session).uid else: owner_id = None - return client.list_projects( - owner_id=owner_id, labels_match=labels, output_mode=mode, session=session - ) + try: + data = client.list_projects( + owner_id=owner_id, + labels_match=labels, + output_mode=mode, + session=session, + name=name, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse(success=False, error=f"Failed to list projects: {e}") diff --git a/controller/src/api/endpoints/prompt_templates.py b/controller/src/api/endpoints/prompt_templates.py index 2a8cbf6..b0c0957 100644 --- a/controller/src/api/endpoints/prompt_templates.py +++ b/controller/src/api/endpoints/prompt_templates.py @@ -18,7 +18,7 @@ from controller.src.api.utils import AuthInfo, get_auth_user, get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, OutputMode, PromptTemplate +from controller.src.schemas import APIResponse, OutputMode, PromptTemplate router = APIRouter(prefix="/projects/{project_name}") @@ -28,48 +28,50 @@ def create_prompt( project_name: str, prompt: PromptTemplate, session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Create a new prompt in the database. :param project_name: The name of the project to create the prompt in. :param prompt: The prompt to create. :param session: The database session. - :param auth: The authentication information. :return: The response from the database. """ - # If the owner ID is not provided, get it from the username - if prompt.owner_id is None: - prompt.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - prompt.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_prompt_template(prompt=prompt, session=session) - - -@router.get("/prompt_templates/{prompt_name}") -def get_prompt( - project_name: str, prompt_name: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.create_prompt_template(prompt=prompt, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create prompt {prompt.name} in project {project_name}: {e}", + ) + + +@router.get("/prompt_templates/{uid}") +def get_prompt(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Get a prompt from the database. :param project_name: The name of the project to get the prompt from. - :param prompt_name: The name of the prompt to get. + :param uid: The UID of the prompt to get. :param session: The database session. :return: The prompt from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_prompt( - project_id=project_id, prompt_name=prompt_name, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.get_prompt(project_id=project_id, uid=uid, session=session) + if data is None: + return APIResponse( + success=False, error=f"Prompt with uid = {uid} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get prompt {uid} in project {project_name}: {e}", + ) @router.put("/prompt_templates/{prompt_name}") @@ -78,7 +80,7 @@ def update_prompt( prompt: PromptTemplate, prompt_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a prompt in the database. @@ -89,48 +91,53 @@ def update_prompt( :return: The response from the database. """ - prompt.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if prompt_name != prompt.name: - raise ValueError(f"Prompt name does not match: {prompt_name} != {prompt.name}") - return client.update_prompt_template(prompt=prompt, session=session) - - -@router.delete("/prompt_templates/{prompt_template_id}") -def delete_prompt( - project_name: str, prompt_template_id: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.update_prompt_template(prompt=prompt, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update prompt {prompt_name} in project {project_name}: {e}", + ) + + +@router.delete("/prompt_templates/{uid}") +def delete_prompt(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Delete a prompt from the database. - :param project_name: The name of the project to delete the prompt from. - :param prompt_template_id: The ID of the prompt to delete. - :param session: The database session. + :param project_name: The name of the project to delete the prompt from. + :param uid: The UID of the prompt to delete. + :param session: The database session. :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_prompt_template( - project_id=project_id, prompt_template_id=prompt_template_id, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + client.delete_prompt_template(project_id=project_id, uid=uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete prompt {uid} in project {project_name}: {e}", + ) @router.get("/prompt_templates") def list_prompts( project_name: str, + name: str = None, version: str = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ List prompts in the database. :param project_name: The name of the project to list the prompts from. + :param name: The name to filter by. :param version: The version to filter by. :param labels: The labels to filter by. :param mode: The output mode. @@ -139,15 +146,21 @@ def list_prompts( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_prompt_templates( - project_id=project_id, - owner_id=owner_id, - version=version, - labels_match=labels, - output_mode=mode, - session=session, - ) + owner_id = client.get_user(user_name=auth.username, session=session).uid + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.list_prompt_templates( + project_id=project_id, + name=name, + owner_id=owner_id, + version=version, + labels_match=labels, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to list prompts in project {project_name}: {e}", + ) diff --git a/controller/src/api/endpoints/sessions.py b/controller/src/api/endpoints/sessions.py index 1d67198..9accb23 100644 --- a/controller/src/api/endpoints/sessions.py +++ b/controller/src/api/endpoints/sessions.py @@ -16,7 +16,7 @@ from controller.src.api.utils import get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, ChatSession, OutputMode +from controller.src.schemas import APIResponse, ChatSession, OutputMode router = APIRouter(prefix="/users/{user_name}") @@ -26,7 +26,7 @@ def create_session( user_name: str, chat_session: ChatSession, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Create a new session in the database. @@ -36,32 +36,43 @@ def create_session( :return: The response from the database. """ - chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ - "id" - ] - return client.create_chat_session(chat_session=chat_session, session=session) - - -@router.get("/sessions/{session_name}") -def get_session( - user_name: str, session_name: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.create_chat_session(chat_session=chat_session, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create session {chat_session.uid} for user {user_name}: {e}", + ) + + +@router.get("/sessions/{uid}") +def get_session(user_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Get a session from the database. If the session ID is "$last", get the last session for the user. :param user_name: The name of the user to get the session for. - :param session_name: The name of the session to get. + :param uid: The UID of the session to get. if "$last" bring the last user's session. :param session: The database session. :return: The session from the database. """ user_id = None - if session_name == "$last": - user_id = client.get_user(user_name=user_name, session=session).data["id"] - session_name = None - return client.get_chat_session( - session_name=session_name, user_id=user_id, session=session - ) + if uid == "$last": + user_id = client.get_user(user_name=user_name, session=session).uid + uid = None + try: + data = client.get_chat_session(uid=uid, user_id=user_id, session=session) + if data is None: + return APIResponse( + success=False, error=f"Session with uid = {uid} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get session {uid} for user {user_name}: {e}", + ) @router.put("/sessions/{session_name}") @@ -69,7 +80,7 @@ def update_session( user_name: str, chat_session: ChatSession, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a session in the database. @@ -79,44 +90,53 @@ def update_session( :return: The response from the database. """ - chat_session.owner_id = client.get_user(user_name=user_name, session=session).data[ - "id" - ] - return client.update_chat_session(chat_session=chat_session, session=session) - - -@router.delete("/sessions/{session_id}") -def delete_session( - user_name: str, session_id: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.update_chat_session(chat_session=chat_session, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update session {chat_session.uid} for user {user_name}: {e}", + ) + + +@router.delete("/sessions/{uid}") +def delete_session(user_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Delete a session from the database. - :param user_name: The name of the user to delete the session for. - :param session_id: The ID of the session to delete. - :param session: The database session. + :param user_name: The name of the user to delete the session for. + :param uid: The UID of the session to delete. + :param session: The database session. :return: The response from the database. """ - user_id = client.get_user(user_name=user_name, session=session).data["id"] - return client.delete_chat_session( - session_name=session_id, user_id=user_id, session=session - ) + user_id = client.get_user(user_name=user_name, session=session).uid + try: + client.delete_chat_session(uid=uid, user_id=user_id, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete session {uid} for user {user_name}: {e}", + ) @router.get("/sessions") def list_sessions( user_name: str, + name: str = None, last: int = 0, created: str = None, workflow_id: str = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ List sessions in the database. :param user_name: The name of the user to list the sessions for. + :param name: The name of the session to filter by. :param last: The number of sessions to get. :param created: The date to filter by. :param workflow_id: The ID of the workflow to filter by. @@ -125,12 +145,19 @@ def list_sessions( :return: The response from the database. """ - user_id = client.get_user(user_name=user_name, session=session).data["id"] - return client.list_chat_sessions( - user_id=user_id, - last=last, - created_after=created, - workflow_id=workflow_id, - output_mode=mode, - session=session, - ) + user_id = client.get_user(user_name=user_name, session=session).uid + try: + data = client.list_chat_sessions( + user_id=user_id, + name=name, + last=last, + created_after=created, + workflow_id=workflow_id, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to list sessions for user {user_name}: {e}" + ) diff --git a/controller/src/api/endpoints/users.py b/controller/src/api/endpoints/users.py index 09525d9..5892dc9 100644 --- a/controller/src/api/endpoints/users.py +++ b/controller/src/api/endpoints/users.py @@ -16,7 +16,7 @@ from controller.src.api.utils import get_db from controller.src.db import client -from controller.src.schemas import ApiResponse, OutputMode, User +from controller.src.schemas import APIResponse, OutputMode, User router = APIRouter() @@ -25,7 +25,7 @@ def create_user( user: User, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Create a new user in the database. @@ -34,11 +34,17 @@ def create_user( :return: The response from the database. """ - return client.create_user(user=user, session=session) + try: + data = client.create_user(user=user, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to create user {user.name}: {e}" + ) @router.get("/users/{user_name}") -def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> ApiResponse: +def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> APIResponse: """ Get a user from the database. @@ -48,7 +54,19 @@ def get_user(user_name: str, email: str = None, session=Depends(get_db)) -> ApiR :return: The user from the database. """ - return client.get_user(user_name=user_name, email=email, session=session) + try: + data = client.get_user(user_name=user_name, email=email, session=session) + if data is None: + return APIResponse( + success=False, + error=f"User with name = {user_name}, email = {email} not found", + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get user with name = {user_name}, email = {email}: {e}", + ) @router.put("/users/{user_name}") @@ -56,7 +74,7 @@ def update_user( user: User, user_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a user in the database. @@ -66,13 +84,17 @@ def update_user( :return: The response from the database. """ - if user_name != user.name: - raise ValueError(f"User name does not match: {user_name} != {user.name}") - return client.update_user(user=user, session=session) + try: + data = client.update_user(user=user, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to update user {user_name}: {e}" + ) @router.delete("/users/{user_name}") -def delete_user(user_name: str, session=Depends(get_db)) -> ApiResponse: +def delete_user(user_name: str, session=Depends(get_db)) -> APIResponse: """ Delete a user from the database. @@ -81,19 +103,28 @@ def delete_user(user_name: str, session=Depends(get_db)) -> ApiResponse: :return: The response from the database. """ - return client.delete_user(user_name=user_name, session=session) + user = client.get_user(user_name=user_name, session=session) + try: + client.delete_user(uid=user.uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, error=f"Failed to delete user {user_name}: {e}" + ) @router.get("/users") def list_users( + name: str = None, email: str = None, full_name: str = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ List users in the database. + :param name: The name to filter by. :param email: The email address to filter by. :param full_name: The full name to filter by. :param mode: The output mode. @@ -101,6 +132,14 @@ def list_users( :return: The response from the database. """ - return client.list_users( - email=email, full_name=full_name, output_mode=mode, session=session - ) + try: + data = client.list_users( + name=name, + email=email, + full_name=full_name, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse(success=False, error=f"Failed to list users: {e}") diff --git a/controller/src/api/endpoints/workflows.py b/controller/src/api/endpoints/workflows.py index cda5fa8..ca2890c 100644 --- a/controller/src/api/endpoints/workflows.py +++ b/controller/src/api/endpoints/workflows.py @@ -25,7 +25,8 @@ ) from controller.src.db import client from controller.src.schemas import ( - ApiResponse, + APIResponse, + ChatSession, OutputMode, QueryItem, Workflow, @@ -40,48 +41,50 @@ def create_workflow( project_name: str, workflow: Workflow, session=Depends(get_db), - auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Create a new workflow in the database. :param project_name: The name of the project to create the workflow in. :param workflow: The workflow to create. :param session: The database session. - :param auth: The authentication information. :return: The response from the database. """ - # If the owner ID is not provided, get it from the username - if workflow.owner_id is None: - workflow.owner_id = client.get_user( - user_name=auth.username, session=session - ).data["id"] - workflow.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - return client.create_workflow(workflow=workflow, session=session) - - -@router.get("/workflows/{workflow_name}") -def get_workflow( - project_name: str, workflow_name: str, session=Depends(get_db) -) -> ApiResponse: + try: + data = client.create_workflow(workflow=workflow, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to create workflow {workflow.name} in project {project_name}: {e}", + ) + + +@router.get("/workflows/{uid}") +def get_workflow(project_name: str, uid: str, session=Depends(get_db)) -> APIResponse: """ Get a workflow from the database. :param project_name: The name of the project to get the workflow from. - :param workflow_name: The name of the workflow to get. + :param uid: The UID of the workflow to get. :param session: The database session. :return: The workflow from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.get_workflow( - project_id=project_id, workflow_name=workflow_name, session=session - ) + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.get_workflow(project_id=project_id, uid=uid, session=session) + if data is None: + return APIResponse( + success=False, error=f"Workflow with uid = {uid} not found" + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to get workflow {uid} in project {project_name}: {e}", + ) @router.put("/workflows/{workflow_name}") @@ -90,7 +93,7 @@ def update_workflow( workflow: Workflow, workflow_name: str, session=Depends(get_db), -) -> ApiResponse: +) -> APIResponse: """ Update a workflow in the database. @@ -101,51 +104,55 @@ def update_workflow( :return: The response from the database. """ - workflow.project_id = client.get_project( - project_name=project_name, session=session - ).data["id"] - if workflow_name != workflow.name: - raise ValueError( - f"Workflow name does not match: {workflow_name} != {workflow.name}" + try: + data = client.update_workflow(workflow=workflow, session=session) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to update workflow {workflow_name} in project {project_name}: {e}", ) - return client.update_workflow(workflow=workflow, session=session) -@router.delete("/workflows/{workflow_id}") +@router.delete("/workflows/{uid}") def delete_workflow( - project_name: str, workflow_id: str, session=Depends(get_db) -) -> ApiResponse: + project_name: str, uid: str, session=Depends(get_db) +) -> APIResponse: """ Delete a workflow from the database. :param project_name: The name of the project to delete the workflow from. - :param workflow_id: The ID of the workflow to delete. + :param uid: The UID of the workflow to delete. :param session: The database session. :return: The response from the database. """ - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.delete_workflow( - project_id=project_id, workflow_id=workflow_id, session=session - ) + try: + client.delete_workflow(uid=uid, session=session) + return APIResponse(success=True) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to delete workflow {uid} in project {project_name}: {e}", + ) @router.get("/workflows") def list_workflows( project_name: str, + name: str = None, version: str = None, workflow_type: Union[WorkflowType, str] = None, labels: Optional[List[Tuple[str, str]]] = None, - mode: OutputMode = OutputMode.Details, + mode: OutputMode = OutputMode.DETAILS, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ List workflows in the database. :param project_name: The name of the project to list the workflows from. + :param name: The name to filter by. :param version: The version to filter by. :param workflow_type: The workflow type to filter by. :param labels: The labels to filter by. @@ -155,34 +162,40 @@ def list_workflows( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).data["id"] - project_id = client.get_project(project_name=project_name, session=session).data[ - "id" - ] - return client.list_workflows( - project_id=project_id, - owner_id=owner_id, - version=version, - workflow_type=workflow_type, - labels_match=labels, - output_mode=mode, - session=session, - ) - - -@router.post("/workflows/{workflow_name}/infer") + owner_id = client.get_user(user_name=auth.username, session=session).uid + project_id = client.get_project(project_name=project_name, session=session).uid + try: + data = client.list_workflows( + name=name, + project_id=project_id, + owner_id=owner_id, + version=version, + workflow_type=workflow_type, + labels_match=labels, + output_mode=mode, + session=session, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to list workflows in project {project_name}: {e}", + ) + + +@router.post("/workflows/{uid}/infer") def infer_workflow( project_name: str, - workflow_name: str, + uid: str, query: QueryItem, session=Depends(get_db), auth: AuthInfo = Depends(get_auth_user), -) -> ApiResponse: +) -> APIResponse: """ Run application workflow. :param project_name: The name of the project to run the workflow in. - :param workflow_name: The name of the workflow to run. + :param uid: The UID of the workflow to run. :param query: The query to run. :param session: The database session. :param auth: The authentication information. @@ -190,20 +203,48 @@ def infer_workflow( :return: The response from the database. """ # Get workflow from the database - workflow = Workflow.from_dict( - get_workflow(project_name, workflow_name, session).data - ) - path = Workflow.get_infer_path(workflow) - + project_id = client.get_project(project_name=project_name, session=session).uid + workflow = client.get_workflow(project_id=project_id, uid=uid, session=session) + path = workflow.get_infer_path() + + if query.session_id: + # Get session by id: + session = client.get_chat_session(uid=query.session_id, session=session) + if session is None: + # If not id found, get session by name: + session_name = query.session_id + session = client.list_chat_sessions(name=session_name, session=session) + # If not name found, create a new session: + if session: + session = session[0] + else: + session = client.create_chat_session( + chat_session=ChatSession( + name=session_name, + workflow_id=uid, + owner_id=client.get_user( + user_name=auth.username, session=session + ).uid, + ), + ) + query.session_id = session.uid + # Prepare the data to send to the application's workflow data = { "item": query.dict(), "workflow": workflow.to_dict(short=True), } # Sent the event to the application's workflow: - return _send_to_application( - path=path, - method="POST", - data=json.dumps(data), - auth=auth, - ) + try: + data = _send_to_application( + path=path, + method="POST", + data=json.dumps(data), + auth=auth, + ) + return APIResponse(success=True, data=data) + except Exception as e: + return APIResponse( + success=False, + error=f"Failed to infer workflow {uid} in project {project_name}: {e}", + ) diff --git a/controller/src/db/sqlclient.py b/controller/src/db/sqlclient.py index f15fe78..f5080c2 100644 --- a/controller/src/db/sqlclient.py +++ b/controller/src/db/sqlclient.py @@ -22,7 +22,6 @@ import controller.src.db.sqldb as db import controller.src.schemas as api_models from controller.src.config import logger -from controller.src.schemas import ApiResponse class SqlClient: @@ -58,7 +57,7 @@ def get_local_session(self): """ return self._local_maker() - def create_tables(self, drop_old: bool = False, names: list = None) -> ApiResponse: + def create_tables(self, drop_old: bool = False, names: list = None) -> None: """ Create the tables in the database. @@ -73,9 +72,10 @@ def create_tables(self, drop_old: bool = False, names: list = None) -> ApiRespon if drop_old: db.Base.metadata.drop_all(self.engine, tables=tables) db.Base.metadata.create_all(self.engine, tables=tables, checkfirst=True) - return ApiResponse(success=True) - def _create(self, session: sqlalchemy.orm.Session, db_class, obj) -> ApiResponse: + def _create( + self, session: sqlalchemy.orm.Session, db_class, obj + ) -> Type[api_models.Base]: """ Create an object in the database. This method generates a UID to the object and adds the object to the session and commits the transaction. @@ -84,25 +84,19 @@ def _create(self, session: sqlalchemy.orm.Session, db_class, obj) -> ApiResponse :param db_class: The DB class of the object. :param obj: The object to create. - :return: A response object with the success status and the created object when successful. + :return: The created object. """ session = self.get_db_session(session) - try: - uid = uuid.uuid4().hex - db_object = obj.to_orm_object(db_class, uid=uid) - session.add(db_object) - session.commit() - return ApiResponse( - success=True, data=obj.__class__.from_orm_object(db_object) - ) - except sqlalchemy.exc.IntegrityError: - return ApiResponse( - success=False, error=f"{db_class} {obj.name} already exists" - ) + # try: + uid = uuid.uuid4().hex + db_object = obj.to_orm_object(db_class, uid=uid) + session.add(db_object) + session.commit() + return obj.__class__.from_orm_object(db_object) def _get( self, session: sqlalchemy.orm.Session, db_class, api_class, **kwargs - ) -> ApiResponse: + ) -> Union[Type[api_models.Base], None]: """ Get an object from the database. @@ -111,35 +105,16 @@ def _get( :param api_class: The API class of the object. :param kwargs: The keyword arguments to filter the object. - :return: A response object with the success status and the object when successful. + :return: the object. """ session = self.get_db_session(session) obj = session.query(db_class).filter_by(**kwargs).one_or_none() - if obj is None: - return ApiResponse( - success=False, error=f"{db_class} object ({kwargs}) not found" - ) - return ApiResponse(success=True, data=api_class.from_orm_object(obj)) - - # def _get_by_name(self, session: sqlalchemy.orm.Session, db_class, api_class, name: str) -> ApiResponse: - # """ - # Get an object from the database by name. - # - # :param session: The session to use. - # :param db_class: The DB class of the object. - # :param api_class: The API class of the object. - # - # :return: A response object with the success status and the object when successful. - # """ - # session = self.get_db_session(session) - # obj = session.query(db_class).filter_by(name=name).one_or_none() - # if obj is None: - # return ApiResponse(success=False, error=f"{db_class} object ({name}) not found") - # return ApiResponse(success=True, data=api_class.from_orm_object(obj)) + if obj: + return api_class.from_orm_object(obj) def _update( self, session: sqlalchemy.orm.Session, db_class, api_object, **kwargs - ) -> ApiResponse: + ) -> Type[api_models.Base]: """ Update an object in the database. @@ -148,7 +123,7 @@ def _update( :param api_object: The API object with the new data. :param kwargs: The keyword arguments to filter the object. - :return: A response object with the success status and the updated object when successful. + :return: The updated object. """ session = self.get_db_session(session) obj = session.query(db_class).filter_by(**kwargs).one_or_none() @@ -156,32 +131,23 @@ def _update( api_object.merge_into_orm_object(obj) session.add(obj) session.commit() - return ApiResponse( - success=True, data=api_object.__class__.from_orm_object(obj) - ) + return api_object.__class__.from_orm_object(obj) else: - return ApiResponse( - success=False, error=f"{db_class} object ({kwargs}) not found" - ) + raise ValueError(f"{db_class} object ({kwargs}) not found") - def _delete( - self, session: sqlalchemy.orm.Session, db_class, **kwargs - ) -> ApiResponse: + def _delete(self, session: sqlalchemy.orm.Session, db_class, **kwargs) -> None: """ Delete an object from the database. :param session: The session to use. :param db_class: The DB class of the object. :param kwargs: The keyword arguments to filter the object. - - :return: A response object with the success status. """ session = self.get_db_session(session) query = session.query(db_class).filter_by(**kwargs) for obj in query: session.delete(obj) session.commit() - return ApiResponse(success=True) def _list( self, @@ -191,7 +157,7 @@ def _list( output_mode: api_models.OutputMode, labels_match: List[str] = None, filters: list = None, - ) -> ApiResponse: + ) -> List: """ List objects from the database. @@ -202,7 +168,7 @@ def _list( :param labels_match: The labels to match, filter the objects by labels. :param filters: The filters to apply. - :return: A response object with the success status and the list of objects when successful. + :return: A list of the desired objects. """ session = self.get_db_session(session) @@ -216,19 +182,18 @@ def _list( pass output = query.all() logger.debug(f"output: {output}") - data = _process_output(output, api_class, output_mode) - return ApiResponse(success=True, data=data) + return _process_output(output, api_class, output_mode) def create_user( self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + ): """ Create a new user in the database. :param user: The user object to create. :param session: The session to use. - :return: A response object with the success status and the created user when successful. + :return: The created user. """ logger.debug(f"Creating user: {user}") if isinstance(user, dict): @@ -242,17 +207,17 @@ def get_user( user_name: str = None, email: str = None, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Get a user from the database. Either user_id or user_name or email must be provided. - :param user_id: The ID of the user to get. + :param user_id: The UID of the user to get. :param user_name: The name of the user to get. :param email: The email of the user to get. :param session: The session to use. - :return: A response object with the success status and the user when successful. + :return: The user. """ args = {} if email: @@ -260,66 +225,65 @@ def get_user( elif user_name: args["name"] = user_name elif user_id: - args["id"] = user_id + args["uid"] = user_id else: - return ApiResponse( - success=False, error="user_id or user_name or email must be provided" - ) + raise ValueError("Either user_id or user_name or email must be provided") logger.debug(f"Getting user: user_id={user_id}, user_name={user_name}") return self._get(session, db.User, api_models.User, **args) def update_user( self, user: Union[api_models.User, dict], session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + ): """ Update an existing user in the database. :param user: The user object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated user when successful. + :return: The updated user. """ logger.debug(f"Updating user: {user}") if isinstance(user, dict): user = api_models.User.from_dict(user) - return self._update(session, db.User, user, name=user.name) + return self._update(session, db.User, user, uid=user.uid) - def delete_user( - self, user_name: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + def delete_user(self, uid: str, session: sqlalchemy.orm.Session = None): """ Delete a user from the database. - :param user_name: The name of the user to delete. - :param session: - :return: + :param uid: The UID of the user to delete. + :param session: The session to use. """ - logger.debug(f"Deleting user: user_name={user_name}") - return self._delete(session, db.User, name=user_name) + logger.debug(f"Deleting user: user_uid={uid}") + self._delete(session, db.User, uid=uid) def list_users( self, + name: str = None, email: str = None, full_name: str = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List users from the database. + :param name: The name to filter the users by. :param email: The email to filter the users by. :param full_name: The full name to filter the users by. :param labels_match: The labels to match, filter the users by labels. :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of users when successful. + :return: List of users. """ logger.debug( f"Getting users: email={email}, full_name={full_name}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.User.name == name) if email: filters.append(db.User.email == email) if full_name: @@ -337,30 +301,28 @@ def create_project( self, project: Union[api_models.Project, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new project in the database. :param project: The project object to create. :param session: The session to use. - :return: A response object with the success status and the created project when successful. + :return: The created project. """ logger.debug(f"Creating project: {project}") if isinstance(project, dict): project = api_models.Project.from_dict(project) return self._create(session, db.Project, project) - def get_project( - self, project_name: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + def get_project(self, project_name: str, session: sqlalchemy.orm.Session = None): """ Get a project from the database. :param project_name: The name of the project to get. :param session: The session to use. - :return: A response object with the success status and the project when successful. + :return: The requested project. """ logger.debug(f"Getting project: project_name={project_name}") return self._get(session, db.Project, api_models.Project, name=project_name) @@ -369,57 +331,57 @@ def update_project( self, project: Union[api_models.Project, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing project in the database. :param project: The project object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated project when successful. + :return: The updated project. """ logger.debug(f"Updating project: {project}") if isinstance(project, dict): project = api_models.Project.from_dict(project) - return self._update(session, db.Project, project, name=project.name) + return self._update(session, db.Project, project, uid=project.uid) - def delete_project( - self, project_name: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + def delete_project(self, uid: str, session: sqlalchemy.orm.Session = None): """ Delete a project from the database. - :param project_name: The name of the project to delete. - :param session: The session to use. - - :return: A response object with the success status. + :param uid: The UID of the project to delete. + :param session: The session to use. """ - logger.debug(f"Deleting project: project_name={project_name}") - return self._delete(session, db.Project, name=project_name) + logger.debug(f"Deleting project: project_uid={uid}") + self._delete(session, db.Project, uid=uid) def list_projects( self, + name: str = None, owner_id: str = None, version: str = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List projects from the database. + :param name: The name to filter the projects by. :param owner_id: The owner to filter the projects by. :param version: The version to filter the projects by. :param labels_match: The labels to match, filter the projects by labels. :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of projects when successful. + :return: List of projects. """ logger.debug( f"Getting projects: owner_id={owner_id}, version={version}, labels_match={labels_match}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.Project.name == name) if owner_id: filters.append(db.Project.owner_id == owner_id) if version: @@ -437,14 +399,14 @@ def create_data_source( self, data_source: Union[api_models.DataSource, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new data source in the database. :param data_source: The data source object to create. :param session: The session to use. - :return: A response object with the success status and the created data source when successful. + :return: The created data source. """ logger.debug(f"Creating data source: {data_source}") if isinstance(data_source, dict): @@ -454,24 +416,24 @@ def create_data_source( def get_data_source( self, project_id: str, - data_source_name: str, + uid: str, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Get a data source from the database. - :param project_id: The ID of the project to get the data source from. - :param data_source_name: The ID of the data source to get. - :param session: The session to use. + :param project_id: The ID of the project to get the data source from. + :param uid: The UID of the data source to get. + :param session: The session to use. - :return: A response object with the success status and the data source when successful. + :return: The requested data source. """ - logger.debug(f"Getting data source: data_source_name={data_source_name}") + logger.debug(f"Getting data source: data_source_uid={uid}") return self._get( session, db.DataSource, api_models.DataSource, - name=data_source_name, + uid=uid, project_id=project_id, ) @@ -479,53 +441,53 @@ def update_data_source( self, data_source: Union[api_models.DataSource, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing data source in the database. :param data_source: The data source object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated data source when successful. + :return: The updated data source. """ logger.debug(f"Updating data source: {data_source}") if isinstance(data_source, dict): data_source = api_models.DataSource.from_dict(data_source) - return self._update(session, db.DataSource, data_source) + return self._update(session, db.DataSource, data_source, uid=data_source.uid) def delete_data_source( self, project_id: str, - data_source_id: str, + uid: str, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Delete a data source from the database. - :param project_id: The ID of the project to delete the data source from. - :param data_source_id: The ID of the data source to delete. - :param session: The session to use. + :param project_id: The ID of the project to delete the data source from. + :param uid: The ID of the data source to delete. + :param session: The session to use. :return: A response object with the success status. """ - logger.debug(f"Deleting data source: data_source_id={data_source_id}") - return self._delete( - session, db.DataSource, project_id=project_id, id=data_source_id - ) + logger.debug(f"Deleting data source: data_source_id={uid}") + self._delete(session, db.DataSource, project_id=project_id, uid=uid) def list_data_sources( self, + name: str = None, owner_id: str = None, version: str = None, project_id: str = None, data_source_type: Union[api_models.DataSourceType, str] = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List data sources from the database. + :param name: The name to filter the data sources by. :param owner_id: The owner to filter the data sources by. :param version: The version to filter the data sources by. :param project_id: The project to filter the data sources by. @@ -534,13 +496,15 @@ def list_data_sources( :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of data sources when successful. + :return: List of data sources. """ logger.debug( - f"Getting collections: owner_id={owner_id}, version={version}, data_source_type={data_source_type}," - f" labels_match={labels_match}, mode={output_mode}" + f"Getting data sources: name={name}, owner_id={owner_id}, version={version}," + f" data_source_type={data_source_type}, labels_match={labels_match}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.DataSource.name == name) if owner_id: filters.append(db.DataSource.owner_id == owner_id) if version: @@ -562,14 +526,14 @@ def create_dataset( self, dataset: Union[api_models.Dataset, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new dataset in the database. :param dataset: The dataset object to create. :param session: The session to use. - :return: A response object with the success status and the created dataset when successful. + :return: The created dataset. """ logger.debug(f"Creating dataset: {dataset}") if isinstance(dataset, dict): @@ -577,23 +541,23 @@ def create_dataset( return self._create(session, db.Dataset, dataset) def get_dataset( - self, project_id: str, dataset_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + ): """ Get a dataset from the database. :param project_id: The ID of the project to get the dataset from. - :param dataset_id: The ID of the dataset to get. + :param uid: The UID of the dataset to get. :param session: The session to use. - :return: A response object with the success status and the dataset when successful. + :return: The requested dataset. """ - logger.debug(f"Getting dataset: dataset_id={dataset_id}") + logger.debug(f"Getting dataset: dataset_id={uid}") return self._get( session, db.Dataset, api_models.Dataset, - id=dataset_id, + uid=uid, project_id=project_id, ) @@ -601,48 +565,48 @@ def update_dataset( self, dataset: Union[api_models.Dataset, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing dataset in the database. :param dataset: The dataset object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated dataset when successful. + :return: The updated dataset. """ logger.debug(f"Updating dataset: {dataset}") if isinstance(dataset, dict): dataset = api_models.Dataset.from_dict(dataset) - return self._update(session, db.Dataset, dataset, id=dataset.id) + return self._update(session, db.Dataset, dataset, uid=dataset.uid) def delete_dataset( - self, project_id: str, dataset_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + ): """ Delete a dataset from the database. :param project_id: The ID of the project to delete the dataset from. - :param dataset_id: The ID of the dataset to delete. + :param uid: The ID of the dataset to delete. :param session: The session to use. - - :return: A response object with the success status. """ - logger.debug(f"Deleting dataset: dataset_id={dataset_id}") - return self._delete(session, db.Dataset, project_id=project_id, id=dataset_id) + logger.debug(f"Deleting dataset: dataset_id={uid}") + self._delete(session, db.Dataset, project_id=project_id, uid=uid) def list_datasets( self, + name: str = None, owner_id: str = None, version: str = None, project_id: str = None, task: str = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List datasets from the database. + :param name: The name to filter the datasets by. :param owner_id: The owner to filter the datasets by. :param version: The version to filter the datasets by. :param project_id: The project to filter the datasets by. @@ -651,13 +615,15 @@ def list_datasets( :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of datasets when successful. + :return: The list of datasets. """ logger.debug( f"Getting datasets: owner_id={owner_id}, version={version}, task={task}, labels_match={labels_match}," f" mode={output_mode}" ) filters = [] + if name: + filters.append(db.Dataset.name == name) if owner_id: filters.append(db.Dataset.owner_id == owner_id) if version: @@ -679,14 +645,14 @@ def create_model( self, model: Union[api_models.Model, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new model in the database. :param model: The model object to create. :param session: The session to use. - :return: A response object with the success status and the created model when successful. + :return: The created model. """ logger.debug(f"Creating model: {model}") if isinstance(model, dict): @@ -694,69 +660,69 @@ def create_model( return self._create(session, db.Model, model) def get_model( - self, project_id: str, model_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + ): """ Get a model from the database. :param project_id: The ID of the project to get the model from. - :param model_id: The ID of the model to get. + :param uid: The UID of the model to get. :param session: The session to use. - :return: A response object with the success status and the model when successful. + :return: The requested model. """ - logger.debug(f"Getting model: model_id={model_id}") + logger.debug(f"Getting model: model_id={uid}") return self._get( - session, db.Model, api_models.Model, project_id=project_id, id=model_id + session, db.Model, api_models.Model, project_id=project_id, uid=uid ) def update_model( self, model: Union[api_models.Model, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing model in the database. :param model: The model object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated model when successful. + :return: The updated model. """ logger.debug(f"Updating model: {model}") if isinstance(model, dict): model = api_models.Model.from_dict(model) - return self._update(session, db.Model, model, id=model.id) + return self._update(session, db.Model, model, uid=model.uid) def delete_model( - self, project_id: str, model_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + ): """ Delete a model from the database. :param project_id: The ID of the project to delete the model from. - :param model_id: The ID of the model to delete. + :param uid: The UID of the model to delete. :param session: The session to use. - - :return: A response object with the success status. """ - logger.debug(f"Deleting model: model_id={model_id}") - return self._delete(session, db.Model, project_id=project_id, id=model_id) + logger.debug(f"Deleting model: model_id={uid}") + self._delete(session, db.Model, project_id=project_id, uid=uid) def list_models( self, + name: str = None, owner_id: str = None, version: str = None, project_id: str = None, model_type: str = None, task: str = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List models from the database. + :param name: The name to filter the models by. :param owner_id: The owner to filter the models by. :param version: The version to filter the models by. :param project_id: The project to filter the models by. @@ -766,13 +732,15 @@ def list_models( :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of models when successful. + :return: The list of models. """ logger.debug( f"Getting models: owner_id={owner_id}, version={version}, project_id={project_id}," f" model_type={model_type}, task={task}, labels_match={labels_match}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.Model.name == name) if owner_id: filters.append(db.Model.owner_id == owner_id) if version: @@ -796,14 +764,14 @@ def create_prompt_template( self, prompt_template: Union[api_models.PromptTemplate, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new prompt template in the database. :param prompt_template: The prompt template object to create. :param session: The session to use. - :return: A response object with the success status and the created prompt template when successful. + :return: The created prompt template. """ logger.debug(f"Creating prompt template: {prompt_template}") if isinstance(prompt_template, dict): @@ -813,83 +781,77 @@ def create_prompt_template( def get_prompt_template( self, project_id: str, - prompt_template_id: str, + uid: str, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Get a prompt template from the database. - :param project_id: The ID of the project to get the prompt template from. - :param prompt_template_id: The ID of the prompt template to get. - :param session: The session to use. + :param project_id: The ID of the project to get the prompt template from. + :param uid: The UID of the prompt template to get. + :param session: The session to use. - :return: A response object with the success status and the prompt template when successful. + :return: The requested prompt template. """ - logger.debug( - f"Getting prompt template: prompt_template_id={prompt_template_id}" - ) + logger.debug(f"Getting prompt template: prompt_template_id={uid}") return self._get( session, db.PromptTemplate, api_models.PromptTemplate, project_id=project_id, - id=prompt_template_id, + uid=uid, ) def update_prompt_template( self, prompt_template: Union[api_models.PromptTemplate, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing prompt template in the database. :param prompt_template: The prompt template object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated prompt template when successful. + :return: The updated prompt template. """ logger.debug(f"Updating prompt template: {prompt_template}") if isinstance(prompt_template, dict): prompt_template = api_models.PromptTemplate.from_dict(prompt_template) return self._update( - session, db.PromptTemplate, prompt_template, id=prompt_template.id + session, db.PromptTemplate, prompt_template, uid=prompt_template.uid ) def delete_prompt_template( self, project_id: str, - prompt_template_id: str, + uid: str, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Delete a prompt template from the database. - :param project_id: The ID of the project to delete the prompt template from. - :param prompt_template_id: The ID of the prompt template to delete. - :param session: The session to use. - - :return: A response object with the success status. + :param project_id: The ID of the project to delete the prompt template from. + :param uid: The ID of the prompt template to delete. + :param session: The session to use. """ - logger.debug( - f"Deleting prompt template: prompt_template_id={prompt_template_id}" - ) - return self._delete( - session, db.PromptTemplate, project_id=project_id, id=prompt_template_id - ) + logger.debug(f"Deleting prompt template: prompt_template_id={uid}") + self._delete(session, db.PromptTemplate, project_id=project_id, uid=uid) def list_prompt_templates( self, + name: str = None, owner_id: str = None, version: str = None, project_id: str = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List prompt templates from the database. + :param name: The name to filter the prompt templates by. :param owner_id: The owner to filter the prompt templates by. :param version: The version to filter the prompt templates by. :param project_id: The project to filter the prompt templates by. @@ -897,13 +859,15 @@ def list_prompt_templates( :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of prompt templates when successful. + :return: The list of prompt templates. """ logger.debug( f"Getting prompt templates: owner_id={owner_id}, version={version}, project_id={project_id}," f" labels_match={labels_match}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.PromptTemplate.name == name) if owner_id: filters.append(db.PromptTemplate.owner_id == owner_id) if version: @@ -923,14 +887,14 @@ def create_document( self, document: Union[api_models.Document, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new document in the database. :param document: The document object to create. :param session: The session to use. - :return: A response object with the success status and the created document when successful. + :return: The created document. """ logger.debug(f"Creating document: {document}") if isinstance(document, dict): @@ -938,71 +902,71 @@ def create_document( return self._create(session, db.Document, document) def get_document( - self, project_id: str, document_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + ): """ Get a document from the database. - :param project_id: The ID of the project to get the document from. - :param document_id: The ID of the document to get. - :param session: The session to use. + :param project_id: The ID of the project to get the document from. + :param uid: The UID of the document to get. + :param session: The session to use. - :return: A response object with the success status and the document when successful. + :return: The requested document. """ - logger.debug(f"Getting document: document_id={document_id}") + logger.debug(f"Getting document: document_id={uid}") return self._get( session, db.Document, api_models.Document, project_id=project_id, - id=document_id, + uid=uid, ) def update_document( self, document: Union[api_models.Document, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing document in the database. - :param document: The document object with the new data. - :param session: The session to use. + :param document: The document object with the new data. + :param session: The session to use. - :return: A response object with the success status and the updated document when successful. + :return: The updated document. """ logger.debug(f"Updating document: {document}") if isinstance(document, dict): document = api_models.Document.from_dict(document) - return self._update(session, db.Document, document, id=document.id) + return self._update(session, db.Document, document, uid=document.uid) def delete_document( - self, project_id: str, document_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + self, project_id: str, uid: str, session: sqlalchemy.orm.Session = None + ): """ Delete a document from the database. :param project_id: The ID of the project to delete the document from. - :param document_id: The ID of the document to delete. + :param uid: The UID of the document to delete. :param session: The session to use. - - :return: A response object with the success status. """ - logger.debug(f"Deleting document: document_id={document_id}") - return self._delete(session, db.Document, project_id=project_id, id=document_id) + logger.debug(f"Deleting document: document_id={uid}") + self._delete(session, db.Document, project_id=project_id, uid=uid) def list_documents( self, + name: str = None, owner_id: str = None, version: str = None, project_id: str = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List documents from the database. + :param name: The name to filter the documents by. :param owner_id: The owner to filter the documents by. :param version: The version to filter the documents by. :param project_id: The project to filter the documents by. @@ -1010,13 +974,15 @@ def list_documents( :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of documents when successful. + :return: The list of documents. """ logger.debug( f"Getting documents: owner_id={owner_id}, version={version}, project_id={project_id}," f" labels_match={labels_match}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.Document.name == name) if owner_id: filters.append(db.Document.owner_id == owner_id) if version: @@ -1036,14 +1002,14 @@ def create_workflow( self, workflow: Union[api_models.Workflow, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new workflow in the database. :param workflow: The workflow object to create. :param session: The session to use. - :return: A response object with the success status and the created workflow when successful. + :return: The created workflow. """ logger.debug(f"Creating workflow: {workflow}") if isinstance(workflow, dict): @@ -1053,72 +1019,70 @@ def create_workflow( def get_workflow( self, project_id: str, - workflow_name: str, + uid: str, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Get a workflow from the database. :param project_id: The ID of the project to get the workflow from. - :param workflow_name: The name of the workflow to get. + :param uid: The UID of the workflow to get. :param session: The session to use. - :return: A response object with the success status and the workflow when successful. + :return: The requested workflow. """ - logger.debug(f"Getting workflow: workflow_name={workflow_name}") + logger.debug(f"Getting workflow: workflow_uid={uid}") return self._get( session, db.Workflow, api_models.Workflow, project_id=project_id, - name=workflow_name, + uid=uid, ) def update_workflow( self, workflow: Union[api_models.Workflow, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update an existing workflow in the database. :param workflow: The workflow object with the new data. :param session: The session to use. - :return: A response object with the success status and the updated workflow when successful. + :return: The updated workflow. """ logger.debug(f"Updating workflow: {workflow}") if isinstance(workflow, dict): workflow = api_models.Workflow.from_dict(workflow) - return self._update(session, db.Workflow, workflow, id=workflow.id) + return self._update(session, db.Workflow, workflow, uid=workflow.uid) - def delete_workflow( - self, workflow_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + def delete_workflow(self, uid: str, session: sqlalchemy.orm.Session = None): """ Delete a workflow from the database. - :param workflow_id: The ID of the workflow to delete. - :param session: The session to use. - - :return: A response object with the success status. + :param uid: The ID of the workflow to delete. + :param session: The session to use. """ - logger.debug(f"Deleting workflow: workflow_id={workflow_id}") - return self._delete(session, db.Workflow, id=workflow_id) + logger.debug(f"Deleting workflow: workflow_id={uid}") + self._delete(session, db.Workflow, uid=uid) def list_workflows( self, + name: str = None, owner_id: str = None, version: str = None, project_id: str = None, workflow_type: Union[api_models.WorkflowType, str] = None, labels_match: Union[list, str] = None, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List workflows from the database. + :param name: The name to filter the workflows by. :param owner_id: The owner to filter the workflows by. :param version: The version to filter the workflows by. :param project_id: The project to filter the workflows by. @@ -1127,13 +1091,15 @@ def list_workflows( :param output_mode: The output mode. :param session: The session to use. - :return: A response object with the success status and the list of workflows when successful. + :return: The list of workflows. """ logger.debug( - f"Getting workflows: owner_id={owner_id}, version={version}, project_id={project_id}," + f"Getting workflows: name={name}, owner_id={owner_id}, version={version}, project_id={project_id}," f" workflow_type={workflow_type}, labels_match={labels_match}, mode={output_mode}" ) filters = [] + if name: + filters.append(db.Workflow.name == name) if owner_id: filters.append(db.Workflow.owner_id == owner_id) if version: @@ -1155,14 +1121,14 @@ def create_chat_session( self, chat_session: Union[api_models.ChatSession, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Create a new chat session in the database. :param chat_session: The chat session object to create. :param session: The session to use. - :return: A response object with the success status and the created chat session when successful. + :return: The created chat session. """ logger.debug(f"Creating chat session: {chat_session}") if isinstance(chat_session, dict): @@ -1171,80 +1137,67 @@ def create_chat_session( def get_chat_session( self, - session_name: str = None, + uid: str = None, user_id: str = None, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Get a chat session from the database. - :param session_name: The ID of the chat session to get. - :param user_id: The ID of the user to get the last session for. - :param session: The DB session to use. + :param uid: The ID of the chat session to get. + :param user_id: The UID of the user to get the last session for. + :param session: The DB session to use. - :return: A response object with the success status and the chat session when successful. + :return: The requested chat session. """ - logger.debug( - f"Getting chat session: session_name={session_name}, user_id={user_id}" - ) - if session_name: - return self._get( - session, db.Session, api_models.ChatSession, name=session_name - ) + logger.debug(f"Getting chat session: session_uid={uid}, user_id={user_id}") + if uid: + return self._get(session, db.Session, api_models.ChatSession, uid=uid) elif user_id: # get the last session for the user - resp = self.list_chat_sessions(user_id=user_id, last=1, session=session) - if resp.success: - data = resp.data[0] if resp.data else None - return ApiResponse(success=True, data=data) - return resp - else: - return ApiResponse( - success=False, error="session_id or username must be provided" - ) + return self.list_chat_sessions(user_id=user_id, last=1, session=session)[0] + raise ValueError("session_name or user_id must be provided") def update_chat_session( self, chat_session: Union[api_models.ChatSession, dict], session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ Update a chat session in the database. :param chat_session: The chat session object with the new data. :param session: The DB session to use. - :return: A response object with the success status and the updated chat session when successful. + :return: The updated chat session. """ logger.debug(f"Updating chat session: {chat_session}") - return self._update(session, db.Session, chat_session, name=chat_session.name) + return self._update(session, db.Session, chat_session, uid=chat_session.uid) - def delete_chat_session( - self, session_id: str, session: sqlalchemy.orm.Session = None - ) -> ApiResponse: + def delete_chat_session(self, uid: str, session: sqlalchemy.orm.Session = None): """ Delete a chat session from the database. - :param session_id: The ID of the chat session to delete. - :param session: The DB session to use. - - :return: A response object with the success status. + :param uid: The UID of the chat session to delete. + :param session: The DB session to use. """ - logger.debug(f"Deleting chat session: session_id={session_id}") - return self._delete(session, db.Session, id=session_id) + logger.debug(f"Deleting chat session: session_id={uid}") + self._delete(session, db.Session, uid=uid) def list_chat_sessions( self, + name: str = None, user_id: str = None, workflow_id: str = None, created_after=None, last=0, - output_mode: api_models.OutputMode = api_models.OutputMode.Details, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, session: sqlalchemy.orm.Session = None, - ) -> ApiResponse: + ): """ List chat sessions from the database. + :param name: The name to filter the chat sessions by. :param user_id: The user ID to filter the chat sessions by. :param workflow_id: The workflow ID to filter the chat sessions by. :param created_after: The date to filter the chat sessions by. @@ -1252,7 +1205,7 @@ def list_chat_sessions( :param output_mode: The output mode. :param session: The DB session to use. - :return: A response object with the success status and the list of chat sessions when successful. + :return: The list of chat sessions. """ logger.debug( f"Getting chat sessions: user_id={user_id}, workflow_id={workflow_id} created>{created_after}," @@ -1260,6 +1213,8 @@ def list_chat_sessions( ) session = self.get_db_session(session) query = session.query(db.Session) + if name: + query = query.filter(db.Session.name == name) if user_id: query = query.filter(db.Session.owner_id == user_id) if workflow_id: @@ -1273,8 +1228,7 @@ def list_chat_sessions( query = query.order_by(db.Session.updated.desc()) if last > 0: query = query.limit(last) - data = _process_output(query.all(), api_models.ChatSession, output_mode) - return ApiResponse(success=True, data=data) + return _process_output(query.all(), api_models.ChatSession, output_mode) def _dict_to_object(cls, d): @@ -1284,12 +1238,12 @@ def _dict_to_object(cls, d): def _process_output( - items, obj_class, mode: api_models.OutputMode = api_models.OutputMode.Details -): - if mode == api_models.OutputMode.Names: + items, obj_class, mode: api_models.OutputMode = api_models.OutputMode.DETAILS +) -> Union[list, dict]: + if mode == api_models.OutputMode.NAMES: return [item.name for item in items] items = [obj_class.from_orm_object(item) for item in items] - if mode == api_models.OutputMode.Details: + if mode == api_models.OutputMode.DETAILS: return items - short = mode == api_models.OutputMode.Short + short = mode == api_models.OutputMode.SHORT return [item.to_dict(short=short) for item in items] diff --git a/controller/src/db/sqldb.py b/controller/src/db/sqldb.py index d97ffdc..468814d 100644 --- a/controller/src/db/sqldb.py +++ b/controller/src/db/sqldb.py @@ -50,7 +50,7 @@ class Label(Base): Index(f"idx_{table}_labels_name_value", "name", "value"), ) - id = Column(Integer, primary_key=True) + uid = Column(Integer, primary_key=True) name = Column(String(255, None)) # in mysql collation="utf8_bin" value = Column(String(255, collation=None)) parent = Column(Integer, ForeignKey(f"{table}.name")) @@ -74,9 +74,9 @@ class BaseSchema(Base): Base class for all tables. We use this class to define common columns and methods for all tables. - :arg id: unique identifier for each entry. - :arg name: entry's name. - :arg description: The entry's description. + :arg uid: unique identifier for each entry. + :arg name: entry's name. + :arg description: The entry's description. The following columns are automatically added to each table: - created: The entry's creation date. @@ -100,7 +100,7 @@ def labels(cls): return relationship(cls.Label, cascade="all, delete-orphan") # Columns: - id: Mapped[str] = mapped_column(String(ID_LENGTH), primary_key=True) + uid: Mapped[str] = mapped_column(String(ID_LENGTH), primary_key=True) name: Mapped[str] = mapped_column(String(255), unique=True) description: Mapped[Optional[str]] created: Mapped[datetime.datetime] = mapped_column(default=datetime.datetime.utcnow) @@ -110,8 +110,8 @@ def labels(cls): ) spec = Column(MutableDict.as_mutable(JSON), nullable=True) - def __init__(self, id, name, spec, description=None, labels=None): - self.id = id + def __init__(self, uid, name, spec, description=None, labels=None): + self.uid = uid self.name = name self.spec = spec self.description = description @@ -123,66 +123,66 @@ class OwnerBaseSchema(BaseSchema): Base class for all tables with owner. We use this class to define common columns and methods for all tables with owner. - :arg owner_id: The entry's owner's id. + :arg owner_id: The entry's owner's id. """ __abstract__ = True owner_id: Mapped[Optional[str]] = mapped_column( - String(ID_LENGTH), ForeignKey("user.id") + String(ID_LENGTH), ForeignKey("user.uid") ) - def __init__(self, id, name, spec, description=None, owner_id=None, labels=None): - super().__init__(id, name, spec, description, labels) + def __init__(self, uid, name, spec, description=None, owner_id=None, labels=None): + super().__init__(uid, name, spec, description, labels) self.owner_id = owner_id -class VersionedBaseSchema(OwnerBaseSchema): +class VersionedOwnerBaseSchema(OwnerBaseSchema): """ Base class for all versioned tables. We use this class to define common columns and methods for all versioned tables. - :arg version: The entry's version. This is the primary key for the table with id. + :arg version: The entry's version. This is the primary key for the table with uid. """ __abstract__ = True version: Mapped[str] = mapped_column(String(255), primary_key=True, default="") def __init__( - self, id, name, spec, version, description=None, owner_id=None, labels=None + self, uid, name, spec, version, description=None, owner_id=None, labels=None ): - super().__init__(id, name, spec, description, owner_id, labels) + super().__init__(uid, name, spec, description, owner_id, labels) self.version = version # Association table between users and projects for many-to-many relationship -user_project = Table( - "user_project", +user_to_project = Table( + "user_to_project", Base.metadata, - Column("user_id", ForeignKey("user.id")), - Column("project_id", ForeignKey("project.id")), + Column("user_id", ForeignKey("user.uid")), + Column("project_id", ForeignKey("project.uid")), Column("project_version", ForeignKey("project.version")), ) # Association table between models and prompt templates for many-to-many relationship -model_prompt_template = Table( - "model_prompt_template", +prompt_template_to_model = Table( + "prompt_template_to_model", Base.metadata, - Column("prompt_id", String(ID_LENGTH), ForeignKey("prompt_template.id")), + Column("prompt_id", String(ID_LENGTH), ForeignKey("prompt_template.uid")), Column( "prompt_version", String(TEXT_LENGTH), ForeignKey("prompt_template.version") ), - Column("model_id", String(ID_LENGTH), ForeignKey("model.id")), + Column("model_id", String(ID_LENGTH), ForeignKey("model.uid")), Column("model_version", String(TEXT_LENGTH), ForeignKey("model.version")), Column("generation_config", JSON), ) # Association table between documents and data sources (ingestions) for many-to-many relationship -ingestions = Table( - "ingestions", +document_to_data_source = Table( + "document_to_data_source", Base.metadata, - Column("document_id", String(ID_LENGTH), ForeignKey("document.id")), + Column("document_id", String(ID_LENGTH), ForeignKey("document.uid")), Column("document_version", String(TEXT_LENGTH), ForeignKey("document.version")), - Column("data_source_id", String(ID_LENGTH), ForeignKey("data_source.id")), + Column("data_source_id", String(ID_LENGTH), ForeignKey("data_source.uid")), Column( "data_source_version", String(TEXT_LENGTH), ForeignKey("data_source.version") ), @@ -194,8 +194,8 @@ class User(BaseSchema): """ The User table which is used to define users. - :arg full_name: The user's full name. - :arg email: The user's email. + :arg full_name: The user's full name. + :arg email: The user's email. """ # Columns: @@ -207,14 +207,14 @@ class User(BaseSchema): # many-to-many relationship with projects: projects: Mapped[List["Project"]] = relationship( back_populates="users", - secondary=user_project, - primaryjoin="User.id == user_project.c.user_id", - secondaryjoin="and_(Project.id == user_project.c.project_id," - " Project.version == user_project.c.project_version)", + secondary=user_to_project, + primaryjoin="User.uid == user_to_project.c.user_id", + secondaryjoin="and_(Project.uid == user_to_project.c.project_id," + " Project.version == user_to_project.c.project_version)", foreign_keys=[ - user_project.c.user_id, - user_project.c.project_id, - user_project.c.project_version, + user_to_project.c.user_id, + user_to_project.c.project_id, + user_to_project.c.project_version, ], ) # one-to-many relationship with sessions: @@ -222,15 +222,17 @@ class User(BaseSchema): back_populates="user", foreign_keys="Session.owner_id" ) - def __init__(self, id, name, email, full_name, spec, description=None, labels=None): + def __init__( + self, uid, name, email, full_name, spec, description=None, labels=None + ): super().__init__( - id=id, name=name, description=description, spec=spec, labels=labels + uid=uid, name=name, description=description, spec=spec, labels=labels ) self.email = email self.full_name = full_name -class Project(VersionedBaseSchema): +class Project(VersionedOwnerBaseSchema): """ The Project table which is used as a workspace. The other tables are associated with a project. """ @@ -240,13 +242,14 @@ class Project(VersionedBaseSchema): # many-to-many relationship with user: users: Mapped[List["User"]] = relationship( back_populates="projects", - secondary=user_project, - primaryjoin="and_(Project.id == user_project.c.project_id, Project.version == user_project.c.project_version)", - secondaryjoin="User.id == user_project.c.user_id", + secondary=user_to_project, + primaryjoin="and_(Project.uid == user_to_project.c.project_id," + " Project.version == user_to_project.c.project_version)", + secondaryjoin="User.uid == user_to_project.c.user_id", foreign_keys=[ - user_project.c.user_id, - user_project.c.project_id, - user_project.c.project_version, + user_to_project.c.user_id, + user_to_project.c.project_id, + user_to_project.c.project_version, ], ) @@ -260,10 +263,10 @@ class Project(VersionedBaseSchema): workflows: Mapped[List["Workflow"]] = relationship(**relationship_args) def __init__( - self, id, name, spec, version, description=None, owner_id=None, labels=None + self, uid, name, spec, version, description=None, owner_id=None, labels=None ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -274,7 +277,7 @@ def __init__( update_labels(self, {"_GENAI_FACTORY": True}) -class DataSource(VersionedBaseSchema): +class DataSource(VersionedOwnerBaseSchema): """ The DataSource table which is used to define data sources for the project. @@ -284,7 +287,9 @@ class DataSource(VersionedBaseSchema): """ # Columns: - project_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("project.id")) + project_id: Mapped[str] = mapped_column( + String(ID_LENGTH), ForeignKey("project.uid") + ) data_source_type: Mapped[str] # Relationships: @@ -294,22 +299,22 @@ class DataSource(VersionedBaseSchema): # many-to-many relationship with documents: documents: Mapped[List["Document"]] = relationship( back_populates="data_sources", - secondary=ingestions, - primaryjoin="and_(DataSource.id == ingestions.c.data_source_id," - " DataSource.version == ingestions.c.data_source_version)", - secondaryjoin="and_(Document.id == ingestions.c.document_id," - " Document.version == ingestions.c.document_version)", + secondary=document_to_data_source, + primaryjoin="and_(DataSource.uid == document_to_data_source.c.data_source_id," + " DataSource.version == document_to_data_source.c.data_source_version)", + secondaryjoin="and_(Document.uid == document_to_data_source.c.document_id," + " Document.version == document_to_data_source.c.document_version)", foreign_keys=[ - ingestions.c.data_source_id, - ingestions.c.data_source_version, - ingestions.c.document_id, - ingestions.c.document_version, + document_to_data_source.c.data_source_id, + document_to_data_source.c.data_source_version, + document_to_data_source.c.document_id, + document_to_data_source.c.document_version, ], ) def __init__( self, - id, + uid, name, spec, version, @@ -320,7 +325,7 @@ def __init__( labels=None, ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -332,7 +337,7 @@ def __init__( self.data_source_type = data_source_type -class Dataset(VersionedBaseSchema): +class Dataset(VersionedOwnerBaseSchema): """ The Dataset table which is used to define datasets for the project. @@ -341,7 +346,9 @@ class Dataset(VersionedBaseSchema): """ # Columns: - project_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("project.id")) + project_id: Mapped[str] = mapped_column( + String(ID_LENGTH), ForeignKey("project.uid") + ) task: Mapped[Optional[str]] # Relationships: @@ -351,7 +358,7 @@ class Dataset(VersionedBaseSchema): def __init__( self, - id, + uid, name, spec, version, @@ -362,7 +369,7 @@ def __init__( labels=None, ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -374,7 +381,7 @@ def __init__( self.task = task -class Model(VersionedBaseSchema): +class Model(VersionedOwnerBaseSchema): """ The Model table which is used to define models for the project. @@ -384,7 +391,9 @@ class Model(VersionedBaseSchema): """ # Columns: - project_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("project.id")) + project_id: Mapped[str] = mapped_column( + String(ID_LENGTH), ForeignKey("project.uid") + ) model_type: Mapped[str] task: Mapped[Optional[str]] @@ -395,22 +404,22 @@ class Model(VersionedBaseSchema): # many-to-many relationship with prompt_templates: prompt_templates: Mapped[List["PromptTemplate"]] = relationship( back_populates="models", - secondary=model_prompt_template, - primaryjoin="and_(Model.id == model_prompt_template.c.model_id," - " Model.version == model_prompt_template.c.model_version)", - secondaryjoin="and_(PromptTemplate.id == model_prompt_template.c.prompt_id," - " PromptTemplate.version == model_prompt_template.c.prompt_version)", + secondary=prompt_template_to_model, + primaryjoin="and_(Model.uid == prompt_template_to_model.c.model_id," + " Model.version == prompt_template_to_model.c.model_version)", + secondaryjoin="and_(PromptTemplate.uid == prompt_template_to_model.c.prompt_id," + " PromptTemplate.version == prompt_template_to_model.c.prompt_version)", foreign_keys=[ - model_prompt_template.c.model_id, - model_prompt_template.c.model_version, - model_prompt_template.c.prompt_id, - model_prompt_template.c.prompt_version, + prompt_template_to_model.c.model_id, + prompt_template_to_model.c.model_version, + prompt_template_to_model.c.prompt_id, + prompt_template_to_model.c.prompt_version, ], ) def __init__( self, - id, + uid, name, spec, version, @@ -422,7 +431,7 @@ def __init__( labels=None, ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -435,16 +444,18 @@ def __init__( self.task = task -class PromptTemplate(VersionedBaseSchema): +class PromptTemplate(VersionedOwnerBaseSchema): """ The PromptTemplate table which is used to define prompt templates for the project. Each prompt template is associated with a model. - :arg project_id: The project's id. + :arg project_id: The project's id. """ # Columns: - project_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("project.id")) + project_id: Mapped[str] = mapped_column( + String(ID_LENGTH), ForeignKey("project.uid") + ) # Relationships: @@ -453,22 +464,22 @@ class PromptTemplate(VersionedBaseSchema): # many-to-many relationship with the 'Model' table models: Mapped[List["Model"]] = relationship( back_populates="prompt_templates", - secondary=model_prompt_template, - primaryjoin="and_(PromptTemplate.id == model_prompt_template.c.prompt_id," - " PromptTemplate.version == model_prompt_template.c.prompt_version)", - secondaryjoin="and_(Model.id == model_prompt_template.c.model_id," - " Model.version == model_prompt_template.c.model_version)", + secondary=prompt_template_to_model, + primaryjoin="and_(PromptTemplate.uid == prompt_template_to_model.c.prompt_id," + " PromptTemplate.version == prompt_template_to_model.c.prompt_version)", + secondaryjoin="and_(Model.uid == prompt_template_to_model.c.model_id," + " Model.version == prompt_template_to_model.c.model_version)", foreign_keys=[ - model_prompt_template.c.prompt_id, - model_prompt_template.c.prompt_version, - model_prompt_template.c.model_id, - model_prompt_template.c.model_version, + prompt_template_to_model.c.prompt_id, + prompt_template_to_model.c.prompt_version, + prompt_template_to_model.c.model_id, + prompt_template_to_model.c.model_version, ], ) def __init__( self, - id, + uid, name, spec, version, @@ -478,7 +489,7 @@ def __init__( labels=None, ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -489,17 +500,19 @@ def __init__( self.project_id = project_id -class Document(VersionedBaseSchema): +class Document(VersionedOwnerBaseSchema): """ The Document table which is used to define documents for the project. The documents are ingested into data sources. - :arg project_id: The project's id. - :arg path: The path to the document. Can be a remote file or a web page. - :arg origin: The origin location of the document. + :arg project_id: The project's id. + :arg path: The path to the document. Can be a remote file or a web page. + :arg origin: The origin location of the document. """ # Columns: - project_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("project.id")) + project_id: Mapped[str] = mapped_column( + String(ID_LENGTH), ForeignKey("project.uid") + ) path: Mapped[str] origin: Mapped[Optional[str]] @@ -510,21 +523,22 @@ class Document(VersionedBaseSchema): # many-to-many relationship with ingestion: data_sources: Mapped[List["DataSource"]] = relationship( back_populates="documents", - secondary=ingestions, - primaryjoin="and_(Document.id == ingestions.c.document_id, Document.version == ingestions.c.document_version)", - secondaryjoin="and_(DataSource.id == ingestions.c.data_source_id," - " DataSource.version == ingestions.c.data_source_version)", + secondary=document_to_data_source, + primaryjoin="and_(Document.uid == document_to_data_source.c.document_id," + " Document.version == document_to_data_source.c.document_version)", + secondaryjoin="and_(DataSource.uid == document_to_data_source.c.data_source_id," + " DataSource.version == document_to_data_source.c.data_source_version)", foreign_keys=[ - ingestions.c.document_id, - ingestions.c.document_version, - ingestions.c.data_source_id, - ingestions.c.data_source_version, + document_to_data_source.c.document_id, + document_to_data_source.c.document_version, + document_to_data_source.c.data_source_id, + document_to_data_source.c.data_source_version, ], ) def __init__( self, - id, + uid, name, spec, version, @@ -536,7 +550,7 @@ def __init__( labels=None, ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -549,18 +563,20 @@ def __init__( self.origin = origin -class Workflow(VersionedBaseSchema): +class Workflow(VersionedOwnerBaseSchema): """ The Workflow table which is used to define workflows for the project. All workflows are a DAG of steps, each with its dedicated task. - :arg project_id: The project's id. - :arg workflow_type: The type of the workflow. - Can be one of the values in controller.src.schemas.workflow.WorkflowType. + :arg project_id: The project's id. + :arg workflow_type: The type of the workflow. + Can be one of the values in controller.src.schemas.workflow.WorkflowType. """ # Columns: - project_id: Mapped[str] = mapped_column(String(ID_LENGTH), ForeignKey("project.id")) + project_id: Mapped[str] = mapped_column( + String(ID_LENGTH), ForeignKey("project.uid") + ) workflow_type: Mapped[str] # Relationships: @@ -572,7 +588,7 @@ class Workflow(VersionedBaseSchema): def __init__( self, - id, + uid, name, spec, version, @@ -583,7 +599,7 @@ def __init__( labels=None, ): super().__init__( - id=id, + uid=uid, name=name, version=version, spec=spec, @@ -599,12 +615,12 @@ class Session(OwnerBaseSchema): """ The Chat Session table which is used to define chat sessions of an application workflow per user. - :arg workflow_id: The workflow's id. + :arg workflow_id: The workflow's id. """ # Columns: workflow_id: Mapped[str] = mapped_column( - String(ID_LENGTH), ForeignKey("workflow.id") + String(ID_LENGTH), ForeignKey("workflow.uid") ) # Relationships: @@ -617,10 +633,10 @@ class Session(OwnerBaseSchema): ) def __init__( - self, id, name, spec, workflow_id, description=None, owner_id=None, labels=None + self, uid, name, spec, workflow_id, description=None, owner_id=None, labels=None ): super().__init__( - id=id, + uid=uid, name=name, spec=spec, description=description, diff --git a/controller/src/main.py b/controller/src/main.py index 70a6145..b51ef8f 100644 --- a/controller/src/main.py +++ b/controller/src/main.py @@ -58,7 +58,7 @@ def initdb(): is_admin=True, ), session=session, - ).data["id"] + ).uid # Create project: click.echo("Creating default project") @@ -69,7 +69,7 @@ def initdb(): owner_id=user_id, ), session=session, - ).data["id"] + ).uid # Create data source: click.echo("Creating default data source") @@ -93,7 +93,7 @@ def initdb(): owner_id=user_id, project_id=project_id, workflow_type="application", - deployment="http://localhost:8000", + deployment="http://localhost:8000/api/workflows/default", ), session=session, ) @@ -136,14 +136,10 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil :return: None """ session = client.get_db_session() - project = client.get_project(project_name=project, session=session).data - project = Project.from_dict(project) - data_source = client.get_data_source( - project_id=project.id, - data_source_name=data_source or "default", - session=session, - ).data - data_source = DataSource.from_dict(data_source) + project = client.get_project(project_name=project, session=session) + data_source = client.list_data_sources( + project_id=project.uid, name=data_source, session=session + )[0] # Create document from path: document = Document( @@ -151,7 +147,7 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil version=version, path=path, owner_id=data_source.owner_id, - project_id=project.id, + project_id=project.uid, ) # Add document to the database: @@ -159,7 +155,7 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil document=document, session=session, ) - document = Document.from_dict(response.data).to_dict(to_datestr=True) + document = response.to_dict(to_datestr=True) # Send ingest to application: params = { @@ -222,13 +218,12 @@ def infer( """ db_session = client.get_db_session() - project = client.get_project(project_name=project, session=db_session).data + project = client.get_project(project_name=project, session=db_session) # Getting the workflow: - workflow = client.get_workflow( - project_id=project["id"], workflow_name=workflow_name, session=db_session - ).data - workflow = Workflow.from_dict(workflow) - path = Workflow.get_infer_path(workflow) + workflow = client.list_workflows( + project_id=project.uid, name=workflow_name, session=db_session + )[0] + path = workflow.get_infer_path() query = QueryItem( question=question, @@ -239,7 +234,7 @@ def infer( data = { "item": query.dict(), - "workflow": workflow.dict(), + "workflow": workflow.to_dict(short=True), } headers = {"x_username": user} if user else {} @@ -282,7 +277,7 @@ def list_users(user, email): """ click.echo("Running List Users") - data = client.list_users(email, user, output_mode="short").data + data = client.list_users(email, user, output_mode="short") table = format_table_results(data) click.echo(table) @@ -309,9 +304,9 @@ def list_data_sources(owner, project, version, source_type, metadata): """ click.echo("Running List Collections") if owner: - owner = client.get_user(username=owner).data["id"] + owner = client.get_user(username=owner).uid if project: - project = client.get_project(project_name=project).data["id"] + project = client.get_project(project_name=project).uid data = client.list_data_sources( owner_id=owner, @@ -320,7 +315,7 @@ def list_data_sources(owner, project, version, source_type, metadata): data_source_type=source_type, labels_match=metadata, output_mode="short", - ).data + ) table = format_table_results(data) click.echo(table) @@ -352,18 +347,18 @@ def update_data_source(name, project, owner, description, source_type, labels): session = client.get_db_session() # check if the collection exists, if it does, update it, otherwise create it - project = client.get_project(project_name=project, session=session).data - collection_exists = client.get_data_source( - project_id=project["id"], + project = client.get_project(project_name=project, session=session) + data_source = client.list_data_sources( + project_id=project.uid, data_source_name=name, session=session, - ).success + ) - if collection_exists: + if data_source is not None: client.update_data_source( session=session, collection=DataSource( - project_id=project["id"], + project_id=project.uid, name=name, description=description, data_source_type=source_type, @@ -371,10 +366,10 @@ def update_data_source(name, project, owner, description, source_type, labels): ), ).with_raise() else: - client.create_collection( + client.create_data_source( session=session, - collection=DataSource( - project_id=project["id"], + data_source=DataSource( + project_id=project.uid, name=name, description=description, owner_name=owner, @@ -401,10 +396,10 @@ def list_sessions(user, last, created): click.echo("Running List Sessions") if user: - user = client.get_user(user_name=user).data["id"] + user = client.get_user(user_name=user).uid data = client.list_chat_sessions( user_id=user, created_after=created, last=last, output_mode="short" - ).data + ) table = format_table_results(data) click.echo(table) diff --git a/controller/src/schemas/__init__.py b/controller/src/schemas/__init__.py index 61289f2..68d09b8 100644 --- a/controller/src/schemas/__init__.py +++ b/controller/src/schemas/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import ApiResponse, Base, OutputMode +from .base import APIResponse, Base, OutputMode from .data_source import DataSource, DataSourceType from .dataset import Dataset from .document import Document diff --git a/controller/src/schemas/base.py b/controller/src/schemas/base.py index 58bb5c9..e9319b6 100644 --- a/controller/src/schemas/base.py +++ b/controller/src/schemas/base.py @@ -21,7 +21,7 @@ from pydantic import BaseModel metadata_fields = [ - "id", + "uid", "name", "description", "labels", @@ -124,7 +124,7 @@ def to_orm_object(self, obj_class, uid=None): } labels = obj_dict.pop("labels", None) if uid: - obj_dict["id"] = uid + obj_dict["uid"] = uid obj = obj_class(**obj_dict) if labels: obj.labels.clear() @@ -147,33 +147,33 @@ def __str__(self): class BaseWithMetadata(Base): name: str - id: Optional[str] = None - description: Optional[str] = None - labels: Optional[Dict[str, Union[str, None]]] = None - created: Optional[Union[str, datetime]] = None - updated: Optional[Union[str, datetime]] = None + uid: Optional[str] + description: Optional[str] + labels: Optional[Dict[str, Union[str, None]]] + created: Optional[Union[str, datetime]] + updated: Optional[Union[str, datetime]] class BaseWithOwner(BaseWithMetadata): - owner_id: Optional[str] = None + owner_id: str class BaseWithVerMetadata(BaseWithOwner): version: Optional[str] = "" -class ApiResponse(BaseModel): +class APIResponse(BaseModel): success: bool - data: Optional[Union[list, Type[BaseModel], dict]] = None - error: Optional[str] = None + data: Optional[Union[list, Type[BaseModel], dict]] + error: Optional[str] - def with_raise(self, format=None) -> "ApiResponse": + def with_raise(self, format=None) -> "APIResponse": if not self.success: format = format or "API call failed: %s" raise ValueError(format % self.error) return self - def with_raise_http(self, format=None) -> "ApiResponse": + def with_raise_http(self, format=None) -> "APIResponse": if not self.success: format = format or "API call failed: %s" raise HTTPException(status_code=400, detail=format % self.error) @@ -181,7 +181,7 @@ def with_raise_http(self, format=None) -> "ApiResponse": class OutputMode(str, Enum): - Names = "names" - Short = "short" - Dict = "dict" - Details = "details" + NAMES = "names" + SHORT = "short" + DICT = "dict" + DETAILS = "details" diff --git a/controller/src/schemas/data_source.py b/controller/src/schemas/data_source.py index 5b46060..b8682b8 100644 --- a/controller/src/schemas/data_source.py +++ b/controller/src/schemas/data_source.py @@ -19,18 +19,18 @@ class DataSourceType(str, Enum): - relational = "relational" - vector = "vector" - graph = "graph" - key_value = "key-value" - column_family = "column-family" - storage = "storage" - other = "other" + RELATIONAL = "relational" + VECTOR = "vector" + GRAPH = "graph" + KEY_VALUE = "key-value" + COLUMN_FAMILY = "column-family" + STORAGE = "storage" + OTHER = "other" class DataSource(BaseWithVerMetadata): _top_level_fields = ["data_source_type"] data_source_type: DataSourceType - project_id: Optional[str] = None + project_id: str database_kwargs: Optional[dict[str, str]] = {} diff --git a/controller/src/schemas/dataset.py b/controller/src/schemas/dataset.py index ca8e70f..a11eea1 100644 --- a/controller/src/schemas/dataset.py +++ b/controller/src/schemas/dataset.py @@ -20,8 +20,8 @@ class Dataset(BaseWithVerMetadata): _top_level_fields = ["task"] - project_id: Optional[str] = None task: str - sources: Optional[List[str]] = None path: str - producer: Optional[str] = None + project_id: str + sources: Optional[List[str]] + producer: Optional[str] diff --git a/controller/src/schemas/document.py b/controller/src/schemas/document.py index 9a68861..297c19d 100644 --- a/controller/src/schemas/document.py +++ b/controller/src/schemas/document.py @@ -20,5 +20,5 @@ class Document(BaseWithVerMetadata): _top_level_fields = ["path", "origin"] path: str - project_id: Optional[str] = None - origin: Optional[str] = None + project_id: str + origin: Optional[str] diff --git a/controller/src/schemas/model.py b/controller/src/schemas/model.py index 2c18064..48742ac 100644 --- a/controller/src/schemas/model.py +++ b/controller/src/schemas/model.py @@ -19,8 +19,8 @@ class ModelType(str, Enum): - model = "model" - adapter = "adapter" + MODEL = "model" + ADAPTER = "adapter" class Model(BaseWithVerMetadata): @@ -29,8 +29,8 @@ class Model(BaseWithVerMetadata): model_type: ModelType base_model: str - project_id: Optional[str] = None - task: Optional[str] = None - path: Optional[str] = None - producer: Optional[str] = None - deployment: Optional[str] = None + project_id: str + task: Optional[str] + path: Optional[str] + producer: Optional[str] + deployment: Optional[str] diff --git a/controller/src/schemas/prompt_template.py b/controller/src/schemas/prompt_template.py index 57ad3e7..00efb34 100644 --- a/controller/src/schemas/prompt_template.py +++ b/controller/src/schemas/prompt_template.py @@ -22,5 +22,5 @@ class PromptTemplate(BaseWithVerMetadata): _top_level_fields = ["text"] text: str - project_id: Optional[str] = None - arguments: Optional[List[str]] = None + project_id: str + arguments: Optional[List[str]] diff --git a/controller/src/schemas/session.py b/controller/src/schemas/session.py index f54e8fc..0d82202 100644 --- a/controller/src/schemas/session.py +++ b/controller/src/schemas/session.py @@ -22,25 +22,25 @@ class QueryItem(BaseModel): question: str - session_name: Optional[str] = None - filter: Optional[List[Tuple[str, str]]] = None - data_source: Optional[str] = None + session_id: Optional[str] + filter: Optional[List[Tuple[str, str]]] + data_source: Optional[str] class ChatRole(str, Enum): - Human = "Human" + HUMAN = "Human" AI = "AI" - System = "System" - User = "User" # for co-pilot user (vs Human?) - Agent = "Agent" # for co-pilot agent + SYSTEM = "System" + USER = "User" # for co-pilot user (vs Human?) + AGENT = "Agent" # for co-pilot agent class Message(BaseModel): role: ChatRole content: str - extra_data: Optional[dict] = None - sources: Optional[List[dict]] = None - human_feedback: Optional[str] = None + extra_data: Optional[dict] + sources: Optional[List[dict]] + human_feedback: Optional[str] class ChatSession(BaseWithOwner): diff --git a/controller/src/schemas/user.py b/controller/src/schemas/user.py index d29ead1..2c084d4 100644 --- a/controller/src/schemas/user.py +++ b/controller/src/schemas/user.py @@ -22,7 +22,7 @@ class User(BaseWithMetadata): _top_level_fields = ["email", "full_name"] email: str - full_name: Optional[str] = None - features: Optional[dict[str, str]] = None - policy: Optional[dict[str, str]] = None + full_name: Optional[str] + features: Optional[dict[str, str]] + policy: Optional[dict[str, str]] is_admin: Optional[bool] = False diff --git a/controller/src/schemas/workflow.py b/controller/src/schemas/workflow.py index 20b3d58..fdcaea2 100644 --- a/controller/src/schemas/workflow.py +++ b/controller/src/schemas/workflow.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from enum import Enum from typing import Optional @@ -19,22 +20,24 @@ class WorkflowType(str, Enum): - ingestion = "ingestion" - application = "application" - data_processing = "data-processing" - training = "training" - evaluation = "evaluation" + INGESTION = "ingestion" + APPLICATION = "application" + DATA_PROCESSING = "data-processing" + TRAINING = "training" + EVALUATION = "evaluation" class Workflow(BaseWithVerMetadata): _top_level_fields = ["workflow_type"] workflow_type: WorkflowType - deployment: str - project_id: Optional[str] = None - workflow_function: Optional[str] = None - configuration: Optional[dict] = None - graph: Optional[dict] = None + project_id: str + deployment: Optional[str] + workflow_function: Optional[str] + configuration: Optional[dict] + graph: Optional[dict] def get_infer_path(self): - return f"{self.deployment}/api/workflows/{self.name}/infer" + if self.deployment is None: + return None + return os.path.join(self.deployment, "infer") From 57796ebddab83d004744c0d773f97a65a8c547c8 Mon Sep 17 00:00:00 2001 From: yonishelach Date: Sun, 25 Aug 2024 11:36:43 +0300 Subject: [PATCH 10/10] workflow graph - list of dicts & fix infer bug --- controller/src/api/endpoints/workflows.py | 18 ++++++++++-------- controller/src/schemas/workflow.py | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/controller/src/api/endpoints/workflows.py b/controller/src/api/endpoints/workflows.py index ca2890c..768e02b 100644 --- a/controller/src/api/endpoints/workflows.py +++ b/controller/src/api/endpoints/workflows.py @@ -162,7 +162,9 @@ def list_workflows( :return: The response from the database. """ - owner_id = client.get_user(user_name=auth.username, session=session).uid + owner_id = client.get_user( + user_name=auth.username, email=auth.username, session=session + ).uid project_id = client.get_project(project_name=project_name, session=session).uid try: data = client.list_workflows( @@ -209,16 +211,16 @@ def infer_workflow( if query.session_id: # Get session by id: - session = client.get_chat_session(uid=query.session_id, session=session) - if session is None: + chat_session = client.get_chat_session(uid=query.session_id, session=session) + if chat_session is None: # If not id found, get session by name: session_name = query.session_id - session = client.list_chat_sessions(name=session_name, session=session) + chat_session = client.list_chat_sessions(name=session_name, session=session) # If not name found, create a new session: - if session: - session = session[0] + if chat_session: + chat_session = chat_session[0] else: - session = client.create_chat_session( + chat_session = client.create_chat_session( chat_session=ChatSession( name=session_name, workflow_id=uid, @@ -227,7 +229,7 @@ def infer_workflow( ).uid, ), ) - query.session_id = session.uid + query.session_id = chat_session.uid # Prepare the data to send to the application's workflow data = { "item": query.dict(), diff --git a/controller/src/schemas/workflow.py b/controller/src/schemas/workflow.py index fdcaea2..668f6fd 100644 --- a/controller/src/schemas/workflow.py +++ b/controller/src/schemas/workflow.py @@ -14,7 +14,7 @@ import os from enum import Enum -from typing import Optional +from typing import List, Optional from controller.src.schemas.base import BaseWithVerMetadata @@ -35,7 +35,7 @@ class Workflow(BaseWithVerMetadata): deployment: Optional[str] workflow_function: Optional[str] configuration: Optional[dict] - graph: Optional[dict] + graph: Optional[List[dict]] def get_infer_path(self): if self.deployment is None: