Skip to content

Commit

Permalink
Localize access to REDIS_URL (#36)
Browse files Browse the repository at this point in the history
* Localize a bit access to REDIS_URL
  • Loading branch information
eyurtsev authored Nov 13, 2023
1 parent 639e67b commit 0eeb14a
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import orjson
from langchain.schema.messages import messages_from_dict
from langchain.utilities.redis import get_client
from redis.client import Redis as RedisType


def assistants_list_key(user_id: str):
Expand Down Expand Up @@ -33,16 +34,24 @@ def thread_messages_key(user_id: str, thread_id: str):
public_user_id = "eef39817-c173-4eb6-8be4-f77cf37054fb"


def dump(map: dict) -> dict:
def _dump(map: dict) -> dict:
return {k: orjson.dumps(v) if v is not None else None for k, v in map.items()}


def load(keys: list[str], values: list[bytes]) -> dict:
return {k: orjson.loads(v) if v is not None else None for k, v in zip(keys, values)}


def _get_redis_client() -> RedisType:
"""Get a Redis client."""
url = os.environ.get("REDIS_URL")
if not url:
raise ValueError("REDIS_URL not set")
return get_client(url)


def list_assistants(user_id: str):
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
ids = [orjson.loads(id) for id in client.smembers(assistants_list_key(user_id))]
with client.pipeline() as pipe:
for id in ids:
Expand All @@ -54,7 +63,7 @@ def list_assistants(user_id: str):
def list_public_assistants(assistant_ids: list[str]):
if not assistant_ids:
return []
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
ids = [
id
for id, is_public in zip(
Expand Down Expand Up @@ -84,19 +93,19 @@ def put_assistant(
"updated_at": datetime.utcnow(),
"public": public,
}
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
with client.pipeline() as pipe:
pipe.sadd(assistants_list_key(user_id), orjson.dumps(assistant_id))
pipe.hset(assistant_key(user_id, assistant_id), mapping=dump(saved))
pipe.hset(assistant_key(user_id, assistant_id), mapping=_dump(saved))
if public:
pipe.sadd(assistants_list_key(public_user_id), orjson.dumps(assistant_id))
pipe.hset(assistant_key(public_user_id, assistant_id), mapping=dump(saved))
pipe.hset(assistant_key(public_user_id, assistant_id), mapping=_dump(saved))
pipe.execute()
return saved


def list_threads(user_id: str):
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
ids = [orjson.loads(id) for id in client.smembers(threads_list_key(user_id))]
with client.pipeline() as pipe:
for id in ids:
Expand All @@ -106,7 +115,7 @@ def list_threads(user_id: str):


def get_thread_messages(user_id: str, thread_id: str):
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
messages = client.lrange(thread_messages_key(user_id, thread_id), 0, -1)
return {
"messages": [
Expand All @@ -124,10 +133,10 @@ def put_thread(user_id: str, thread_id: str, *, assistant_id: str, name: str):
"name": name,
"updated_at": datetime.utcnow(),
}
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
with client.pipeline() as pipe:
pipe.sadd(threads_list_key(user_id), orjson.dumps(thread_id))
pipe.hset(thread_key(user_id, thread_id), mapping=dump(saved))
pipe.hset(thread_key(user_id, thread_id), mapping=_dump(saved))
pipe.execute()
return saved

Expand Down

0 comments on commit 0eeb14a

Please sign in to comment.