diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index d9d61b6c8a6..f08bfd17e2f 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -21,6 +21,7 @@ ) from danswer.db.document_set import construct_document_select_by_docset from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import global_version class RedisObjectHelper(ABC): @@ -172,6 +173,9 @@ def generate_tasks( async_results = [] + if not global_version.is_ee_version(): + return 0 + try: construct_document_select_by_usergroup = fetch_versioned_implementation( "danswer.db.user_group", diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 62806c7b81d..3f347cbab3d 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -48,6 +48,7 @@ from danswer.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) +from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import noop_fallback @@ -87,21 +88,24 @@ def check_for_vespa_sync_task() -> None: ) # check if any user groups are not synced - try: - fetch_user_groups = fetch_versioned_implementation( - "danswer.db.user_group", "fetch_user_groups" - ) + if global_version.is_ee_version(): + try: + fetch_user_groups = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_groups" + ) - user_groups = fetch_user_groups( - db_session=db_session, only_up_to_date=False - ) - for usergroup in user_groups: - try_generate_user_group_sync_tasks( - usergroup, db_session, r, lock_beat + user_groups = fetch_user_groups( + db_session=db_session, only_up_to_date=False ) - except ModuleNotFoundError: - # Always exceptions on the MIT version, which is expected - pass + for usergroup in user_groups: + try_generate_user_group_sync_tasks( + usergroup, db_session, r, lock_beat + ) + except ModuleNotFoundError: + # Always exceptions on the MIT version, which is expected + # We shouldn't actually get here if the ee version check works + pass + except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 57d05513ac2..773165c5161 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -353,7 +353,7 @@ def kickoff_indexing_jobs( run_indexing_entrypoint, attempt.id, attempt.connector_credential_pair_id, - global_version.get_is_ee_version(), + global_version.is_ee_version(), pure=False, ) if not run: @@ -364,7 +364,7 @@ def kickoff_indexing_jobs( run_indexing_entrypoint, attempt.id, attempt.connector_credential_pair_id, - global_version.get_is_ee_version(), + global_version.is_ee_version(), pure=False, ) if not run: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index b9231a9c561..d3aa8b00efd 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -329,7 +329,7 @@ def get_application() -> FastAPI: f"Starting Danswer Backend version {__version__} on http://{APP_HOST}:{str(APP_PORT)}/" ) - if global_version.get_is_ee_version(): + if global_version.is_ee_version(): logger.notice("Running Enterprise Edition") uvicorn.run(app, host=APP_HOST, port=APP_PORT) diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py index 55f296aa8e7..dfe6def2a56 100644 --- a/backend/danswer/utils/variable_functionality.py +++ b/backend/danswer/utils/variable_functionality.py @@ -16,7 +16,7 @@ def __init__(self) -> None: def set_ee(self) -> None: self._is_ee = True - def get_is_ee_version(self) -> bool: + def is_ee_version(self) -> bool: return self._is_ee @@ -24,7 +24,7 @@ def get_is_ee_version(self) -> bool: def set_is_ee_based_on_env_variable() -> None: - if ENTERPRISE_EDITION_ENABLED and not global_version.get_is_ee_version(): + if ENTERPRISE_EDITION_ENABLED and not global_version.is_ee_version(): logger.notice("Enterprise Edition enabled") global_version.set_ee() @@ -54,7 +54,7 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any: implementation cannot be found or loaded. """ logger.debug("Fetching versioned implementation for %s.%s", module, attribute) - is_ee = global_version.get_is_ee_version() + is_ee = global_version.is_ee_version() module_full = f"ee.{module}" if is_ee else module try: