Skip to content

Commit

Permalink
[Controller] Fix Optional fields default value in pydantic models (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonishelach authored Sep 4, 2024
1 parent 9dc7641 commit e50f563
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 45 deletions.
2 changes: 1 addition & 1 deletion controller/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ RUN pip install -r /controller/requirements.txt
RUN python -m controller.src.main initdb

# Run the controller's API server:
CMD ["uvicorn", "controller.src.api:app", "--port", "8001"]
CMD ["uvicorn", "controller.src.api:app", "--port", "8001", "--reload"]
2 changes: 1 addition & 1 deletion controller/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CtrlConfig(BaseModel):
verbose: bool = True
log_level: str = "DEBUG"
# SQL Database
db_type = "sql"
db_type: str = "sql"
sql_connection_str: str = default_db_path
application_url: str = "http://localhost:8000"

Expand Down
18 changes: 9 additions & 9 deletions controller/src/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datetime import datetime
from enum import Enum
from http.client import HTTPException
from typing import Dict, Optional, Type, Union
from typing import Dict, Type, Union

import yaml
from pydantic import BaseModel
Expand Down Expand Up @@ -147,25 +147,25 @@ def __str__(self):

class BaseWithMetadata(Base):
name: str
uid: Optional[str]
description: Optional[str]
labels: Optional[Dict[str, Union[str, None]]]
created: Optional[Union[str, datetime]]
updated: Optional[Union[str, datetime]]
uid: str = None
description: str = None
labels: Dict[str, Union[str, None]] = None
created: Union[str, datetime] = None
updated: Union[str, datetime] = None


class BaseWithOwner(BaseWithMetadata):
owner_id: str


class BaseWithVerMetadata(BaseWithOwner):
version: Optional[str] = ""
version: str = ""


class APIResponse(BaseModel):
success: bool
data: Optional[Union[list, Type[BaseModel], dict]]
error: Optional[str]
data: Union[list, Type[BaseModel], dict] = None
error: str = None

def with_raise(self, format=None) -> "APIResponse":
if not self.success:
Expand Down
3 changes: 1 addition & 2 deletions controller/src/schemas/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from enum import Enum
from typing import Optional

from controller.src.schemas.base import BaseWithVerMetadata

Expand All @@ -33,4 +32,4 @@ class DataSource(BaseWithVerMetadata):

data_source_type: DataSourceType
project_id: str
database_kwargs: Optional[dict[str, str]] = {}
database_kwargs: dict[str, str] = {}
6 changes: 3 additions & 3 deletions controller/src/schemas/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List

from controller.src.schemas.base import BaseWithVerMetadata

Expand All @@ -23,5 +23,5 @@ class Dataset(BaseWithVerMetadata):
task: str
path: str
project_id: str
sources: Optional[List[str]]
producer: Optional[str]
sources: List[str] = None
producer: str = None
4 changes: 1 addition & 3 deletions controller/src/schemas/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from controller.src.schemas.base import BaseWithVerMetadata


class Document(BaseWithVerMetadata):
_top_level_fields = ["path", "origin"]
path: str
project_id: str
origin: Optional[str]
origin: str = None
9 changes: 4 additions & 5 deletions controller/src/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from enum import Enum
from typing import Optional

from controller.src.schemas.base import BaseWithVerMetadata

Expand All @@ -30,7 +29,7 @@ class Model(BaseWithVerMetadata):
model_type: ModelType
base_model: str
project_id: str
task: Optional[str]
path: Optional[str]
producer: Optional[str]
deployment: Optional[str]
task: str = None
path: str = None
producer: str = None
deployment: str = None
4 changes: 2 additions & 2 deletions controller/src/schemas/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List

from controller.src.schemas.base import BaseWithVerMetadata

Expand All @@ -23,4 +23,4 @@ class PromptTemplate(BaseWithVerMetadata):

text: str
project_id: str
arguments: Optional[List[str]]
arguments: List[str] = None
16 changes: 8 additions & 8 deletions controller/src/schemas/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from enum import Enum
from typing import List, Optional, Tuple
from typing import List, Tuple

from pydantic import BaseModel

Expand All @@ -22,9 +22,9 @@

class QueryItem(BaseModel):
question: str
session_id: Optional[str]
filter: Optional[List[Tuple[str, str]]]
data_source: Optional[str]
session_id: str = None
filter: List[Tuple[str, str]] = None
data_source: str = None


class ChatRole(str, Enum):
Expand All @@ -38,14 +38,14 @@ class ChatRole(str, Enum):
class Message(BaseModel):
role: ChatRole
content: str
extra_data: Optional[dict]
sources: Optional[List[dict]]
human_feedback: Optional[str]
extra_data: dict = None
sources: List[dict] = None
human_feedback: str = None


class ChatSession(BaseWithOwner):
_extra_fields = ["history"]
_top_level_fields = ["workflow_id"]

workflow_id: str
history: Optional[List[Message]] = []
history: List[Message] = []
10 changes: 4 additions & 6 deletions controller/src/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from controller.src.schemas.base import BaseWithMetadata


Expand All @@ -22,7 +20,7 @@ class User(BaseWithMetadata):
_top_level_fields = ["email", "full_name"]

email: str
full_name: Optional[str]
features: Optional[dict[str, str]]
policy: Optional[dict[str, str]]
is_admin: Optional[bool] = False
full_name: str = None
features: dict[str, str] = None
policy: dict[str, str] = None
is_admin: bool = False
10 changes: 5 additions & 5 deletions controller/src/schemas/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
from enum import Enum
from typing import List, Optional
from typing import List

from controller.src.schemas.base import BaseWithVerMetadata

Expand All @@ -32,10 +32,10 @@ class Workflow(BaseWithVerMetadata):

workflow_type: WorkflowType
project_id: str
deployment: Optional[str]
workflow_function: Optional[str]
configuration: Optional[dict]
graph: Optional[List[dict]]
deployment: str = None
workflow_function: str = None
configuration: dict = None
graph: List[dict] = None

def get_infer_path(self):
if self.deployment is None:
Expand Down

0 comments on commit e50f563

Please sign in to comment.