diff --git a/backend/app/infrastructure/connection.py b/backend/app/infrastructure/connection.py index 85db1479f..1202ca53c 100644 --- a/backend/app/infrastructure/connection.py +++ b/backend/app/infrastructure/connection.py @@ -12,6 +12,8 @@ from sqlalchemy import MetaData, create_engine from sqlalchemy.orm import sessionmaker +from app.infrastructure.utils.singleton import Singleton + load_dotenv() @@ -26,11 +28,11 @@ ) -class Connection: +class Connection(metaclass=Singleton): def __init__(self) -> None: - self.engine = create_engine(CONNECTION_URI, echo=False, pool_pre_ping=True) - self.Session = sessionmaker(bind=self.engine, expire_on_commit=False) - self.session = self.Session() + self.engine = create_engine( + CONNECTION_URI, echo=False, pool_pre_ping=True, pool_size=2, pool_recycle=60 + ) self.metadata = MetaData() def refresh_session(self): @@ -38,3 +40,7 @@ def refresh_session(self): def close_session(self): self.session.close() + + @property + def session(self): + return sessionmaker(bind=self.engine, expire_on_commit=True)() diff --git a/backend/app/infrastructure/repositories/abstract.py b/backend/app/infrastructure/repositories/abstract.py index b9ed91f48..3f41ca501 100644 --- a/backend/app/infrastructure/repositories/abstract.py +++ b/backend/app/infrastructure/repositories/abstract.py @@ -24,10 +24,12 @@ def get_all(self): def add(self, instance_dict: dict): new_instance = self.model(**instance_dict) - self.session.add(new_instance) - self.session.commit() - new_instance = self.instance_converter.instance_to_dict(new_instance) - return new_instance + with self.session as session: + session.add(new_instance) + session.commit() + session.refresh(new_instance) + new_instance = self.instance_converter.instance_to_dict(new_instance) + return new_instance def get_by_id(self, id: int) -> dict: instance = self.session.query(self.model).get(id) diff --git a/backend/app/infrastructure/repositories/badge.py b/backend/app/infrastructure/repositories/badge.py index 214aa984f..4ad8fabed 100644 --- a/backend/app/infrastructure/repositories/badge.py +++ b/backend/app/infrastructure/repositories/badge.py @@ -16,6 +16,7 @@ def __init__(self) -> None: def add_badge(self, user_id: int, name: str) -> None: model = self.model(uid=user_id, name=name) - self.session.add(model) - self.session.flush() - self.session.commit() + with self.session as session: + session.add(model) + session.flush() + session.commit() diff --git a/backend/app/infrastructure/repositories/context.py b/backend/app/infrastructure/repositories/context.py index 1a85ae581..54fa7f2b0 100644 --- a/backend/app/infrastructure/repositories/context.py +++ b/backend/app/infrastructure/repositories/context.py @@ -16,10 +16,11 @@ def __init__(self) -> None: super().__init__(Context) def increment_counter_total_samples_and_update_date(self, context_id: int) -> None: - self.session.query(self.model).filter(self.model.id == context_id).update( - {self.model.total_used: self.model.total_used + 1} - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == context_id).update( + {self.model.total_used: self.model.total_used + 1} + ) + session.commit() def get_real_round_id(self, context_id: int) -> int: instance = ( @@ -83,5 +84,6 @@ def get_context_by_key_value_in_contextjson(self, search_txt: str): ) def upload_contexts(self, context: dict): - self.session.add(self.model(**context)) - self.session.commit() + with self.session as session: + session.add(self.model(**context)) + session.commit() diff --git a/backend/app/infrastructure/repositories/dataset.py b/backend/app/infrastructure/repositories/dataset.py index 62bf30284..e8449a332 100644 --- a/backend/app/infrastructure/repositories/dataset.py +++ b/backend/app/infrastructure/repositories/dataset.py @@ -104,9 +104,10 @@ def create_dataset_in_db( dataset = Dataset( tid=task_id, name=dataset_name, access_type=access_type, rid=0 ) - self.session.add(dataset) - self.session.commit() - return dataset + with self.session as session: + session.add(dataset) + session.commit() + return dataset def get_downstream_datasets(self, task_id: int) -> dict: downstream_datasets = ( diff --git a/backend/app/infrastructure/repositories/example.py b/backend/app/infrastructure/repositories/example.py index c13d2250b..015cece7b 100644 --- a/backend/app/infrastructure/repositories/example.py +++ b/backend/app/infrastructure/repositories/example.py @@ -90,47 +90,53 @@ def get_example_to_validate_fooling( ) def increment_counter_total_verified(self, example_id: int): - self.session.query(self.model).filter(self.model.id == example_id).update( - {self.model.total_verified: self.model.total_verified + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == example_id).update( + {self.model.total_verified: self.model.total_verified + 1} + ) + session.flush() + session.commit() def increment_counter_total_correct(self, example_id: int): - self.session.query(self.model).filter(self.model.id == example_id).update( - {self.model.verified_correct: self.model.verified_correct + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == example_id).update( + {self.model.verified_correct: self.model.verified_correct + 1} + ) + session.flush() + session.commit() def increment_counter_total_incorrect(self, example_id: int): - self.session.query(self.model).filter(self.model.id == example_id).update( - {self.model.verified_incorrect: self.model.verified_incorrect + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == example_id).update( + {self.model.verified_incorrect: self.model.verified_incorrect + 1} + ) + session.flush() + session.commit() def increment_counter_total_flagged(self, example_id: int): - self.session.query(self.model).filter(self.model.id == example_id).update( - {self.model.verified_flagged: self.model.verified_flagged + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == example_id).update( + {self.model.verified_flagged: self.model.verified_flagged + 1} + ) + session.flush() + session.commit() def mark_as_verified(self, example_id: int): example = self.get_by_id(example_id) example["verified"] = 1 - self.session.flush() - self.session.commit() + with self.session as session: + session.flush() + session.commit() def update_creation_generative_example_by_example_id( self, example_id: int, model_input: Json, metadata: Json ): - self.session.query(self.model).filter_by(id=example_id).update( - {"input_json": model_input, "metadata_json": metadata} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter_by(id=example_id).update( + {"input_json": model_input, "metadata_json": metadata} + ) + session.flush() + session.commit() def download_created_examples_user(self, task_id: int, user_id: int, amount: int): return ( diff --git a/backend/app/infrastructure/repositories/historical_data.py b/backend/app/infrastructure/repositories/historical_data.py index ed25f8153..000ba808d 100644 --- a/backend/app/infrastructure/repositories/historical_data.py +++ b/backend/app/infrastructure/repositories/historical_data.py @@ -26,17 +26,19 @@ def save_historical_data(self, task_id: int, user_id: int, data: str): user_id=user_id, history=data, ) - self.session.add(model) - self.session.flush() - self.session.commit() - return self.get_historical_data_by_task_and_user(task_id, user_id) + with self.session as session: + session.add(model) + session.flush() + session.commit() + return self.get_historical_data_by_task_and_user(task_id, user_id) def delete_historical_data(self, task_id: int, user_id: int): - self.session.query(HistoricalData).filter( - HistoricalData.task_id == task_id - ).filter(HistoricalData.user_id == user_id).delete() - self.session.flush() - self.session.commit() + with self.session as session: + session.query(HistoricalData).filter( + HistoricalData.task_id == task_id + ).filter(HistoricalData.user_id == user_id).delete() + session.flush() + session.commit() def get_occurrences_with_more_than_one_hundred(self, task_id: int): return ( diff --git a/backend/app/infrastructure/repositories/jobs.py b/backend/app/infrastructure/repositories/jobs.py index 0348cfe07..358721b54 100644 --- a/backend/app/infrastructure/repositories/jobs.py +++ b/backend/app/infrastructure/repositories/jobs.py @@ -25,8 +25,10 @@ def metadata_exists(self, model: dict): def create_registry(self, model: dict): job = {"prompt": model["prompt"], "user_id": model["user_id"]} - self.session.add(self.model(**job)) - self.session.commit() + new_instance = self.model(**job) + with self.session as session: + session.add(new_instance) + session.commit() def determine_queue_position(self, model: dict): my_position = ( @@ -45,8 +47,9 @@ def determine_queue_position(self, model: dict): return {"queue_position": queue_position, "all_positions": all_positions} def remove_registry(self, model: dict): - self.session.query(self.model).filter( - self.model.prompt == model["prompt"] - ).filter(self.model.user_id == model["user_id"]).delete() - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + self.model.prompt == model["prompt"] + ).filter(self.model.user_id == model["user_id"]).delete() + session.flush() + session.commit() diff --git a/backend/app/infrastructure/repositories/model.py b/backend/app/infrastructure/repositories/model.py index 8b2f7abd3..7013bb5b8 100644 --- a/backend/app/infrastructure/repositories/model.py +++ b/backend/app/infrastructure/repositories/model.py @@ -40,19 +40,21 @@ def get_model_in_the_loop(self, task_id: int) -> dict: return models_in_the_loop def update_light_model(self, id: int, light_model: str) -> None: - instance = self.session.query(self.model).filter(self.model.id == id).first() - light_model = f"{light_model}/model/single_evaluation" - instance.light_model = light_model - instance.is_in_the_loop = 0 - self.session.flush() - self.session.commit() + with self.session as session: + instance = session.query(self.model).filter(self.model.id == id).first() + light_model = f"{light_model}/model/single_evaluation" + instance.light_model = light_model + instance.is_in_the_loop = 0 + session.flush() + session.commit() def update_model_status(self, id: int) -> None: - instance = self.session.query(self.model).filter(self.model.id == id).first() - instance.deployment_status = "deployed" - instance.is_published = 0 - self.session.flush() - self.session.commit() + with self.session as session: + instance = session.query(self.model).filter(self.model.id == id).first() + instance.deployment_status = "deployed" + instance.is_published = 0 + session.flush() + session.commit() def get_lambda_models(self) -> list: models = ( @@ -89,10 +91,11 @@ def create_new_model( deployment_status=deployment_status, secret=secret, ) - self.session.add(model) - self.session.flush() - self.session.commit() - return model.__dict__ + with self.session as session: + session.add(model) + session.flush() + session.commit() + return model.__dict__ def get_active_models_by_task_id(self, task_id: int) -> list: models = ( @@ -128,18 +131,19 @@ def get_task_id_by_model_id(self, model_id: int) -> int: ) def update_published_status(self, model_id: int): - instance = ( - self.session.query(self.model).filter(self.model.id == model_id).first() - ) - instance.is_published = ( - 0 - if instance.is_published == 1 - else 1 - if instance.is_published == 0 - else instance.is_published - ) - self.session.flush() - self.session.commit() + with self.session as session: + instance = ( + session.query(self.model).filter(self.model.id == model_id).first() + ) + instance.is_published = ( + 0 + if instance.is_published == 1 + else 1 + if instance.is_published == 0 + else instance.is_published + ) + session.flush() + session.commit() def get_models_by_user_id(self, user_id: int) -> list: return ( @@ -173,10 +177,11 @@ def get_total_models_by_user_id(self, user_id): return self.session.query(self.model).filter(self.model.uid == user_id).count() def delete_model(self, model_id: int): - self.session.query(Score).filter(Score.mid == model_id).delete() - self.session.query(self.model).filter(self.model.id == model_id).delete() - self.session.flush() - self.session.commit() + with self.session as session: + session.query(Score).filter(Score.mid == model_id).delete() + session.query(self.model).filter(self.model.id == model_id).delete() + session.flush() + session.commit() def get_all_model_info_by_id(self, model_id: int): valid_datasets = ( @@ -229,23 +234,24 @@ def update_model_info( license: str, source_url: str, ): - ( - self.session.query(self.model) - .filter(self.model.id == model_id) - .update( - { - "name": name, - "desc": desc, - "longdesc": longdesc, - "params": params, - "languages": languages, - "license": license, - "source_url": source_url, - } + with self.session as session: + ( + session.query(self.model) + .filter(self.model.id == model_id) + .update( + { + "name": name, + "desc": desc, + "longdesc": longdesc, + "params": params, + "languages": languages, + "license": license, + "source_url": source_url, + } + ) ) - ) - self.session.flush() - return self.session.commit() + session.flush() + return session.commit() def download_model_results(self, task_id: int): m = aliased(Model) diff --git a/backend/app/infrastructure/repositories/round.py b/backend/app/infrastructure/repositories/round.py index 90fcd4c23..96e767e0e 100644 --- a/backend/app/infrastructure/repositories/round.py +++ b/backend/app/infrastructure/repositories/round.py @@ -23,27 +23,30 @@ def get_round_info_by_round_and_task(self, task_id: int, round_id: int): return round_info def increment_counter_examples_collected(self, round_id: int): - self.session.query(self.model).filter(self.model.id == round_id).update( - {self.model.total_collected: self.model.total_collected + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == round_id).update( + {self.model.total_collected: self.model.total_collected + 1} + ) + session.flush() + session.commit() def increment_counter_examples_fooled(self, round_id: int): - self.session.query(self.model).filter(self.model.id == round_id).update( - {self.model.total_fooled: self.model.total_fooled + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == round_id).update( + {self.model.total_fooled: self.model.total_fooled + 1} + ) + session.flush() + session.commit() def increment_counter_examples_verified_fooled(self, round_id: int, task_id: int): - self.session.query(self.model).filter( - (self.model.rid == round_id) & (self.model.tid == task_id) - ).update( - {self.model.total_verified_fooled: self.model.total_verified_fooled + 1} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.rid == round_id) & (self.model.tid == task_id) + ).update( + {self.model.total_verified_fooled: self.model.total_verified_fooled + 1} + ) + session.flush() + session.commit() def get_task_id_by_round_id(self, round_id: int): return ( diff --git a/backend/app/infrastructure/repositories/rounduserexampleinfo.py b/backend/app/infrastructure/repositories/rounduserexampleinfo.py index 8436ce6ed..ec38b1590 100644 --- a/backend/app/infrastructure/repositories/rounduserexampleinfo.py +++ b/backend/app/infrastructure/repositories/rounduserexampleinfo.py @@ -25,43 +25,47 @@ def verify_user_and_round_exist(self, user_id: int, round_id: int) -> bool: ) def create_user_and_round_example_info(self, round_id: int, user_id: int) -> None: - self.session.add( - self.model( - r_realid=round_id, - uid=user_id, - examples_submitted=0, - total_fooled=0, - total_verified_not_correct_fooled=0, + with self.session as session: + session.add( + self.model( + r_realid=round_id, + uid=user_id, + examples_submitted=0, + total_fooled=0, + total_verified_not_correct_fooled=0, + ) ) - ) - self.session.flush() - self.session.commit() + session.flush() + session.commit() def increment_counter_examples_submitted(self, round_id: int, user_id: int): - self.session.query(self.model).filter( - (self.model.r_realid == round_id) & (self.model.uid == user_id) - ).update({self.model.examples_submitted: self.model.examples_submitted + 1}) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.r_realid == round_id) & (self.model.uid == user_id) + ).update({self.model.examples_submitted: self.model.examples_submitted + 1}) + session.flush() + session.commit() def increment_counter_examples_fooled(self, round_id: int, user_id: int): - self.session.query(self.model).filter( - (self.model.r_realid == round_id) & (self.model.uid == user_id) - ).update({self.model.total_fooled: self.model.total_fooled + 1}) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.r_realid == round_id) & (self.model.uid == user_id) + ).update({self.model.total_fooled: self.model.total_fooled + 1}) + session.flush() + session.commit() def increment_examples_submitted_today(self, round_id: int, user_id: int): - self.session.query(self.model).filter( - (self.model.r_realid == round_id) & (self.model.uid == user_id) - ).update( - { - self.model.amount_examples_on_a_day: self.model.amount_examples_on_a_day - + 1 - } - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.r_realid == round_id) & (self.model.uid == user_id) + ).update( + { + self.model.amount_examples_on_a_day: self.model.amount_examples_on_a_day + + 1 + } + ) + session.flush() + session.commit() def amounts_examples_created_today(self, round_id: int, user_id: int): return ( @@ -85,28 +89,30 @@ def get_last_date_used(self, round_id: int, user_id: int): ) def update_last_used_and_counter(self, round_id: int, user_id: int): - self.session.query(self.model).filter( - (self.model.r_realid == round_id) & (self.model.uid == user_id) - ).update( - { - self.model.last_used: datetime.date.today(), - self.model.amount_examples_on_a_day: 0, - } - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.r_realid == round_id) & (self.model.uid == user_id) + ).update( + { + self.model.last_used: datetime.date.today(), + self.model.amount_examples_on_a_day: 0, + } + ) + session.flush() + session.commit() def create_first_entry_for_day(self, round_id: int, user_id: int): - self.session.query(self.model).filter( - (self.model.r_realid == round_id) & (self.model.uid == user_id) - ).update( - { - self.model.amount_examples_on_a_day: 1, - self.model.last_used: datetime.date.today(), - } - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.r_realid == round_id) & (self.model.uid == user_id) + ).update( + { + self.model.amount_examples_on_a_day: 1, + self.model.last_used: datetime.date.today(), + } + ) + session.flush() + session.commit() def get_counter_examples_submitted(self, round_id: int, user_id: int): return ( @@ -116,11 +122,12 @@ def get_counter_examples_submitted(self, round_id: int, user_id: int): ) def reset_counter_examples_submitted(self, round_id: int, user_id: int): - self.session.query(self.model).filter( - (self.model.r_realid == round_id) & (self.model.uid == user_id) - ).update({self.model.amount_examples_on_a_day: 0}) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter( + (self.model.r_realid == round_id) & (self.model.uid == user_id) + ).update({self.model.amount_examples_on_a_day: 0}) + session.flush() + session.commit() def number_of_examples_created(self, round_id: int, user_id: int): return ( diff --git a/backend/app/infrastructure/repositories/score.py b/backend/app/infrastructure/repositories/score.py index 8c5b65f69..634c8aa38 100644 --- a/backend/app/infrastructure/repositories/score.py +++ b/backend/app/infrastructure/repositories/score.py @@ -83,9 +83,10 @@ def fix_matthews_correlation(self, model_id: int): IS NOT NULL AND mid = :model_id """ ) - self.session.execute(sql, {"model_id": model_id}) - self.session.flush() - self.session.commit() + with self.session as session: + session.execute(sql, {"model_id": model_id}) + session.flush() + session.commit() def fix_f1_score(self, model_id: int): sql = text( @@ -101,6 +102,7 @@ def fix_f1_score(self, model_id: int): IS NOT NULL AND mid = :model_id """ ) - self.session.execute(sql, {"model_id": model_id}) - self.session.flush() - self.session.commit() + with self.session as session: + session.execute(sql, {"model_id": model_id}) + session.flush() + session.commit() diff --git a/backend/app/infrastructure/repositories/task.py b/backend/app/infrastructure/repositories/task.py index a15d9b6ed..9341cd753 100644 --- a/backend/app/infrastructure/repositories/task.py +++ b/backend/app/infrastructure/repositories/task.py @@ -44,9 +44,10 @@ def get_model_id_and_task_code(self, task): return instance def update_last_activity_date(self, task_id: int): - self.session.query(self.model).filter(self.model.id == task_id).update({}) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == task_id).update({}) + session.flush() + session.commit() def get_active_tasks_with_round_info(self): return ( @@ -113,11 +114,12 @@ def get_task_instructions(self, task_id: int): ) def update_task_instructions(self, task_id: int, instructions: dict): - self.session.query(self.model).filter_by(id=task_id).update( - {"general_instructions": instructions} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter_by(id=task_id).update( + {"general_instructions": instructions} + ) + session.flush() + session.commit() def get_challenges_types(self): return self.session.query(ChallengesTypes).all() @@ -144,11 +146,12 @@ def validate_no_duplicate_task_name(self, task_name: str): ) def update_config_yaml(self, task_id: int, config_yaml: dict): - self.session.query(self.model).filter_by(id=task_id).update( - {"config_yaml": config_yaml} - ) - self.session.flush() - self.session.commit() + with self.session as session: + session.query(self.model).filter_by(id=task_id).update( + {"config_yaml": config_yaml} + ) + session.flush() + session.commit() def get_s3_bucket_by_task_id(self, task_id: int): return ( diff --git a/backend/app/infrastructure/repositories/user.py b/backend/app/infrastructure/repositories/user.py index 5e53cd960..e143113ae 100644 --- a/backend/app/infrastructure/repositories/user.py +++ b/backend/app/infrastructure/repositories/user.py @@ -25,16 +25,18 @@ def create_user(self, email: str, password: str, username: str) -> dict: return self.add({"email": email, "password": password, "username": username}) def increment_examples_fooled(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - {self.model.total_fooled: self.model.total_fooled + 1} - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + {self.model.total_fooled: self.model.total_fooled + 1} + ) + session.commit() def increment_model_submitted_count(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - {self.model.models_submitted: self.model.models_submitted + 1} - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + {self.model.models_submitted: self.model.models_submitted + 1} + ) + session.commit() def get_user_email(self, user_id: int) -> str: return ( @@ -59,40 +61,45 @@ def get_is_admin(self, user_id: int) -> bool: ) is not None def increment_examples_verified(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - {self.model.examples_verified: self.model.examples_verified + 1} - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + {self.model.examples_verified: self.model.examples_verified + 1} + ) + session.commit() def increment_examples_verified_correct(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - { - self.model.examples_verified_correct: self.model.examples_verified_correct - + 1 - } - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + { + self.model.examples_verified_correct: self.model.examples_verified_correct + + 1 + } + ) + session.commit() def increment_examples_verified_correct_fooled(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - {self.model.total_verified_fooled: self.model.total_verified_fooled + 1} - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + {self.model.total_verified_fooled: self.model.total_verified_fooled + 1} + ) + session.commit() def increment_examples_verified_incorrect_fooled(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - { - self.model.total_verified_not_correct_fooled: self.model.total_verified_not_correct_fooled - + 1 - } - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + { + self.model.total_verified_not_correct_fooled: self.model.total_verified_not_correct_fooled + + 1 + } + ) + session.commit() def increment_examples_created(self, user_id: int): - self.session.query(self.model).filter(self.model.id == user_id).update( - {self.model.examples_submitted: self.model.examples_submitted + 1} - ) - self.session.commit() + with self.session as session: + session.query(self.model).filter(self.model.id == user_id).update( + {self.model.examples_submitted: self.model.examples_submitted + 1} + ) + session.commit() def get_badges_by_user_id(self, user_id: int) -> dict: return self.session.query(Badge).filter(Badge.uid == user_id).all() diff --git a/backend/app/infrastructure/utils/singleton.py b/backend/app/infrastructure/utils/singleton.py new file mode 100644 index 000000000..13d0a58eb --- /dev/null +++ b/backend/app/infrastructure/utils/singleton.py @@ -0,0 +1,17 @@ +# Copyright (c) MLCommons and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +class Singleton(type): + """ + Metaclass implementation of the Singleton pattern + Taken from https://stackoverflow.com/q/6760685 + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/backend/app/main.py b/backend/app/main.py index 4ce5d98ce..6cb835e9c 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -73,22 +73,3 @@ def read_root(): historical_data.router, prefix="/historical_data", tags=["historical_data"] ) app.include_router(round.router, prefix="/round", tags=["round"]) - - -class RefreshSessionMiddleware: - """ - Refreshes SQLAlchemy session on each call. - There may be issues since close_session is blocking and it is ran in an async function. - Possibly blocking the main asyncio loop. - Keep in mind that middlewares run on the main asyncio thread. - """ - - def __init__(self, app): - self.app = app - - async def __call__(self, scope, receive, send): - await self.app(scope, receive, send) - Connection().close_session() - - -app.add_middleware(RefreshSessionMiddleware)