Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate DataStore with MultiPack #834

Merged
merged 7 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 210 additions & 28 deletions forte/data/base_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import copy
import gzip
import pickle
Expand All @@ -27,19 +28,33 @@
Union,
Iterator,
Dict,
Tuple,
Any,
Iterable,
)


from functools import partial
from typing_inspect import get_origin
from packaging.version import Version
import jsonpickle

from forte.common import ProcessExecutionException, EntryNotFoundError
from forte.data.container import EntryContainer
from forte.data.index import BaseIndex
from forte.data.ontology.core import Entry, EntryType, GroupType, LinkType
from forte.version import PACK_VERSION, DEFAULT_PACK_VERSION
from forte.data.base_store import BaseStore
from forte.data.container import EntryContainer
from forte.data.ontology.core import (
Entry,
EntryType,
GroupType,
LinkType,
FList,
FDict,
)
from forte.version import (
PACK_VERSION,
DEFAULT_PACK_VERSION,
PACK_ID_COMPATIBLE_VERSION,
)

logger = logging.getLogger(__name__)

__all__ = ["BasePack", "BaseMeta", "PackType"]

Expand Down Expand Up @@ -97,26 +112,19 @@ class BasePack(EntryContainer[EntryType, LinkType, GroupType]):
# pylint: disable=too-many-public-methods
def __init__(self, pack_name: Optional[str] = None):
super().__init__()
self.links: List[LinkType] = []
self.groups: List[GroupType] = []
self.pack_version: str = PACK_VERSION

self._meta: BaseMeta = self._init_meta(pack_name)
self._index: BaseIndex = BaseIndex()

self._data_store: BaseStore

self.__control_component: Optional[str] = None

# This Dict maintains a mapping from entry's tid to the Entry object
# itself (for MultiPack) or entry's tid (for DataPack) and the component
# This Dict maintains a mapping from entry's tid to the component
# name associated with the entry.
# The component name is used for tracking the "creator" of this entry.
# TODO: Will need to unify the format for MultiPack and DataPack after
# DataStore is integrated with MultiPack and MultiPack entries. In
# future we should only maintain a mapping from entry's tid to the
# corresponding component, i.e., Dict[int, Optional[str]].
self._pending_entries: Dict[
int, Tuple[Union[int, Entry], Optional[str]]
] = {}
self._pending_entries: Dict[int, Optional[str]] = {}

def __getstate__(self):
state = self.__dict__.copy()
Expand All @@ -126,6 +134,20 @@ def __getstate__(self):
return state

def __setstate__(self, state):
# Pack version checking. We will no longer provide support for
# serialized Pack whose "pack_version" is less than
# PACK_ID_COMPATIBLE_VERSION.
pack_version: str = (
state["pack_version"]
if "pack_version" in state
else DEFAULT_PACK_VERSION
)
if Version(pack_version) < Version(PACK_ID_COMPATIBLE_VERSION):
raise ValueError(
"The pack cannot be deserialized because its version "
f"{pack_version} is outdated. We only support pack with "
f"version greater or equal to {PACK_ID_COMPATIBLE_VERSION}"
)
super().__setstate__(state)
if "meta" in self.__dict__:
self._meta = self.__dict__.pop("meta")
Expand All @@ -136,9 +158,6 @@ def __setstate__(self, state):
def _init_meta(self, pack_name: Optional[str] = None) -> BaseMeta:
raise NotImplementedError

def get_control_component(self):
return self.__control_component

def set_meta(self, **kwargs):
for k, v in kwargs.items():
if not hasattr(self._meta, k):
Expand All @@ -154,9 +173,15 @@ def __iter__(self) -> Iterator[EntryType]:
raise NotImplementedError

def __del__(self):
if len(self._pending_entries) > 0:
num_remaning_entries: int = len(self._pending_entries)
if num_remaning_entries > 0:
# Remove all the remaining tids in _pending_entries.
tids: List = list(self._pending_entries.keys())
for tid in tids:
self._pending_entries.pop(tid)
self._data_store.delete_entry(tid=tid)
raise ProcessExecutionException(
f"There are {len(self._pending_entries)} "
f"There are {num_remaning_entries} "
f"entries not added to the index correctly."
)

Expand Down Expand Up @@ -224,7 +249,6 @@ def from_string(cls, data_content: str) -> "BasePack":

return pack

@abstractmethod
def delete_entry(self, entry: EntryType):
r"""Remove the entry from the pack.

Expand All @@ -234,7 +258,14 @@ def delete_entry(self, entry: EntryType):
Returns:
None
"""
raise NotImplementedError
self._data_store.delete_entry(tid=entry.tid)

# update basic index
self._index.remove_entry(entry)

# set other index invalid
self._index.turn_link_index_switch(on=False)
self._index.turn_group_index_switch(on=False)

def add_entry(
self, entry: Union[Entry, int], component_name: Optional[str] = None
Expand Down Expand Up @@ -281,7 +312,7 @@ def add_all_remaining_entries(self, component: Optional[str] = None):
Returns:
None
"""
for entry, c in list(self._pending_entries.values()):
for entry, c in list(self._pending_entries.items()):
c_ = component if component else c
self.add_entry(entry, c_)
self._pending_entries.clear()
Expand Down Expand Up @@ -430,20 +461,171 @@ def on_entry_creation(
# Use the auto-inferred control component.
c = self.__control_component

def entry_getter(cls: Entry, attr_name: str, field_type):
"""A getter function for dataclass fields of entry object.
When the field contains ``tid``s, we will convert them to entry
object on the fly.

Args:
cls: An ``Entry`` class object.
attr_name: The name of the attribute.
field_type: The type of the attribute.
"""
data_store_ref = (
cls.pack._data_store # pylint: disable=protected-access
)
attr_val = data_store_ref.get_attribute(
tid=cls.tid, attr_name=attr_name
)
if field_type in (FList, FDict):
# Generate FList/FDict object on the fly
return field_type(parent_entry=cls, data=attr_val)
try:
# TODO: Find a better solution to determine if a field is Entry
# Will be addressed by https://github.com/asyml/forte/issues/835
# Convert tid to entry object on the fly
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(attr_val, int):
# Single pack entry
return cls.pack.get_entry(tid=attr_val)
# The condition below is to check whether the attribute's value
# is a pair of integers - `(pack_id, tid)`. If so we may have
# encountered a `tid` that can only be resolved by
# `MultiPack.get_subentry`.
elif (
isinstance(attr_val, tuple)
and len(attr_val) == 2
and all(isinstance(element, int) for element in attr_val)
and hasattr(cls.pack, "get_subentry")
):
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
# Multi pack entry
return cls.pack.get_subentry(*attr_val)
except KeyError:
pass
return attr_val

def entry_setter(cls: Entry, value: Any, attr_name: str, field_type):
"""A setter function for dataclass fields of entry object.
When the value contains entry objects, we will convert them into
``tid``s before storing to ``DataStore``.
hunterhector marked this conversation as resolved.
Show resolved Hide resolved

Args:
cls: An ``Entry`` class object.
value: The value to be assigned to the attribute.
attr_name: The name of the attribute.
field_type: The type of the attribute.
"""
attr_value: Any
data_store_ref = (
cls.pack._data_store # pylint: disable=protected-access
)
# Assumption: Users will not assign value to a FList/FDict field.
# Only internal methods can set the FList/FDict field, and value's
# type has to be Iterator[Entry]/Dict[Any, Entry].
if field_type is FList:
try:
attr_value = [entry.tid for entry in value]
except AttributeError as e:
raise ValueError(
"You are trying to assign value to a `FList` field, "
"which can only accept an iterator of `Entry` objects."
) from e
elif field_type is FDict:
try:
attr_value = {
key: entry.tid for key, entry in value.items()
}
except AttributeError as e:
raise ValueError(
"You are trying to assign value to a `FDict` field, "
"which can only accept a mapping whose values are "
"`Entry` objects."
) from e
elif isinstance(value, Entry):
attr_value = (
value.tid
if value.pack.pack_id == cls.pack.pack_id
# When value's pack and cls's pack are not the same, we
# assume that cls.pack is a MultiPack, which will resolve
# value.tid using MultiPack.get_subentry(pack_id, tid).
# In this case, both pack_id and tid should be stored.
else (value.pack.pack_id, value.tid)
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
)
else:
attr_value = value
data_store_ref.set_attribute(
tid=cls.tid, attr_name=attr_name, attr_value=attr_value
)

# Save the input entry object in DataStore
self._save_entry_to_data_store(entry=entry)

# Register property functions for all dataclass fields.
for name, field in entry.__dataclass_fields__.items():
# Convert the typing annotation to the original class.
# This will be used to determine if a field is FList/FDict.
field_type = get_origin(field.type)
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
setattr(
type(entry),
name,
# property(fget, fset) will register a conversion layer
# that specifies how to retrieve/assign value of this field.
property(
# We need to bound the attribute name and field type here
# for the getter and setter of each field.
fget=partial(
entry_getter, attr_name=name, field_type=field_type
),
fset=partial(
entry_setter, attr_name=name, field_type=field_type
),
hunterhector marked this conversation as resolved.
Show resolved Hide resolved
),
)

# Record that this entry hasn't been added to the index yet.
self._pending_entries[entry.tid] = entry, c
self._pending_entries[entry.tid] = c

# TODO: how to make this return the precise type here?
def get_entry(self, tid: int) -> EntryType:
r"""Look up the entry_index with key ``ptr``. Specific implementation
r"""Look up the entry_index with ``tid``. Specific implementation
depends on the actual class."""
entry: EntryType = self._index.get_entry(tid)
try:
# Try to find entry in DataIndex
entry: EntryType = self._index.get_entry(tid)
except KeyError:
# Find entry in DataStore
entry = self._get_entry_from_data_store(tid=tid)
if entry is None:
raise KeyError(
f"There is no entry with tid '{tid}'' in this datapack"
)
return entry

def get_entry_raw(self, tid: int) -> List:
r"""Retrieve the raw entry data in list format from DataStore."""
return self._data_store.get_entry(tid=tid)[0]

@abstractmethod
def _save_entry_to_data_store(self, entry: Entry):
r"""Save an existing entry object into DataStore"""
raise NotImplementedError

@abstractmethod
def _get_entry_from_data_store(self, tid: int) -> EntryType:
r"""Generate a class object from entry data in DataStore"""
raise NotImplementedError

@property
@abstractmethod
def links(self):
r"""A List container of all links in this data pack."""
raise NotImplementedError

@property
@abstractmethod
def groups(self):
r"""A List container of all groups in this pack."""
raise NotImplementedError

@abstractmethod
def get_data(
self, context_type, request, skip_k
Expand Down
Loading