Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add json method for pack serialization/deserialization #917

Merged
merged 13 commits into from
Jan 12, 2023
6 changes: 6 additions & 0 deletions forte/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@
# Name of the key to access a set of parent names of an entry type from
# ``_type_attributes`` of ``DataStore``.
PARENT_CLASS_KEY = "parent_class"

# Name of the class field in JSON serialization schema of BasePack
JSON_CLASS_FIELD = "_json_class"

# Name of the state field in JSON serialization schema of BasePack
JSON_STATE_FIELD = "_json_state"
77 changes: 61 additions & 16 deletions forte/data/base_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import copy
import gzip
import json
import pickle
import uuid
from abc import abstractmethod
Expand All @@ -38,6 +39,7 @@
import jsonpickle

from forte.common import ProcessExecutionException, EntryNotFoundError
from forte.common.constants import JSON_CLASS_FIELD, JSON_STATE_FIELD
from forte.data.index import BaseIndex
from forte.data.base_store import BaseStore
from forte.data.container import EntryContainer
Expand All @@ -56,6 +58,7 @@
DEFAULT_PACK_VERSION,
PACK_ID_COMPATIBLE_VERSION,
)
from forte.utils import get_full_module_name, get_class

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,7 +133,7 @@ def __init__(self, pack_name: Optional[str] = None):
self._pending_entries: Dict[int, Optional[str]] = {}

def __getstate__(self):
state = self.__dict__.copy()
state = super().__getstate__()
state.pop("_index")
state.pop("_pending_entries")
state.pop("_BasePack__control_component")
Expand Down Expand Up @@ -203,7 +206,7 @@ def pack_name(self, pack_name: str):
def _deserialize(
cls,
data_source: Union[Path, str],
serialize_method: str = "jsonpickle",
serialize_method: str = "json",
zip_pack: bool = False,
) -> "BasePack[Any, Any, Any]":
"""
Expand All @@ -216,8 +219,8 @@ def _deserialize(
serialization.
serialize_method: The method used to serialize the data, this
should be the same as how serialization is done. The current
options are `jsonpickle` and `pickle`. The default method
is `jsonpickle`.
options are `json`, `jsonpickle` and `pickle`. The default method
is `json`.
zip_pack: Boolean value indicating whether the input source is
zipped.

Expand All @@ -226,9 +229,9 @@ def _deserialize(
"""
_open = gzip.open if zip_pack else open

if serialize_method == "jsonpickle":
if serialize_method in ("jsonpickle", "json"):
with _open(data_source, mode="rt") as f: # type: ignore
pack = cls.from_string(f.read())
pack = cls.from_string(f.read(), json_method=serialize_method)
else:
with _open(data_source, mode="rb") as f: # type: ignore
pack = pickle.load(f)
Expand All @@ -239,8 +242,30 @@ def _deserialize(
return pack

@classmethod
def from_string(cls, data_content: str) -> "BasePack":
pack = jsonpickle.decode(data_content)
def from_string(
cls, data_content: str, json_method: str = "json"
) -> "BasePack":
if json_method == "jsonpickle":
pack = jsonpickle.decode(data_content)
elif json_method == "json":

def object_hook(json_dict):
"""
Custom object hook for JSON deserialization. It will call
`__setstate__` to deserialize the json content into a class
object.
"""
if json_dict.keys() == {JSON_STATE_FIELD, JSON_CLASS_FIELD}:
state = json_dict[JSON_STATE_FIELD]
obj_type = get_class(json_dict[JSON_CLASS_FIELD])
obj = obj_type.__new__(obj_type)
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
obj.__setstate__(state)
return obj
return json_dict

pack = json.loads(data_content, object_hook=object_hook)
else:
raise ValueError(f"Unsupported JSON method {json_method}.")
if not hasattr(pack, "pack_version"):
pack.pack_version = DEFAULT_PACK_VERSION

Expand Down Expand Up @@ -317,7 +342,7 @@ def add_all_remaining_entries(self, component: Optional[str] = None):
def to_string(
self,
drop_record: Optional[bool] = False,
json_method: str = "jsonpickle",
json_method: str = "json",
indent: Optional[int] = None,
) -> str:
"""
Expand All @@ -326,8 +351,8 @@ def to_string(
Args:
drop_record: Whether to drop the creation records, default is False.
json_method: What method is used to convert data pack to json.
Only supports `json_pickle` for now. Default value is
`json_pickle`.
Only supports `json` and `jsonpickle` for now. Default value is
`json`.
indent: The indent used for json string.

Returns: String representation of the data pack.
Expand All @@ -337,6 +362,24 @@ def to_string(
self._field_records.clear()
if json_method == "jsonpickle":
return jsonpickle.encode(self, unpicklable=True, indent=indent)
elif json_method == "json":

def json_serialize_handler(obj):
"""
Custom object handler for JSON serialization. It will call
`__getstate__` to serialize a class object into the its json
format.
"""
if hasattr(obj, "__getstate__"):
return {
JSON_CLASS_FIELD: get_full_module_name(obj),
JSON_STATE_FIELD: obj.__getstate__(),
}
raise TypeError(f"Type {type(obj)} not serializable")

return json.dumps(
self, indent=indent, default=json_serialize_handler
)
else:
raise ValueError(f"Unsupported JSON method {json_method}.")

Expand All @@ -345,7 +388,7 @@ def serialize(
output_path: Union[str, Path],
zip_pack: bool = False,
drop_record: bool = False,
serialize_method: str = "jsonpickle",
serialize_method: str = "json",
indent: Optional[int] = None,
):
r"""
Expand All @@ -357,8 +400,8 @@ def serialize(
zip_pack: Whether to compress the result with `gzip`.
drop_record: Whether to drop the creation records, default is False.
serialize_method: The method used to serialize the data. Currently
supports `jsonpickle` (outputs str) and Python's built-in
`pickle` (outputs bytes).
supports `json` (outputs str), `jsonpickle` (outputs str) and
Python's built-in `pickle` (outputs bytes).
indent: Whether to indent the file if written as JSON.

Returns: Results of serialization.
Expand All @@ -375,10 +418,12 @@ def serialize(
if serialize_method == "pickle":
with _open(output_path, mode="wb") as pickle_out:
pickle.dump(self, pickle_out)
elif serialize_method == "jsonpickle":
elif serialize_method in ("jsonpickle", "json"):
with _open(output_path, mode="wt", encoding="utf-8") as json_out:
json_out.write(
self.to_string(drop_record, "jsonpickle", indent=indent)
self.to_string(
drop_record, json_method=serialize_method, indent=indent
)
)
else:
raise NotImplementedError(
Expand Down
18 changes: 18 additions & 0 deletions forte/data/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def __setstate__(self, state):
if "field_records" in self.__dict__:
self._field_records = self.__dict__.pop("field_records")

for record_field in ("_creation_records", "_field_records"):
setattr(
self,
record_field,
{
key: set(val)
for key, val in getattr(self, record_field).items()
},
)

def __getstate__(self):
state = self.__dict__.copy()
for record_field in ("_creation_records", "_field_records"):
state[record_field] = {
key: list(val) for key, val in state.pop(record_field).items()
}
return state

@abstractmethod
def on_entry_creation(self, entry: E):
raise NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def get_original_index(
def deserialize(
cls,
data_source: Union[Path, str],
serialize_method: str = "jsonpickle",
serialize_method: str = "json",
zip_pack: bool = False,
) -> "DataPack":
"""
Expand All @@ -904,8 +904,8 @@ def deserialize(
data_source: The path storing data source.
serialize_method: The method used to serialize the data, this
should be the same as how serialization is done. The current
options are `jsonpickle` and `pickle`. The default method
is `jsonpickle`.
options are `json`, `jsonpickle` and `pickle`. The default method
is `json`.
zip_pack: Boolean value indicating whether the input source is
zipped.

Expand Down
16 changes: 5 additions & 11 deletions forte/data/multi_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from pathlib import Path
from typing import Dict, List, Union, Iterator, Optional, Type, Any, Tuple, cast

import jsonpickle

from packaging.version import Version
from sortedcontainers import SortedList

Expand All @@ -41,7 +39,6 @@
)
from forte.data.types import DataRequest
from forte.utils import get_full_module_name
from forte.version import DEFAULT_PACK_VERSION


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -886,7 +883,7 @@ def get( # type: ignore
def deserialize(
cls,
data_path: Union[Path, str],
serialize_method: str = "jsonpickle",
serialize_method: str = "json",
zip_pack: bool = False,
) -> "MultiPack":
"""
Expand Down Expand Up @@ -923,18 +920,15 @@ def deserialize(
return mp

@classmethod
def from_string(cls, data_content: str):
def from_string(cls, data_content: str, json_method: str = "json"):
# pylint: disable=protected-access
# can not use explict type hint for mp as pylint does not allow type change
# from base_pack to multi_pack which is problematic so use jsonpickle instead

mp = jsonpickle.decode(data_content)
if not hasattr(mp, "pack_version"):
mp.pack_version = DEFAULT_PACK_VERSION
mp = super().from_string(data_content, json_method)
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
# (fix 595) change the dictionary's key after deserialization from str back to int
mp._inverse_pack_ref = { # pylint: disable=no-member
mp._inverse_pack_ref = { # type: ignore # pylint: disable=no-member
int(k): v
for k, v in mp._inverse_pack_ref.items() # pylint: disable=no-member
for k, v in mp._inverse_pack_ref.items() # type: ignore # pylint: disable=no-member
}

return mp
Expand Down
1 change: 1 addition & 0 deletions forte/data/ontology/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,7 @@ def __getstate__(self):
# Entry store is being integrated into DataStore
state = self.__dict__.copy()
state["modality"] = self.modality_name
state.pop("_Entry__pack")

if isinstance(state["_cache"], np.ndarray):
state["_cache"] = list(self._cache.tolist()) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions forte/processors/base/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def initialize(self, resources: Resources, configs: Config):
self._zip_pack = configs.zip_pack
self._indent = configs.indent

if self.configs.serialize_method == "jsonpickle":
if self.configs.serialize_method in ("jsonpickle", "json"):
self._suffix = ".json.gz" if self._zip_pack else ".json"
else:
self._suffix = ".pickle.gz" if self._zip_pack else ".pickle"
Expand Down Expand Up @@ -197,7 +197,7 @@ def initialize(self, resources: Resources, configs: Config):
ensure_dir(multi_index)
self.multi_idx_out = open(multi_index, "w", encoding="utf-8")

if self.configs.serialize_method == "jsonpickle":
if self.configs.serialize_method in ("jsonpickle", "json"):
self._suffix = ".json.gz" if self.configs.zip_pack else ".json"
else:
self._suffix = ".pickle.gz" if self.configs.zip_pack else ".pickle"
Expand Down
Loading