Skip to content

Commit

Permalink
check before using fetch_versioned_implementation because it logs war…
Browse files Browse the repository at this point in the history
…nings that confuse users. (#2708)

Renamed get_is_ee_version to is_ee_version to be less redundant
  • Loading branch information
rkuo-danswer authored Oct 7, 2024
1 parent 30dc408 commit 1a3469d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 19 deletions.
4 changes: 4 additions & 0 deletions backend/danswer/background/celery/celery_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version


class RedisObjectHelper(ABC):
Expand Down Expand Up @@ -172,6 +173,9 @@ def generate_tasks(

async_results = []

if not global_version.is_ee_version():
return 0

try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
Expand Down
30 changes: 17 additions & 13 deletions backend/danswer/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback


Expand Down Expand Up @@ -87,21 +88,24 @@ def check_for_vespa_sync_task() -> None:
)

# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
if global_version.is_ee_version():
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)

user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
# We shouldn't actually get here if the ee version check works
pass

except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/background/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def kickoff_indexing_jobs(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
global_version.is_ee_version(),
pure=False,
)
if not run:
Expand All @@ -364,7 +364,7 @@ def kickoff_indexing_jobs(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
global_version.is_ee_version(),
pure=False,
)
if not run:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def get_application() -> FastAPI:
f"Starting Danswer Backend version {__version__} on http://{APP_HOST}:{str(APP_PORT)}/"
)

if global_version.get_is_ee_version():
if global_version.is_ee_version():
logger.notice("Running Enterprise Edition")

uvicorn.run(app, host=APP_HOST, port=APP_PORT)
6 changes: 3 additions & 3 deletions backend/danswer/utils/variable_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def __init__(self) -> None:
def set_ee(self) -> None:
self._is_ee = True

def get_is_ee_version(self) -> bool:
def is_ee_version(self) -> bool:
return self._is_ee


global_version = DanswerVersion()


def set_is_ee_based_on_env_variable() -> None:
if ENTERPRISE_EDITION_ENABLED and not global_version.get_is_ee_version():
if ENTERPRISE_EDITION_ENABLED and not global_version.is_ee_version():
logger.notice("Enterprise Edition enabled")
global_version.set_ee()

Expand Down Expand Up @@ -54,7 +54,7 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any:
implementation cannot be found or loaded.
"""
logger.debug("Fetching versioned implementation for %s.%s", module, attribute)
is_ee = global_version.get_is_ee_version()
is_ee = global_version.is_ee_version()

module_full = f"ee.{module}" if is_ee else module
try:
Expand Down

0 comments on commit 1a3469d

Please sign in to comment.