Skip to content

Commit

Permalink
Fix issue when fetching threads prior to langgraph 0.0.31
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Apr 10, 2024
1 parent 56a3016 commit 8c98e03
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions backend/app/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pickle
from datetime import datetime
from typing import AsyncIterator, Optional
from enum import Enum
from io import BytesIO
from typing import Any, AsyncIterator, Optional

from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
Expand All @@ -9,6 +11,23 @@
from app.lifespan import get_pg_pool


class PostgresUnpickler(pickle.Unpickler):
def find_class(self, module_name: str, global_name: str) -> Any:
# backwards compatibility for threads prior to langgraph 0.0.31
print(module_name, global_name)
if (
module_name == "langgraph.pregel.reserved"
and global_name == "ReservedChannels"
):

class ReservedChannels(str, Enum):
is_last_step = "is_last_step"

return ReservedChannels

return super().find_class(module_name, global_name)


class PostgresCheckpoint(BaseCheckpointSaver):
class Config:
arbitrary_types_allowed = True
Expand Down Expand Up @@ -47,7 +66,7 @@ async def alist(self, config: RunnableConfig) -> AsyncIterator[CheckpointTuple]:
"thread_ts": value[1],
}
},
pickle.loads(value[0]),
PostgresUnpickler(BytesIO(value[0])).load(),
{
"configurable": {
"thread_id": thread_id,
Expand All @@ -70,7 +89,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
):
return CheckpointTuple(
config,
pickle.loads(value[0]),
PostgresUnpickler(BytesIO(value[0])).load(),
{
"configurable": {
"thread_id": thread_id,
Expand All @@ -92,7 +111,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"thread_ts": value[1],
}
},
pickle.loads(value[0]),
PostgresUnpickler(BytesIO(value[0])).load(),
{
"configurable": {
"thread_id": thread_id,
Expand Down

0 comments on commit 8c98e03

Please sign in to comment.