Skip to content

Commit

Permalink
Fixture DB Connections (#310)
Browse files Browse the repository at this point in the history
* Add kullback_leibler_divergence metric

* Fix numpy's version in order to avoid errors when building docker image

* Add Singleton Class, implement singleton in connection and middleware in session

* undo numpy version

* implement model outside with like done in other documents

* extract variable from with
  • Loading branch information
shincap8 authored Oct 31, 2024
1 parent a9f79aa commit b311016
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 253 deletions.
14 changes: 10 additions & 4 deletions backend/app/infrastructure/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import sessionmaker

from app.infrastructure.utils.singleton import Singleton


load_dotenv()

Expand All @@ -26,15 +28,19 @@
)


class Connection:
class Connection(metaclass=Singleton):
def __init__(self) -> None:
self.engine = create_engine(CONNECTION_URI, echo=False, pool_pre_ping=True)
self.Session = sessionmaker(bind=self.engine, expire_on_commit=False)
self.session = self.Session()
self.engine = create_engine(
CONNECTION_URI, echo=False, pool_pre_ping=True, pool_size=2, pool_recycle=60
)
self.metadata = MetaData()

def refresh_session(self):
self.session = self.Session()

def close_session(self):
self.session.close()

@property
def session(self):
return sessionmaker(bind=self.engine, expire_on_commit=True)()
10 changes: 6 additions & 4 deletions backend/app/infrastructure/repositories/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def get_all(self):

def add(self, instance_dict: dict):
new_instance = self.model(**instance_dict)
self.session.add(new_instance)
self.session.commit()
new_instance = self.instance_converter.instance_to_dict(new_instance)
return new_instance
with self.session as session:
session.add(new_instance)
session.commit()
session.refresh(new_instance)
new_instance = self.instance_converter.instance_to_dict(new_instance)
return new_instance

def get_by_id(self, id: int) -> dict:
instance = self.session.query(self.model).get(id)
Expand Down
7 changes: 4 additions & 3 deletions backend/app/infrastructure/repositories/badge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self) -> None:

def add_badge(self, user_id: int, name: str) -> None:
model = self.model(uid=user_id, name=name)
self.session.add(model)
self.session.flush()
self.session.commit()
with self.session as session:
session.add(model)
session.flush()
session.commit()
15 changes: 9 additions & 6 deletions backend/app/infrastructure/repositories/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def __init__(self) -> None:
super().__init__(Context)

def increment_counter_total_samples_and_update_date(self, context_id: int) -> None:
self.session.query(self.model).filter(self.model.id == context_id).update(
{self.model.total_used: self.model.total_used + 1}
)
self.session.commit()
with self.session as session:
session.query(self.model).filter(self.model.id == context_id).update(
{self.model.total_used: self.model.total_used + 1}
)
session.commit()

def get_real_round_id(self, context_id: int) -> int:
instance = (
Expand Down Expand Up @@ -83,5 +84,7 @@ def get_context_by_key_value_in_contextjson(self, search_txt: str):
)

def upload_contexts(self, context: dict):
self.session.add(self.model(**context))
self.session.commit()
model = self.model(**context)
with self.session as session:
session.add(model)
session.commit()
7 changes: 4 additions & 3 deletions backend/app/infrastructure/repositories/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def create_dataset_in_db(
dataset = Dataset(
tid=task_id, name=dataset_name, access_type=access_type, rid=0
)
self.session.add(dataset)
self.session.commit()
return dataset
with self.session as session:
session.add(dataset)
session.commit()
return dataset

def get_downstream_datasets(self, task_id: int) -> dict:
downstream_datasets = (
Expand Down
60 changes: 33 additions & 27 deletions backend/app/infrastructure/repositories/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,47 +90,53 @@ def get_example_to_validate_fooling(
)

def increment_counter_total_verified(self, example_id: int):
self.session.query(self.model).filter(self.model.id == example_id).update(
{self.model.total_verified: self.model.total_verified + 1}
)
self.session.flush()
self.session.commit()
with self.session as session:
session.query(self.model).filter(self.model.id == example_id).update(
{self.model.total_verified: self.model.total_verified + 1}
)
session.flush()
session.commit()

def increment_counter_total_correct(self, example_id: int):
self.session.query(self.model).filter(self.model.id == example_id).update(
{self.model.verified_correct: self.model.verified_correct + 1}
)
self.session.flush()
self.session.commit()
with self.session as session:
session.query(self.model).filter(self.model.id == example_id).update(
{self.model.verified_correct: self.model.verified_correct + 1}
)
session.flush()
session.commit()

def increment_counter_total_incorrect(self, example_id: int):
self.session.query(self.model).filter(self.model.id == example_id).update(
{self.model.verified_incorrect: self.model.verified_incorrect + 1}
)
self.session.flush()
self.session.commit()
with self.session as session:
session.query(self.model).filter(self.model.id == example_id).update(
{self.model.verified_incorrect: self.model.verified_incorrect + 1}
)
session.flush()
session.commit()

def increment_counter_total_flagged(self, example_id: int):
self.session.query(self.model).filter(self.model.id == example_id).update(
{self.model.verified_flagged: self.model.verified_flagged + 1}
)
self.session.flush()
self.session.commit()
with self.session as session:
session.query(self.model).filter(self.model.id == example_id).update(
{self.model.verified_flagged: self.model.verified_flagged + 1}
)
session.flush()
session.commit()

def mark_as_verified(self, example_id: int):
example = self.get_by_id(example_id)
example["verified"] = 1
self.session.flush()
self.session.commit()
with self.session as session:
session.flush()
session.commit()

def update_creation_generative_example_by_example_id(
self, example_id: int, model_input: Json, metadata: Json
):
self.session.query(self.model).filter_by(id=example_id).update(
{"input_json": model_input, "metadata_json": metadata}
)
self.session.flush()
self.session.commit()
with self.session as session:
session.query(self.model).filter_by(id=example_id).update(
{"input_json": model_input, "metadata_json": metadata}
)
session.flush()
session.commit()

def download_created_examples_user(self, task_id: int, user_id: int, amount: int):
return (
Expand Down
20 changes: 11 additions & 9 deletions backend/app/infrastructure/repositories/historical_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@ def save_historical_data(self, task_id: int, user_id: int, data: str):
user_id=user_id,
history=data,
)
self.session.add(model)
self.session.flush()
self.session.commit()
return self.get_historical_data_by_task_and_user(task_id, user_id)
with self.session as session:
session.add(model)
session.flush()
session.commit()
return self.get_historical_data_by_task_and_user(task_id, user_id)

def delete_historical_data(self, task_id: int, user_id: int):
self.session.query(HistoricalData).filter(
HistoricalData.task_id == task_id
).filter(HistoricalData.user_id == user_id).delete()
self.session.flush()
self.session.commit()
with self.session as session:
session.query(HistoricalData).filter(
HistoricalData.task_id == task_id
).filter(HistoricalData.user_id == user_id).delete()
session.flush()
session.commit()

def get_occurrences_with_more_than_one_hundred(self, task_id: int):
return (
Expand Down
17 changes: 10 additions & 7 deletions backend/app/infrastructure/repositories/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def metadata_exists(self, model: dict):

def create_registry(self, model: dict):
job = {"prompt": model["prompt"], "user_id": model["user_id"]}
self.session.add(self.model(**job))
self.session.commit()
new_instance = self.model(**job)
with self.session as session:
session.add(new_instance)
session.commit()

def determine_queue_position(self, model: dict):
my_position = (
Expand All @@ -45,8 +47,9 @@ def determine_queue_position(self, model: dict):
return {"queue_position": queue_position, "all_positions": all_positions}

def remove_registry(self, model: dict):
self.session.query(self.model).filter(
self.model.prompt == model["prompt"]
).filter(self.model.user_id == model["user_id"]).delete()
self.session.flush()
self.session.commit()
with self.session as session:
session.query(self.model).filter(
self.model.prompt == model["prompt"]
).filter(self.model.user_id == model["user_id"]).delete()
session.flush()
session.commit()
98 changes: 52 additions & 46 deletions backend/app/infrastructure/repositories/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,21 @@ def get_model_in_the_loop(self, task_id: int) -> dict:
return models_in_the_loop

def update_light_model(self, id: int, light_model: str) -> None:
instance = self.session.query(self.model).filter(self.model.id == id).first()
light_model = f"{light_model}/model/single_evaluation"
instance.light_model = light_model
instance.is_in_the_loop = 0
self.session.flush()
self.session.commit()
with self.session as session:
instance = session.query(self.model).filter(self.model.id == id).first()
instance.light_model = light_model
instance.is_in_the_loop = 0
session.flush()
session.commit()

def update_model_status(self, id: int) -> None:
instance = self.session.query(self.model).filter(self.model.id == id).first()
instance.deployment_status = "deployed"
instance.is_published = 0
self.session.flush()
self.session.commit()
with self.session as session:
instance = session.query(self.model).filter(self.model.id == id).first()
instance.deployment_status = "deployed"
instance.is_published = 0
session.flush()
session.commit()

def get_lambda_models(self) -> list:
models = (
Expand Down Expand Up @@ -89,10 +91,11 @@ def create_new_model(
deployment_status=deployment_status,
secret=secret,
)
self.session.add(model)
self.session.flush()
self.session.commit()
return model.__dict__
with self.session as session:
session.add(model)
session.flush()
session.commit()
return model.__dict__

def get_active_models_by_task_id(self, task_id: int) -> list:
models = (
Expand Down Expand Up @@ -128,18 +131,19 @@ def get_task_id_by_model_id(self, model_id: int) -> int:
)

def update_published_status(self, model_id: int):
instance = (
self.session.query(self.model).filter(self.model.id == model_id).first()
)
instance.is_published = (
0
if instance.is_published == 1
else 1
if instance.is_published == 0
else instance.is_published
)
self.session.flush()
self.session.commit()
with self.session as session:
instance = (
session.query(self.model).filter(self.model.id == model_id).first()
)
instance.is_published = (
0
if instance.is_published == 1
else 1
if instance.is_published == 0
else instance.is_published
)
session.flush()
session.commit()

def get_models_by_user_id(self, user_id: int) -> list:
return (
Expand Down Expand Up @@ -173,10 +177,11 @@ def get_total_models_by_user_id(self, user_id):
return self.session.query(self.model).filter(self.model.uid == user_id).count()

def delete_model(self, model_id: int):
self.session.query(Score).filter(Score.mid == model_id).delete()
self.session.query(self.model).filter(self.model.id == model_id).delete()
self.session.flush()
self.session.commit()
with self.session as session:
session.query(Score).filter(Score.mid == model_id).delete()
session.query(self.model).filter(self.model.id == model_id).delete()
session.flush()
session.commit()

def get_all_model_info_by_id(self, model_id: int):
valid_datasets = (
Expand Down Expand Up @@ -229,23 +234,24 @@ def update_model_info(
license: str,
source_url: str,
):
(
self.session.query(self.model)
.filter(self.model.id == model_id)
.update(
{
"name": name,
"desc": desc,
"longdesc": longdesc,
"params": params,
"languages": languages,
"license": license,
"source_url": source_url,
}
with self.session as session:
(
session.query(self.model)
.filter(self.model.id == model_id)
.update(
{
"name": name,
"desc": desc,
"longdesc": longdesc,
"params": params,
"languages": languages,
"license": license,
"source_url": source_url,
}
)
)
)
self.session.flush()
return self.session.commit()
session.flush()
return session.commit()

def download_model_results(self, task_id: int):
m = aliased(Model)
Expand Down
Loading

0 comments on commit b311016

Please sign in to comment.