From e50f5639374123db8233092ee7864f39420cde75 Mon Sep 17 00:00:00 2001 From: Yonatan Shelach <92271540+yonishelach@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:27:26 +0300 Subject: [PATCH] [Controller] Fix Optional fields default value in pydantic models (#13) --- controller/Dockerfile | 2 +- controller/src/config.py | 2 +- controller/src/schemas/base.py | 18 +++++++++--------- controller/src/schemas/data_source.py | 3 +-- controller/src/schemas/dataset.py | 6 +++--- controller/src/schemas/document.py | 4 +--- controller/src/schemas/model.py | 9 ++++----- controller/src/schemas/prompt_template.py | 4 ++-- controller/src/schemas/session.py | 16 ++++++++-------- controller/src/schemas/user.py | 10 ++++------ controller/src/schemas/workflow.py | 10 +++++----- 11 files changed, 39 insertions(+), 45 deletions(-) diff --git a/controller/Dockerfile b/controller/Dockerfile index cd42172..559ebc9 100644 --- a/controller/Dockerfile +++ b/controller/Dockerfile @@ -42,4 +42,4 @@ RUN pip install -r /controller/requirements.txt RUN python -m controller.src.main initdb # Run the controller's API server: -CMD ["uvicorn", "controller.src.api:app", "--port", "8001"] +CMD ["uvicorn", "controller.src.api:app", "--port", "8001", "--reload"] diff --git a/controller/src/config.py b/controller/src/config.py index f80342b..3e28d6d 100644 --- a/controller/src/config.py +++ b/controller/src/config.py @@ -35,7 +35,7 @@ class CtrlConfig(BaseModel): verbose: bool = True log_level: str = "DEBUG" # SQL Database - db_type = "sql" + db_type: str = "sql" sql_connection_str: str = default_db_path application_url: str = "http://localhost:8000" diff --git a/controller/src/schemas/base.py b/controller/src/schemas/base.py index e9319b6..ad9ecdb 100644 --- a/controller/src/schemas/base.py +++ b/controller/src/schemas/base.py @@ -15,7 +15,7 @@ from datetime import datetime from enum import Enum from http.client import HTTPException -from typing import Dict, Optional, Type, Union +from typing import Dict, Type, Union import yaml from pydantic import BaseModel @@ -147,11 +147,11 @@ def __str__(self): class BaseWithMetadata(Base): name: str - uid: Optional[str] - description: Optional[str] - labels: Optional[Dict[str, Union[str, None]]] - created: Optional[Union[str, datetime]] - updated: Optional[Union[str, datetime]] + uid: str = None + description: str = None + labels: Dict[str, Union[str, None]] = None + created: Union[str, datetime] = None + updated: Union[str, datetime] = None class BaseWithOwner(BaseWithMetadata): @@ -159,13 +159,13 @@ class BaseWithOwner(BaseWithMetadata): class BaseWithVerMetadata(BaseWithOwner): - version: Optional[str] = "" + version: str = "" class APIResponse(BaseModel): success: bool - data: Optional[Union[list, Type[BaseModel], dict]] - error: Optional[str] + data: Union[list, Type[BaseModel], dict] = None + error: str = None def with_raise(self, format=None) -> "APIResponse": if not self.success: diff --git a/controller/src/schemas/data_source.py b/controller/src/schemas/data_source.py index b8682b8..308506a 100644 --- a/controller/src/schemas/data_source.py +++ b/controller/src/schemas/data_source.py @@ -13,7 +13,6 @@ # limitations under the License. from enum import Enum -from typing import Optional from controller.src.schemas.base import BaseWithVerMetadata @@ -33,4 +32,4 @@ class DataSource(BaseWithVerMetadata): data_source_type: DataSourceType project_id: str - database_kwargs: Optional[dict[str, str]] = {} + database_kwargs: dict[str, str] = {} diff --git a/controller/src/schemas/dataset.py b/controller/src/schemas/dataset.py index a11eea1..538c693 100644 --- a/controller/src/schemas/dataset.py +++ b/controller/src/schemas/dataset.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List from controller.src.schemas.base import BaseWithVerMetadata @@ -23,5 +23,5 @@ class Dataset(BaseWithVerMetadata): task: str path: str project_id: str - sources: Optional[List[str]] - producer: Optional[str] + sources: List[str] = None + producer: str = None diff --git a/controller/src/schemas/document.py b/controller/src/schemas/document.py index 297c19d..4b3996d 100644 --- a/controller/src/schemas/document.py +++ b/controller/src/schemas/document.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - from controller.src.schemas.base import BaseWithVerMetadata @@ -21,4 +19,4 @@ class Document(BaseWithVerMetadata): _top_level_fields = ["path", "origin"] path: str project_id: str - origin: Optional[str] + origin: str = None diff --git a/controller/src/schemas/model.py b/controller/src/schemas/model.py index 48742ac..c9d5f88 100644 --- a/controller/src/schemas/model.py +++ b/controller/src/schemas/model.py @@ -13,7 +13,6 @@ # limitations under the License. from enum import Enum -from typing import Optional from controller.src.schemas.base import BaseWithVerMetadata @@ -30,7 +29,7 @@ class Model(BaseWithVerMetadata): model_type: ModelType base_model: str project_id: str - task: Optional[str] - path: Optional[str] - producer: Optional[str] - deployment: Optional[str] + task: str = None + path: str = None + producer: str = None + deployment: str = None diff --git a/controller/src/schemas/prompt_template.py b/controller/src/schemas/prompt_template.py index 00efb34..3e4cb14 100644 --- a/controller/src/schemas/prompt_template.py +++ b/controller/src/schemas/prompt_template.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List from controller.src.schemas.base import BaseWithVerMetadata @@ -23,4 +23,4 @@ class PromptTemplate(BaseWithVerMetadata): text: str project_id: str - arguments: Optional[List[str]] + arguments: List[str] = None diff --git a/controller/src/schemas/session.py b/controller/src/schemas/session.py index 0d82202..f22036c 100644 --- a/controller/src/schemas/session.py +++ b/controller/src/schemas/session.py @@ -13,7 +13,7 @@ # limitations under the License. from enum import Enum -from typing import List, Optional, Tuple +from typing import List, Tuple from pydantic import BaseModel @@ -22,9 +22,9 @@ class QueryItem(BaseModel): question: str - session_id: Optional[str] - filter: Optional[List[Tuple[str, str]]] - data_source: Optional[str] + session_id: str = None + filter: List[Tuple[str, str]] = None + data_source: str = None class ChatRole(str, Enum): @@ -38,9 +38,9 @@ class ChatRole(str, Enum): class Message(BaseModel): role: ChatRole content: str - extra_data: Optional[dict] - sources: Optional[List[dict]] - human_feedback: Optional[str] + extra_data: dict = None + sources: List[dict] = None + human_feedback: str = None class ChatSession(BaseWithOwner): @@ -48,4 +48,4 @@ class ChatSession(BaseWithOwner): _top_level_fields = ["workflow_id"] workflow_id: str - history: Optional[List[Message]] = [] + history: List[Message] = [] diff --git a/controller/src/schemas/user.py b/controller/src/schemas/user.py index 2c084d4..ec93ae5 100644 --- a/controller/src/schemas/user.py +++ b/controller/src/schemas/user.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - from controller.src.schemas.base import BaseWithMetadata @@ -22,7 +20,7 @@ class User(BaseWithMetadata): _top_level_fields = ["email", "full_name"] email: str - full_name: Optional[str] - features: Optional[dict[str, str]] - policy: Optional[dict[str, str]] - is_admin: Optional[bool] = False + full_name: str = None + features: dict[str, str] = None + policy: dict[str, str] = None + is_admin: bool = False diff --git a/controller/src/schemas/workflow.py b/controller/src/schemas/workflow.py index 668f6fd..2623af4 100644 --- a/controller/src/schemas/workflow.py +++ b/controller/src/schemas/workflow.py @@ -14,7 +14,7 @@ import os from enum import Enum -from typing import List, Optional +from typing import List from controller.src.schemas.base import BaseWithVerMetadata @@ -32,10 +32,10 @@ class Workflow(BaseWithVerMetadata): workflow_type: WorkflowType project_id: str - deployment: Optional[str] - workflow_function: Optional[str] - configuration: Optional[dict] - graph: Optional[List[dict]] + deployment: str = None + workflow_function: str = None + configuration: dict = None + graph: List[dict] = None def get_infer_path(self): if self.deployment is None: