diff --git a/backend/app/storage.py b/backend/app/storage.py index 7be4d29e..2ea37c3a 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -4,6 +4,7 @@ import orjson from langchain.schema.messages import messages_from_dict from langchain.utilities.redis import get_client +from redis.client import Redis as RedisType def assistants_list_key(user_id: str): @@ -33,7 +34,7 @@ def thread_messages_key(user_id: str, thread_id: str): public_user_id = "eef39817-c173-4eb6-8be4-f77cf37054fb" -def dump(map: dict) -> dict: +def _dump(map: dict) -> dict: return {k: orjson.dumps(v) if v is not None else None for k, v in map.items()} @@ -41,8 +42,16 @@ def load(keys: list[str], values: list[bytes]) -> dict: return {k: orjson.loads(v) if v is not None else None for k, v in zip(keys, values)} +def _get_redis_client() -> RedisType: + """Get a Redis client.""" + url = os.environ.get("REDIS_URL") + if not url: + raise ValueError("REDIS_URL not set") + return get_client(url) + + def list_assistants(user_id: str): - client = get_client(os.environ.get("REDIS_URL")) + client = _get_redis_client() ids = [orjson.loads(id) for id in client.smembers(assistants_list_key(user_id))] with client.pipeline() as pipe: for id in ids: @@ -54,7 +63,7 @@ def list_assistants(user_id: str): def list_public_assistants(assistant_ids: list[str]): if not assistant_ids: return [] - client = get_client(os.environ.get("REDIS_URL")) + client = _get_redis_client() ids = [ id for id, is_public in zip( @@ -84,19 +93,19 @@ def put_assistant( "updated_at": datetime.utcnow(), "public": public, } - client = get_client(os.environ.get("REDIS_URL")) + client = _get_redis_client() with client.pipeline() as pipe: pipe.sadd(assistants_list_key(user_id), orjson.dumps(assistant_id)) - pipe.hset(assistant_key(user_id, assistant_id), mapping=dump(saved)) + pipe.hset(assistant_key(user_id, assistant_id), mapping=_dump(saved)) if public: pipe.sadd(assistants_list_key(public_user_id), orjson.dumps(assistant_id)) - pipe.hset(assistant_key(public_user_id, assistant_id), mapping=dump(saved)) + pipe.hset(assistant_key(public_user_id, assistant_id), mapping=_dump(saved)) pipe.execute() return saved def list_threads(user_id: str): - client = get_client(os.environ.get("REDIS_URL")) + client = _get_redis_client() ids = [orjson.loads(id) for id in client.smembers(threads_list_key(user_id))] with client.pipeline() as pipe: for id in ids: @@ -106,7 +115,7 @@ def list_threads(user_id: str): def get_thread_messages(user_id: str, thread_id: str): - client = get_client(os.environ.get("REDIS_URL")) + client = _get_redis_client() messages = client.lrange(thread_messages_key(user_id, thread_id), 0, -1) return { "messages": [ @@ -124,10 +133,10 @@ def put_thread(user_id: str, thread_id: str, *, assistant_id: str, name: str): "name": name, "updated_at": datetime.utcnow(), } - client = get_client(os.environ.get("REDIS_URL")) + client = _get_redis_client() with client.pipeline() as pipe: pipe.sadd(threads_list_key(user_id), orjson.dumps(thread_id)) - pipe.hset(thread_key(user_id, thread_id), mapping=dump(saved)) + pipe.hset(thread_key(user_id, thread_id), mapping=_dump(saved)) pipe.execute() return saved