Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Localize access to REDIS_URL #36

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import orjson
from langchain.schema.messages import messages_from_dict
from langchain.utilities.redis import get_client
from redis.client import Redis as RedisType


def assistants_list_key(user_id: str):
Expand Down Expand Up @@ -33,16 +34,24 @@ def thread_messages_key(user_id: str, thread_id: str):
public_user_id = "eef39817-c173-4eb6-8be4-f77cf37054fb"


def dump(map: dict) -> dict:
def _dump(map: dict) -> dict:
return {k: orjson.dumps(v) if v is not None else None for k, v in map.items()}


def load(keys: list[str], values: list[bytes]) -> dict:
return {k: orjson.loads(v) if v is not None else None for k, v in zip(keys, values)}


def _get_redis_client() -> RedisType:
"""Get a Redis client."""
url = os.environ.get("REDIS_URL")
if not url:
raise ValueError("REDIS_URL not set")
return get_client(url)


def list_assistants(user_id: str):
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
ids = [orjson.loads(id) for id in client.smembers(assistants_list_key(user_id))]
with client.pipeline() as pipe:
for id in ids:
Expand All @@ -54,7 +63,7 @@ def list_assistants(user_id: str):
def list_public_assistants(assistant_ids: list[str]):
if not assistant_ids:
return []
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
ids = [
id
for id, is_public in zip(
Expand Down Expand Up @@ -84,19 +93,19 @@ def put_assistant(
"updated_at": datetime.utcnow(),
"public": public,
}
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
with client.pipeline() as pipe:
pipe.sadd(assistants_list_key(user_id), orjson.dumps(assistant_id))
pipe.hset(assistant_key(user_id, assistant_id), mapping=dump(saved))
pipe.hset(assistant_key(user_id, assistant_id), mapping=_dump(saved))
if public:
pipe.sadd(assistants_list_key(public_user_id), orjson.dumps(assistant_id))
pipe.hset(assistant_key(public_user_id, assistant_id), mapping=dump(saved))
pipe.hset(assistant_key(public_user_id, assistant_id), mapping=_dump(saved))
pipe.execute()
return saved


def list_threads(user_id: str):
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
ids = [orjson.loads(id) for id in client.smembers(threads_list_key(user_id))]
with client.pipeline() as pipe:
for id in ids:
Expand All @@ -106,7 +115,7 @@ def list_threads(user_id: str):


def get_thread_messages(user_id: str, thread_id: str):
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
messages = client.lrange(thread_messages_key(user_id, thread_id), 0, -1)
return {
"messages": [
Expand All @@ -124,10 +133,10 @@ def put_thread(user_id: str, thread_id: str, *, assistant_id: str, name: str):
"name": name,
"updated_at": datetime.utcnow(),
}
client = get_client(os.environ.get("REDIS_URL"))
client = _get_redis_client()
with client.pipeline() as pipe:
pipe.sadd(threads_list_key(user_id), orjson.dumps(thread_id))
pipe.hset(thread_key(user_id, thread_id), mapping=dump(saved))
pipe.hset(thread_key(user_id, thread_id), mapping=_dump(saved))
pipe.execute()
return saved

Expand Down