From 2b2f1cca02888ff9354462972b8fc563c6718a85 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Tue, 14 Jun 2022 23:55:54 -0400 Subject: [PATCH 1/5] Integrate DataStore with MultiPack --- forte/data/__init__.py | 1 + forte/data/base_pack.py | 200 +++++++++-- forte/data/base_store.py | 223 +++++++++++- forte/data/data_pack.py | 339 ++---------------- forte/data/data_store.py | 131 +++++-- forte/data/entry_converter.py | 209 +++++++++++ forte/data/multi_pack.py | 239 ++++++------ forte/data/ontology/top.py | 76 ++-- forte/processors/nlp/ner_predictor.py | 2 +- .../forte/data/entry_data_structures_test.py | 1 - 10 files changed, 886 insertions(+), 535 deletions(-) create mode 100644 forte/data/entry_converter.py diff --git a/forte/data/__init__.py b/forte/data/__init__.py index 01858ebca..1b06afe10 100644 --- a/forte/data/__init__.py +++ b/forte/data/__init__.py @@ -20,3 +20,4 @@ from forte.data.data_store import * from forte.data.selector import * from forte.data.index import * +from forte.data.entry_converter import * diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index 94081acca..81a18c353 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import copy import gzip import pickle @@ -27,19 +28,33 @@ Union, Iterator, Dict, - Tuple, Any, Iterable, ) - - +from functools import partial +from typing_inspect import get_origin +from packaging.version import Version import jsonpickle from forte.common import ProcessExecutionException, EntryNotFoundError -from forte.data.container import EntryContainer from forte.data.index import BaseIndex -from forte.data.ontology.core import Entry, EntryType, GroupType, LinkType -from forte.version import PACK_VERSION, DEFAULT_PACK_VERSION +from forte.data.base_store import BaseStore +from forte.data.container import EntryContainer +from forte.data.ontology.core import ( + Entry, + EntryType, + GroupType, + LinkType, + FList, + FDict, +) +from forte.version import ( + PACK_VERSION, + DEFAULT_PACK_VERSION, + PACK_ID_COMPATIBLE_VERSION, +) + +logger = logging.getLogger(__name__) __all__ = ["BasePack", "BaseMeta", "PackType"] @@ -97,26 +112,19 @@ class BasePack(EntryContainer[EntryType, LinkType, GroupType]): # pylint: disable=too-many-public-methods def __init__(self, pack_name: Optional[str] = None): super().__init__() - self.links: List[LinkType] = [] - self.groups: List[GroupType] = [] self.pack_version: str = PACK_VERSION self._meta: BaseMeta = self._init_meta(pack_name) self._index: BaseIndex = BaseIndex() + self._data_store: BaseStore + self.__control_component: Optional[str] = None - # This Dict maintains a mapping from entry's tid to the Entry object - # itself (for MultiPack) or entry's tid (for DataPack) and the component + # This Dict maintains a mapping from entry's tid to the component # name associated with the entry. # The component name is used for tracking the "creator" of this entry. - # TODO: Will need to unify the format for MultiPack and DataPack after - # DataStore is integrated with MultiPack and MultiPack entries. In - # future we should only maintain a mapping from entry's tid to the - # corresponding component, i.e., Dict[int, Optional[str]]. - self._pending_entries: Dict[ - int, Tuple[Union[int, Entry], Optional[str]] - ] = {} + self._pending_entries: Dict[int, Optional[str]] = {} def __getstate__(self): state = self.__dict__.copy() @@ -126,6 +134,20 @@ def __getstate__(self): return state def __setstate__(self, state): + # Pack version checking. We will no longer provide support for + # serialized Pack whose "pack_version" is less than + # PACK_ID_COMPATIBLE_VERSION. + pack_version: str = ( + state["pack_version"] + if "pack_version" in state + else DEFAULT_PACK_VERSION + ) + if Version(pack_version) < Version(PACK_ID_COMPATIBLE_VERSION): + raise ValueError( + "The pack cannot be deserialized because its version " + f"{pack_version} is outdated. We only support pack with " + f"version greater or equal to {PACK_ID_COMPATIBLE_VERSION}" + ) super().__setstate__(state) if "meta" in self.__dict__: self._meta = self.__dict__.pop("meta") @@ -136,9 +158,6 @@ def __setstate__(self, state): def _init_meta(self, pack_name: Optional[str] = None) -> BaseMeta: raise NotImplementedError - def get_control_component(self): - return self.__control_component - def set_meta(self, **kwargs): for k, v in kwargs.items(): if not hasattr(self._meta, k): @@ -154,9 +173,15 @@ def __iter__(self) -> Iterator[EntryType]: raise NotImplementedError def __del__(self): - if len(self._pending_entries) > 0: + num_remaning_entries: int = len(self._pending_entries) + if num_remaning_entries > 0: + # Remove all the remaining tids in _pending_entries. + tids: List = list(self._pending_entries.keys()) + for tid in tids: + self._pending_entries.pop(tid) + self._data_store.delete_entry(tid=tid) raise ProcessExecutionException( - f"There are {len(self._pending_entries)} " + f"There are {num_remaning_entries} " f"entries not added to the index correctly." ) @@ -224,7 +249,6 @@ def from_string(cls, data_content: str) -> "BasePack": return pack - @abstractmethod def delete_entry(self, entry: EntryType): r"""Remove the entry from the pack. @@ -234,7 +258,14 @@ def delete_entry(self, entry: EntryType): Returns: None """ - raise NotImplementedError + self._data_store.delete_entry(tid=entry.tid) + + # update basic index + self._index.remove_entry(entry) + + # set other index invalid + self._index.turn_link_index_switch(on=False) + self._index.turn_group_index_switch(on=False) def add_entry( self, entry: Union[Entry, int], component_name: Optional[str] = None @@ -281,7 +312,7 @@ def add_all_remaining_entries(self, component: Optional[str] = None): Returns: None """ - for entry, c in list(self._pending_entries.values()): + for entry, c in list(self._pending_entries.items()): c_ = component if component else c self.add_entry(entry, c_) self._pending_entries.clear() @@ -430,20 +461,133 @@ def on_entry_creation( # Use the auto-inferred control component. c = self.__control_component + def entry_getter(cls: Entry, attr_name: str, field_type): + """A getter function for dataclass fields of entry object. + When the field contains ``tid``s, we will convert them to entry + object on the fly. + """ + data_store_ref = ( + cls.pack._data_store # pylint: disable=protected-access + ) + attr_val = data_store_ref.get_attribute( + tid=cls.tid, attr_name=attr_name + ) + if field_type in (FList, FDict): + # Generate FList/FDict object on the fly + return field_type(parent_entry=cls, data=attr_val) + try: + # TODO: Find a better solution to determine if a field is Entry + # Convert tid to entry object on the fly + if isinstance(attr_val, int): + # Single pack entry + return cls.pack.get_entry(tid=attr_val) + elif ( + isinstance(attr_val, tuple) + and len(attr_val) == 2 + and all(isinstance(element, int) for element in attr_val) + and hasattr(cls.pack, "get_subentry") + ): + # Multi pack entry + return cls.pack.get_subentry(*attr_val) + except KeyError: + pass + return attr_val + + def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): + """A setter function for dataclass fields of entry object. + When the value contains entry objects, we will convert them into + ``tid``s before storing to ``DataStore``. + """ + attr_value: Any + data_store_ref = ( + cls.pack._data_store # pylint: disable=protected-access + ) + if field_type is FList: + attr_value = [ + entry.tid if isinstance(entry, Entry) else entry + for entry in value + ] + elif field_type is FDict: + attr_value = { + key: entry.tid if isinstance(entry, Entry) else entry + for key, entry in value.items() + } + elif isinstance(value, Entry): + attr_value = ( + value.tid + if value.pack.pack_id == cls.pack.pack_id + else (value.pack.pack_id, value.tid) + ) + else: + attr_value = value + data_store_ref.set_attribute( + tid=cls.tid, attr_name=attr_name, attr_value=attr_value + ) + + # Save the input entry object in DataStore + self._save_entry_to_data_store(entry=entry) + + # Register property functions for all dataclass fields. + for name, field in entry.__dataclass_fields__.items(): + field_type = get_origin(field.type) + setattr( + type(entry), + name, + property( + fget=partial( + entry_getter, attr_name=name, field_type=field_type + ), + fset=partial( + entry_setter, attr_name=name, field_type=field_type + ), + ), + ) + # Record that this entry hasn't been added to the index yet. - self._pending_entries[entry.tid] = entry, c + self._pending_entries[entry.tid] = c # TODO: how to make this return the precise type here? def get_entry(self, tid: int) -> EntryType: - r"""Look up the entry_index with key ``ptr``. Specific implementation + r"""Look up the entry_index with ``tid``. Specific implementation depends on the actual class.""" - entry: EntryType = self._index.get_entry(tid) + try: + # Try to find entry in DataIndex + entry: EntryType = self._index.get_entry(tid) + except KeyError: + # Find entry in DataStore + entry = self._get_entry_from_data_store(tid=tid) if entry is None: raise KeyError( f"There is no entry with tid '{tid}'' in this datapack" ) return entry + def get_entry_raw(self, tid: int) -> List: + r"""Retrieve the raw entry data in list format from DataStore.""" + return self._data_store.get_entry(tid=tid)[0] + + @abstractmethod + def _save_entry_to_data_store(self, entry: Entry): + r"""Save an existing entry object into DataStore""" + raise NotImplementedError + + @abstractmethod + def _get_entry_from_data_store(self, tid: int) -> EntryType: + r"""Generate a class object from entry data in DataStore""" + raise NotImplementedError + + @property + @abstractmethod + def links(self): + r"""A List container of all links in this data pack.""" + raise NotImplementedError + + @property + @abstractmethod + def groups(self): + r"""A List container of all groups in this pack.""" + raise NotImplementedError + @abstractmethod def get_data( self, context_type, request, skip_k diff --git a/forte/data/base_store.py b/forte/data/base_store.py index 6981858ba..15349db62 100644 --- a/forte/data/base_store.py +++ b/forte/data/base_store.py @@ -20,6 +20,7 @@ class BaseStore: + # pylint: disable=too-many-public-methods r"""The base class which will be used by :class:`~forte.data.data_store.DataStore`.""" def __init__(self): @@ -123,7 +124,14 @@ def _deserialize( ) @abstractmethod - def add_annotation_raw(self, type_name: str, begin: int, end: int) -> int: + def add_annotation_raw( + self, + type_name: str, + begin: int, + end: int, + tid: Optional[int] = None, + allow_duplicate: bool = True, + ) -> int: r"""This function adds an annotation entry with ``begin`` and ``end`` indices to the ``type_name`` sorted list in ``self.__elements``, returns the ``tid`` for the inserted entry. @@ -132,6 +140,11 @@ def add_annotation_raw(self, type_name: str, begin: int, end: int) -> int: type_name: The index of Annotation sorted list in ``self.__elements``. begin: Begin index of the entry. end: End index of the entry. + tid: ``tid`` of the Annotation entry that is being added. + It's optional, and it will be auto-assigned if not given. + allow_duplicate: Whether we allow duplicate in the DataStore. When + it's set to False, the function will return the ``tid`` of + existing entry if a duplicate is found. Default value is True. Returns: ``tid`` of the entry. """ @@ -139,7 +152,11 @@ def add_annotation_raw(self, type_name: str, begin: int, end: int) -> int: @abstractmethod def add_link_raw( - self, type_name: str, parent_tid: int, child_tid: int + self, + type_name: str, + parent_tid: int, + child_tid: int, + tid: Optional[int] = None, ) -> Tuple[int, int]: r"""This function adds a link entry with ``parent_tid`` and ``child_tid`` to the ``type_name`` list in ``self.__elements``, returns the ``tid`` and the @@ -150,6 +167,8 @@ def add_link_raw( type_name: The index of Link list in ``self.__elements``. parent_tid: ``tid`` of the parent entry. child_tid: ``tid`` of the child entry. + tid: ``tid`` of the Link entry that is being added. + It's optional, and it will be auto-assigned if not given. Returns: ``tid`` of the entry and its index in the ``type_name`` list. @@ -159,7 +178,7 @@ def add_link_raw( @abstractmethod def add_group_raw( - self, type_name: str, member_type: str + self, type_name: str, member_type: str, tid: Optional[int] = None ) -> Tuple[int, int]: r"""This function adds a group entry with ``member_type`` to the ``type_name`` list in ``self.__elements``, returns the ``tid`` and the @@ -176,6 +195,204 @@ def add_group_raw( """ raise NotImplementedError + @abstractmethod + def add_generics_raw( + self, type_name: str, tid: Optional[int] = None + ) -> Tuple[int, int]: + r"""This function adds a generics entry with ``type_name`` to the + current data store object. Returns the ``tid`` and the ``index_id`` + for the inserted entry in the list. This ``index_id`` is the index + of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new Generics. + tid: ``tid`` of generics entry. + + Returns: + ``tid`` of the entry and its index in the (``type_id``)th list. + + """ + raise NotImplementedError + + @abstractmethod + def add_audio_annotation_raw( + self, + type_name: str, + begin: int, + end: int, + tid: Optional[int] = None, + allow_duplicate=True, + ) -> int: + + r""" + This function adds an audio annotation entry with ``begin`` and ``end`` + indices to current data store object. Returns the ``tid`` for the + inserted entry. + + Args: + type_name: The fully qualified type name of the new AudioAnnotation. + begin: Begin index of the entry. + end: End index of the entry. + tid: ``tid`` of the Annotation entry that is being added. + It's optional, and it will be + auto-assigned if not given. + allow_duplicate: Whether we allow duplicate in the DataStore. When + it's set to False, the function will return the ``tid`` of + existing entry if a duplicate is found. Default value is True. + + Returns: + ``tid`` of the entry. + """ + raise NotImplementedError + + @abstractmethod + def add_image_annotation_raw( + self, + type_name: str, + image_payload_idx: int, + tid: Optional[int] = None, + ) -> int: + + r""" + This function adds an image annotation entry with ``image_payload_idx`` + indices to current data store object. Returns the ``tid`` for the + inserted entry. + + Args: + type_name: The fully qualified type name of the new AudioAnnotation. + image_payload_idx: the index of the image payload. + tid: ``tid`` of the Annotation entry that is being added. + It's optional, and it will be + auto-assigned if not given. + + Returns: + ``tid`` of the entry. + """ + raise NotImplementedError + + @abstractmethod + def add_grid_raw( + self, + type_name: str, + image_payload_idx: int, + tid: Optional[int] = None, + ) -> int: + + r""" + This function adds an image annotation entry with ``image_payload_idx`` + indices to current data store object. Returns the ``tid`` for the + inserted entry. + + Args: + type_name: The fully qualified type name of the new grid. + image_payload_idx: the index of the image payload. + tid: ``tid`` of the Annotation entry that is being added. + It's optional, and it will be + auto-assigned if not given. + + Returns: + ``tid`` of the entry. + """ + raise NotImplementedError + + @abstractmethod + def add_multipack_generic_raw( + self, type_name: str, tid: Optional[int] = None + ) -> Tuple[int, int]: + r"""This function adds a multi pack generic entry with ``type_name`` to + the current data store object. Returns the ``tid`` and the ``index_id`` + for the inserted entry in the list. This ``index_id`` is the index + of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new Generics. + tid: ``tid`` of multi pack generic entry. + + Returns: + ``tid`` of the entry and its index in the (``type_id``)th list. + + """ + raise NotImplementedError + + @abstractmethod + def add_multipack_link_raw( + self, + type_name: str, + parent_pack_id: int, + parent_tid: int, + child_pack_id: int, + child_tid: int, + tid: Optional[int] = None, + ) -> Tuple[int, int]: + r"""This function adds a multi pack link entry with ``parent_tid`` and + ``child_tid`` to current data store object. Returns the ``tid`` and + the ``index_id`` for the inserted entry in the list. This ``index_id`` + is the index of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new + ``MultiPackLink``. + parent_pack_id: ``pack_id`` of the parent entry. + parent_tid: ``tid`` of the parent entry. + child_pack_id: ``pack_id`` of the child entry. + child_tid: ``tid`` of the child entry. + tid: ``tid`` of the ``MultiPackLink`` entry that is being added. + It's optional, and it will be auto-assigned if not given. + + Returns: + ``tid`` of the entry and its index in the ``type_name`` list. + """ + raise NotImplementedError + + @abstractmethod + def add_multipack_group_raw( + self, type_name: str, member_type: str, tid: Optional[int] = None + ) -> Tuple[int, int]: + r"""This function adds a multi pack group entry with ``member_type`` to + the current data store object. Returns the ``tid`` and the ``index_id`` + for the inserted entry in the list. This ``index_id`` is the index + of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new + ``MultiPackGroup``. + member_type: Fully qualified name of its members. + tid: ``tid`` of the ``MultiPackGroup`` entry that is being added. + It's optional, and it will be auto-assigned if not given. + + Returns: + ``tid`` of the entry and its index in the (``type_id``)th list. + """ + raise NotImplementedError + + @abstractmethod + def all_entries(self, entry_type_name: str) -> Iterator[List]: + """ + Retrieve all entry data of entry type ``entry_type_name`` and + entries of subclasses of entry type ``entry_type_name``. + + Args: + entry_type_name (str): the type name of entries that the User wants to retrieve. + + Yields: + Iterator of raw entry data in list format. + """ + raise NotImplementedError + + @abstractmethod + def num_entries(self, entry_type_name: str) -> int: + """ + Compute the number of entries of given ``entry_type_name`` and + entries of subclasses of entry type ``entry_type_name``. + + Args: + entry_type_name (str): the type name of entries that the User wants to get its count. + + Returns: + The number of entries of given ``entry_type_name``. + """ + raise NotImplementedError + @abstractmethod def set_attribute(self, tid: int, attr_name: str, attr_value: Any): r"""This function locates the entry data with ``tid`` and sets its diff --git a/forte/data/data_pack.py b/forte/data/data_pack.py index 9e0de82aa..d18fc2d52 100644 --- a/forte/data/data_pack.py +++ b/forte/data/data_pack.py @@ -27,9 +27,6 @@ Callable, Tuple, ) -from functools import partial -from typing_inspect import get_origin -from packaging.version import Version import numpy as np from sortedcontainers import SortedList @@ -40,9 +37,10 @@ from forte.common.constants import TID_INDEX from forte.data import data_utils_io from forte.data.data_store import DataStore +from forte.data.entry_converter import EntryConverter from forte.data.base_pack import BaseMeta, BasePack from forte.data.index import BaseIndex -from forte.data.ontology.core import Entry, FList, FDict +from forte.data.ontology.core import Entry from forte.data.ontology.core import EntryType from forte.data.ontology.top import ( Annotation, @@ -57,7 +55,6 @@ from forte.data.span import Span from forte.data.types import ReplaceOperationsType, DataRequest from forte.utils import get_class, get_full_module_name -from forte.version import PACK_ID_COMPATIBLE_VERSION, DEFAULT_PACK_VERSION logger = logging.getLogger(__name__) @@ -197,21 +194,6 @@ def __setstate__(self, state): 3) initialize the indexes. 4) Obtain the pack ids. """ - # Pack version checking. We will no longer provide support for - # serialized DataPack whose "pack_version" is less than - # PACK_ID_COMPATIBLE_VERSION. - pack_version: str = ( - state["pack_version"] - if "pack_version" in state - else DEFAULT_PACK_VERSION - ) - if Version(pack_version) < Version(PACK_ID_COMPATIBLE_VERSION): - raise ValueError( - "The DataPack cannot be deserialized because its version " - f"{pack_version} is outdated. We only support DataPack with " - f"version greater or equal to {PACK_ID_COMPATIBLE_VERSION}" - ) - self._entry_converter = EntryConverter() super().__setstate__(state) @@ -270,7 +252,7 @@ def all_annotations(self) -> Iterator[Annotation]: for entry in self._data_store.all_entries( "forte.data.ontology.top.Annotation" ): - yield self.get_entry(tid=entry[TID_INDEX]) + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_annotations(self) -> int: @@ -296,7 +278,7 @@ def all_links(self) -> Iterator[Link]: for entry in self._data_store.all_entries( "forte.data.ontology.top.Link" ): - yield self.get_entry(tid=entry[TID_INDEX]) + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_links(self) -> int: @@ -320,7 +302,7 @@ def all_groups(self) -> Iterator[Group]: for entry in self._data_store.all_entries( "forte.data.ontology.top.Group" ): - yield self.get_entry(tid=entry[TID_INDEX]) + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_groups(self): @@ -343,7 +325,7 @@ def all_generic_entries(self) -> Iterator[Generics]: for entry in self._data_store.all_entries( "forte.data.ontology.top.Generics" ): - yield self.get_entry(tid=entry[TID_INDEX]) + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_generics_entries(self): @@ -367,7 +349,7 @@ def all_audio_annotations(self) -> Iterator[AudioAnnotation]: for entry in self._data_store.all_entries( "forte.data.ontology.top.AudioAnnotation" ): - yield self.get_entry(tid=entry[TID_INDEX]) + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_audio_annotations(self): @@ -413,7 +395,7 @@ def audio_annotations(self): """ return SortedList(self.all_audio_annotations) - @property # type: ignore + @property def links(self): """ A List container of all links in this data pack. @@ -422,18 +404,9 @@ def links(self): type :class:`~forte.data.ontology.top.Link`. """ - # TODO: Right now we create a new variable `_links` here to avoid - # conflicts from BasePack and MultiPack. After DataStore is fully - # integrated with MultiPack, we should reconsider the design here. - if isinstance(self, DataPack): - self._links = SortedList(self.all_links) - return self._links - - @links.setter - def links(self, val): - self._links = val - - @property # type: ignore + return SortedList(self.all_links) + + @property def groups(self): """ A List container of all groups in this data pack. @@ -442,12 +415,7 @@ def groups(self): type :class:`~forte.data.ontology.top.Group`. """ - # TODO: Right now we create a new variable `_groups` here to avoid - # conflicts from BasePack and MultiPack. After DataStore is fully - # integrated with MultiPack, we should reconsider the design here. - if isinstance(self, DataPack): - self._groups = SortedList(self.all_links) - return self._groups + return SortedList(self.all_groups) @groups.setter def groups(self, val): @@ -786,14 +754,7 @@ def delete_entry(self, entry: EntryType): object to be deleted from the pack. """ - self._data_store.delete_entry(tid=entry.tid) - - # update basic index - self._index.remove_entry(entry) - - # set other index invalid - self._index.turn_link_index_switch(on=False) - self._index.turn_group_index_switch(on=False) + super().delete_entry(entry=entry) self._index.deactivate_coverage_index() @classmethod @@ -1420,7 +1381,7 @@ def require_annotations(entry_class=Annotation) -> bool: range_annotation=range_annotation # type: ignore and (range_annotation.begin, range_annotation.end), ): - entry: EntryType = self.get_entry(tid=entry_data[TID_INDEX]) + entry: Entry = self.get_entry(tid=entry_data[TID_INDEX]) # Filter by components if components is not None: if not self.is_created_by(entry, components): @@ -1436,7 +1397,7 @@ def require_annotations(entry_class=Annotation) -> bool: ): continue - yield entry + yield entry # type: ignore except ValueError: # type_name does not exist in DataStore yield from [] @@ -1452,129 +1413,13 @@ def update(self, datapack: "DataPack"): # better solution. self.__dict__.update(datapack.__dict__) - def get_entry(self, tid: int) -> EntryType: - r"""Look up the entry_index with ``tid``. Specific implementation - depends on the actual class.""" - try: - # Try to find entry in DataIndex - entry: EntryType = self._index.get_entry(tid) - except KeyError: - # Find entry in DataStore - entry = self._entry_converter.get_entry_object(tid, self) - if entry is None: - raise KeyError( - f"There is no entry with tid '{tid}'' in this datapack" - ) - return entry - - def get_entry_raw(self, tid: int) -> List: - r"""Retrieve the raw entry data in list format from DataStore.""" - return self._data_store.get_entry(tid=tid)[0] - - def on_entry_creation( - self, entry: Entry, component_name: Optional[str] = None - ): - """ - Call this when adding a new entry, will be called - in :class:`~forte.data.ontology.core.Entry` when - its `__init__` function is called. - - Here we override BasePack.on_entry_creation() to make sure each new - entry is stored into ``DataStore`` on creation. - - Args: - entry: The entry to be added. - component_name: A name to record that the entry is created by - this component. - - Returns: - - """ - c = component_name - - if c is None: - # Use the auto-inferred control component. - c = self.get_control_component() - - def entry_getter(cls: Entry, attr_name: str, field_type): - """A getter function for dataclass fields of entry object. - When the field contains ``tid``s, we will convert them to entry - object on the fly. - """ - data_store_ref = ( - cls.pack._data_store # pylint: disable=protected-access - ) - attr_val = data_store_ref.get_attribute( - tid=cls.tid, attr_name=attr_name - ) - if field_type in (FList, FDict): - # Generate FList/FDict object on the fly - return field_type(parent_entry=cls, data=attr_val) - try: - # TODO: Find a better solution to determine if a field is Entry - if isinstance(attr_val, int): - # Convert tid to entry object on the fly - return cls.pack.get_entry(tid=attr_val) - except KeyError: - pass - return attr_val - - def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): - """A setter function for dataclass fields of entry object. - When the value contains entry objects, we will convert them into - ``tid``s before storing to ``DataStore``. - """ - attr_value: Any - data_store_ref = ( - cls.pack._data_store # pylint: disable=protected-access - ) - if field_type is FList: - attr_value = [ - entry.tid if isinstance(entry, Entry) else entry - for entry in value - ] - elif field_type is FDict: - attr_value = { - key: entry.tid if isinstance(entry, Entry) else entry - for key, entry in value.items() - } - elif isinstance(value, Entry): - attr_value = value.tid - else: - attr_value = value - data_store_ref.set_attribute( - tid=cls.tid, attr_name=attr_name, attr_value=attr_value - ) - - # Save the input entry object in DataStore + def _save_entry_to_data_store(self, entry: Entry): + r"""Save an existing entry object into DataStore""" self._entry_converter.save_entry_object(entry=entry, pack=self) - # Register property functions for all dataclass fields. - for name, field in entry.__dataclass_fields__.items(): - field_type = get_origin(field.type) - setattr( - type(entry), - name, - property( - fget=partial( - entry_getter, attr_name=name, field_type=field_type - ), - fset=partial( - entry_setter, attr_name=name, field_type=field_type - ), - ), - ) - - # Record that this entry hasn't been added to the index yet. - self._pending_entries[entry.tid] = entry.tid, c - - def __del__(self): - super().__del__() - # Remove all the remaining tids in _pending_entries. - tids: List = list(self._pending_entries.keys()) - for tid in tids: - self._pending_entries.pop(tid) - self._data_store.delete_entry(tid=tid) + def _get_entry_from_data_store(self, tid: int) -> EntryType: + r"""Generate a class object from entry data in DataStore""" + return self._entry_converter.get_entry_object(tid=tid, pack=self) class DataIndex(BaseIndex): @@ -1940,147 +1785,3 @@ def in_audio_span(self, inner_entry: Union[int, Entry], span: Span) -> bool: # check here. return False return inner_begin >= span.begin and inner_end <= span.end - - -class EntryConverter: - r""" - Facilitate the conversion between entry data in list format from - ``DataStore`` and entry class object. - """ - - def __init__(self) -> None: - # Mapping from entry's tid to the entry objects for caching - self._entry_dict: Dict[int, Entry] = {} - - def save_entry_object( - self, entry: Entry, pack: DataPack, allow_duplicate: bool = True - ): - """ - Save an existing entry object into DataStore. - """ - # Check if the entry is already stored - data_store_ref = pack._data_store # pylint: disable=protected-access - try: - data_store_ref.get_entry(tid=entry.tid) - logger.info( - "The entry with tid=%d is already saved into DataStore", - entry.tid, - ) - return - except KeyError: - # The entry is not found in DataStore - pass - - # Create a new registry in DataStore based on entry's type - if isinstance(entry, Annotation): - data_store_ref.add_annotation_raw( - type_name=entry.entry_type(), - begin=entry.begin, - end=entry.end, - tid=entry.tid, - allow_duplicate=allow_duplicate, - ) - elif isinstance(entry, Link): - data_store_ref.add_link_raw( - type_name=entry.entry_type(), - parent_tid=entry.parent, - child_tid=entry.child, - tid=entry.tid, - ) - elif isinstance(entry, Group): - data_store_ref.add_group_raw( - type_name=entry.entry_type(), - member_type=get_full_module_name(entry.MemberType), - tid=entry.tid, - ) - elif isinstance(entry, Generics): - data_store_ref.add_generics_raw( - type_name=entry.entry_type(), - tid=entry.tid, - ) - elif isinstance(entry, AudioAnnotation): - data_store_ref.add_audio_annotation_raw( - type_name=entry.entry_type(), - begin=entry.begin, - end=entry.end, - tid=entry.tid, - allow_duplicate=allow_duplicate, - ) - elif isinstance(entry, ImageAnnotation): - data_store_ref.add_image_annotation_raw( - type_name=entry.entry_type(), - image_payload_idx=entry.image_payload_idx, - tid=entry.tid, - allow_duplicate=allow_duplicate, - ) - elif isinstance(entry, Grids): - data_store_ref.add_grid_raw( - type_name=entry.entry_type(), - image_payload_idx=entry.image_payload_idx, - tid=entry.tid, - allow_duplicate=allow_duplicate, - ) - else: - raise ValueError( - f"Invalid entry type {type(entry)}. A valid entry " - f"should be an instance of Annotation, Link, Group, Generics " - "or AudioAnnotation." - ) - - # Store all the dataclass attributes to DataStore - for attribute in entry.__dataclass_fields__: - value = getattr(entry, attribute, None) - if not value: - continue - if isinstance(value, Entry): - value = value.tid - elif isinstance(value, FDict): - value = {key: val.tid for key, val in value.items()} - elif isinstance(value, FList): - value = [val.tid for val in value] - data_store_ref.set_attribute( - tid=entry.tid, attr_name=attribute, attr_value=value - ) - - # Cache the stored entry and its tid - self._entry_dict[entry.tid] = entry - - def get_entry_object(self, tid: int, pack: DataPack) -> EntryType: - """ - Convert a tid to its corresponding entry object. - """ - - # Check if the tid is cached - if tid in self._entry_dict: - return self._entry_dict[tid] # type: ignore - - data_store_ref = pack._data_store # pylint: disable=protected-access - entry_data, entry_type = data_store_ref.get_entry(tid=tid) - entry_class = get_class(entry_type) - entry: Entry - # Here the entry arguments are optional (begin, end, parent, ...) and - # the value can be arbitrary since they will all be routed to DataStore. - if issubclass(entry_class, (Annotation, AudioAnnotation)): - entry = entry_class(pack=pack, begin=0, end=0) - elif issubclass(entry_class, (Link, Group, Generics)): - entry = entry_class(pack=pack) - else: - raise ValueError( - f"Invalid entry type {type(entry_class)}. A valid entry " - f"should be an instance of Annotation, Link, Group, Generics " - "or AudioAnnotation." - ) - - # TODO: Remove the new tid and direct the entry object to the correct - # tid. The implementation here is a little bit hacky. Will need a stable - # solution in future. - # pylint: disable=protected-access - if entry.tid in self._entry_dict: - self._entry_dict.pop(entry.tid) - if entry.tid in pack._pending_entries: - pack._pending_entries.pop(entry.tid) - data_store_ref.delete_entry(tid=entry.tid) - entry._tid = entry_data[TID_INDEX] - - self._entry_dict[tid] = entry - return entry # type: ignore diff --git a/forte/data/data_store.py b/forte/data/data_store.py index d0180d507..70cf2d6ae 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -29,6 +29,9 @@ ImageAnnotation, Link, Generics, + MultiPackGeneric, + MultiPackGroup, + MultiPackLink, ) from forte.data.ontology.core import Entry, FList, FDict from forte.common import constants @@ -917,7 +920,16 @@ def _add_entry_raw( except KeyError: self.__elements[type_name] = SortedList(key=sorting_fn) self.__elements[type_name].add(entry) - elif entry_type in [Link, Group, Generics, ImageAnnotation, Grids]: + elif entry_type in [ + Link, + Group, + Generics, + ImageAnnotation, + Grids, + MultiPackLink, + MultiPackGroup, + MultiPackGeneric, + ]: try: self.__elements[type_name].append(entry) except KeyError: @@ -1025,8 +1037,8 @@ def add_audio_annotation_raw( """ # We should create the `entry data` with the format # [begin, end, tid, type_id, None, ...]. - # A helper function _new_annotation() can be used to generate a - # annotation type entry data with default fields. + # A helper function _new_audio_annotation() can be used to generate an + # audio annotation type entry data with default fields. # A reference to the entry should be store in both self.__elements and # self.__tid_ref_dict. entry = self._new_audio_annotation(type_name, begin, end, tid) @@ -1043,7 +1055,6 @@ def add_image_annotation_raw( type_name: str, image_payload_idx: int, tid: Optional[int] = None, - allow_duplicate=True, ) -> int: r""" @@ -1057,26 +1068,17 @@ def add_image_annotation_raw( tid: ``tid`` of the Annotation entry that is being added. It's optional, and it will be auto-assigned if not given. - allow_duplicate: Whether we allow duplicate in the DataStore. When - it's set to False, the function will return the ``tid`` of - existing entry if a duplicate is found. Default value is True. Returns: ``tid`` of the entry. """ # We should create the `entry data` with the format - # [begin, end, tid, type_id, None, ...]. - # A helper function _new_annotation() can be used to generate a - # annotation type entry data with default fields. + # [image_payload_index, None, tid, type_name, None, ...]. + # A helper function _new_image_annotation() can be used to generate an + # image annotation type entry data with default fields. # A reference to the entry should be store in both self.__elements and # self.__tid_ref_dict. entry = self._new_image_annotation(type_name, image_payload_idx, tid) - - if not allow_duplicate: - tid_search_result = self._get_existing_ann_entry_tid(entry) - # if found existing entry - if tid_search_result != -1: - return tid_search_result return self._add_entry_raw(AudioAnnotation, type_name, entry) def add_grid_raw( @@ -1084,7 +1086,6 @@ def add_grid_raw( type_name: str, image_payload_idx: int, tid: Optional[int] = None, - allow_duplicate=True, ) -> int: r""" @@ -1098,26 +1099,17 @@ def add_grid_raw( tid: ``tid`` of the Annotation entry that is being added. It's optional, and it will be auto-assigned if not given. - allow_duplicate: Whether we allow duplicate in the DataStore. When - it's set to False, the function will return the ``tid`` of - existing entry if a duplicate is found. Default value is True. Returns: ``tid`` of the entry. """ # We should create the `entry data` with the format - # [begin, end, tid, type_id, None, ...]. - # A helper function _new_annotation() can be used to generate a - # annotation type entry data with default fields. + # [image_payload_index, None, tid, type_name, None, ...]. + # A helper function _new_grid() can be used to generate a + # grid type entry data with default fields. # A reference to the entry should be store in both self.__elements and # self.__tid_ref_dict. entry = self._new_grid(type_name, image_payload_idx, tid) - - if not allow_duplicate: - tid_search_result = self._get_existing_ann_entry_tid(entry) - # if found existing entry - if tid_search_result != -1: - return tid_search_result return self._add_entry_raw(Grids, type_name, entry) def _get_existing_ann_entry_tid(self, entry: List[Any]): @@ -1228,6 +1220,87 @@ def add_generics_raw( entry = self._new_generics(type_name, tid) return self._add_entry_raw(Generics, type_name, entry) + def add_multipack_generic_raw( + self, type_name: str, tid: Optional[int] = None + ) -> Tuple[int, int]: + r"""This function adds a multi pack generic entry with ``type_name`` to + the current data store object. Returns the ``tid`` and the ``index_id`` + for the inserted entry in the list. This ``index_id`` is the index + of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new Generics. + tid: ``tid`` of multi pack generic entry. + + Returns: + ``tid`` of the entry and its index in the (``type_id``)th list. + + """ + tid: int = self._new_tid() if tid is None else tid + entry = [None, None, tid, type_name] + entry += self._default_attributes_for_type(type_name) + return self._add_entry_raw(MultiPackGeneric, type_name, entry) + + def add_multipack_link_raw( + self, + type_name: str, + parent_pack_id: int, + parent_tid: int, + child_pack_id: int, + child_tid: int, + tid: Optional[int] = None, + ) -> Tuple[int, int]: + r"""This function adds a multi pack link entry with ``parent_tid`` and + ``child_tid`` to current data store object. Returns the ``tid`` and + the ``index_id`` for the inserted entry in the list. This ``index_id`` + is the index of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new + ``MultiPackLink``. + parent_pack_id: ``pack_id`` of the parent entry. + parent_tid: ``tid`` of the parent entry. + child_pack_id: ``pack_id`` of the child entry. + child_tid: ``tid`` of the child entry. + tid: ``tid`` of the ``MultiPackLink`` entry that is being added. + It's optional, and it will be auto-assigned if not given. + + Returns: + ``tid`` of the entry and its index in the ``type_name`` list. + """ + tid: int = self._new_tid() if tid is None else tid + entry: List[Any] = [ + [parent_pack_id, parent_tid], + [child_pack_id, child_tid], + tid, + type_name, + ] + entry += self._default_attributes_for_type(type_name) + return self._add_entry_raw(MultiPackLink, type_name, entry) + + def add_multipack_group_raw( + self, type_name: str, member_type: str, tid: Optional[int] = None + ) -> Tuple[int, int]: + r"""This function adds a multi pack group entry with ``member_type`` to + the current data store object. Returns the ``tid`` and the ``index_id`` + for the inserted entry in the list. This ``index_id`` is the index + of the entry in the ``type_name`` list. + + Args: + type_name: The fully qualified type name of the new + ``MultiPackGroup``. + member_type: Fully qualified name of its members. + tid: ``tid`` of the ``MultiPackGroup`` entry that is being added. + It's optional, and it will be auto-assigned if not given. + + Returns: + ``tid`` of the entry and its index in the (``type_id``)th list. + """ + tid: int = self._new_tid() if tid is None else tid + entry = [member_type, [], tid, type_name] + entry += self._default_attributes_for_type(type_name) + return self._add_entry_raw(MultiPackGroup, type_name, entry) + def set_attribute(self, tid: int, attr_name: str, attr_value: Any): r"""This function locates the entry data with ``tid`` and sets its ``attr_name`` with `attr_value`. It first finds ``attr_id`` according diff --git a/forte/data/entry_converter.py b/forte/data/entry_converter.py new file mode 100644 index 000000000..3350609cf --- /dev/null +++ b/forte/data/entry_converter.py @@ -0,0 +1,209 @@ +# Copyright 2022 The Forte Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict +from forte.common.constants import TID_INDEX +from forte.data.base_pack import PackType +from forte.data.ontology.core import Entry, FList, FDict +from forte.data.ontology.core import EntryType +from forte.data.ontology.top import ( + Annotation, + Link, + Group, + Generics, + AudioAnnotation, + ImageAnnotation, + Grids, + MultiPackGeneric, + MultiPackGroup, + MultiPackLink, +) +from forte.utils import get_class, get_full_module_name + +logger = logging.getLogger(__name__) + +__all__ = ["EntryConverter"] + + +class EntryConverter: + r""" + Facilitate the conversion between entry data in list format from + ``DataStore`` and entry class object. + """ + + def __init__(self) -> None: + # Mapping from entry's tid to the entry objects for caching + self._entry_dict: Dict[int, Entry] = {} + + def save_entry_object( + self, entry: Entry, pack: PackType, allow_duplicate: bool = True + ): + """ + Save an existing entry object into DataStore. + """ + # Check if the entry is already stored + data_store_ref = pack._data_store # pylint: disable=protected-access + try: + data_store_ref.get_entry(tid=entry.tid) + logger.info( + "The entry with tid=%d is already saved into DataStore", + entry.tid, + ) + return + except KeyError: + # The entry is not found in DataStore + pass + + # Create a new registry in DataStore based on entry's type + if isinstance(entry, Annotation): + data_store_ref.add_annotation_raw( + type_name=entry.entry_type(), + begin=entry.begin, + end=entry.end, + tid=entry.tid, + allow_duplicate=allow_duplicate, + ) + elif isinstance(entry, Link): + data_store_ref.add_link_raw( + type_name=entry.entry_type(), + parent_tid=entry.parent, + child_tid=entry.child, + tid=entry.tid, + ) + elif isinstance(entry, Group): + data_store_ref.add_group_raw( + type_name=entry.entry_type(), + member_type=get_full_module_name(entry.MemberType), + tid=entry.tid, + ) + elif isinstance(entry, Generics): + data_store_ref.add_generics_raw( + type_name=entry.entry_type(), + tid=entry.tid, + ) + elif isinstance(entry, AudioAnnotation): + data_store_ref.add_audio_annotation_raw( + type_name=entry.entry_type(), + begin=entry.begin, + end=entry.end, + tid=entry.tid, + allow_duplicate=allow_duplicate, + ) + elif isinstance(entry, ImageAnnotation): + data_store_ref.add_image_annotation_raw( + type_name=entry.entry_type(), + image_payload_idx=entry.image_payload_idx, + tid=entry.tid, + ) + elif isinstance(entry, Grids): + data_store_ref.add_grid_raw( + type_name=entry.entry_type(), + image_payload_idx=entry.image_payload_idx, + tid=entry.tid, + ) + elif isinstance(entry, MultiPackLink): + data_store_ref.add_multipack_link_raw( + type_name=entry.entry_type(), + parent_pack_id=entry.parent[0], + parent_tid=entry.parent[1], + child_pack_id=entry.child[0], + child_tid=entry.child[1], + tid=entry.tid, + ) + elif isinstance(entry, MultiPackGroup): + data_store_ref.add_multipack_group_raw( + type_name=entry.entry_type(), + member_type=get_full_module_name(entry.MemberType), + tid=entry.tid, + ) + elif isinstance(entry, MultiPackGeneric): + data_store_ref.add_multipack_generic_raw( + type_name=entry.entry_type(), + tid=entry.tid, + ) + else: + raise ValueError( + f"Invalid entry type {type(entry)}. A valid entry " + f"should be an instance of Annotation, Link, Group, Generics " + "or AudioAnnotation." + ) + + # Store all the dataclass attributes to DataStore + for attribute in entry.__dataclass_fields__: + value = getattr(entry, attribute, None) + if not value: + continue + if isinstance(value, Entry): + value = value.tid + elif isinstance(value, FDict): + value = {key: val.tid for key, val in value.items()} + elif isinstance(value, FList): + value = [val.tid for val in value] + data_store_ref.set_attribute( + tid=entry.tid, attr_name=attribute, attr_value=value + ) + + # Cache the stored entry and its tid + self._entry_dict[entry.tid] = entry + + def get_entry_object(self, tid: int, pack: PackType) -> EntryType: + """ + Convert a tid to its corresponding entry object. + """ + + # Check if the tid is cached + if tid in self._entry_dict: + return self._entry_dict[tid] # type: ignore + + data_store_ref = pack._data_store # pylint: disable=protected-access + entry_data, entry_type = data_store_ref.get_entry(tid=tid) + entry_class = get_class(entry_type) + entry: Entry + # Here the entry arguments are optional (begin, end, parent, ...) and + # the value can be arbitrary since they will all be routed to DataStore. + if issubclass(entry_class, (Annotation, AudioAnnotation)): + entry = entry_class(pack=pack, begin=0, end=0) + elif issubclass( + entry_class, + ( + Link, + Group, + Generics, + MultiPackGeneric, + MultiPackGroup, + MultiPackLink, + ), + ): + entry = entry_class(pack=pack) + else: + raise ValueError( + f"Invalid entry type {type(entry_class)}. A valid entry " + f"should be an instance of Annotation, Link, Group, Generics " + "or AudioAnnotation." + ) + + # TODO: Remove the new tid and direct the entry object to the correct + # tid. The implementation here is a little bit hacky. Will need a stable + # solution in future. + # pylint: disable=protected-access + if entry.tid in self._entry_dict: + self._entry_dict.pop(entry.tid) + if entry.tid in pack._pending_entries: + pack._pending_entries.pop(entry.tid) + data_store_ref.delete_entry(tid=entry.tid) + entry._tid = entry_data[TID_INDEX] + + self._entry_dict[tid] = entry + return entry # type: ignore diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index b3d660656..045541040 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -16,16 +16,19 @@ import logging from pathlib import Path -from typing import Dict, List, Set, Union, Iterator, Optional, Type, Any, Tuple +from typing import Dict, List, Union, Iterator, Optional, Type, Any, Tuple import jsonpickle -from sortedcontainers import SortedList from packaging.version import Version +from sortedcontainers import SortedList from forte.common import ProcessExecutionException +from forte.common.constants import TID_INDEX from forte.data.base_pack import BaseMeta, BasePack from forte.data.data_pack import DataPack +from forte.data.data_store import DataStore +from forte.data.entry_converter import EntryConverter from forte.data.index import BaseIndex from forte.data.ontology.core import Entry from forte.data.ontology.core import EntryType @@ -37,7 +40,7 @@ MultiPackGeneric, ) from forte.data.types import DataRequest -from forte.utils import get_class +from forte.utils import get_class, get_full_module_name from forte.version import DEFAULT_PACK_VERSION, PACK_ID_COMPATIBLE_VERSION @@ -84,39 +87,24 @@ def __init__(self, pack_name: Optional[str] = None): # Reference to the real packs. self._packs: List[DataPack] = [] - self.links: SortedList[MultiPackLink] = SortedList() - self.groups: SortedList[MultiPackGroup] = SortedList() - self.generics: SortedList[MultiPackGeneric] = SortedList() - # Used to automatically give name to sub packs. self.__default_pack_prefix = "_pack" + self._data_store = DataStore() + self._entry_converter = EntryConverter() + self._index: MultiIndex = MultiIndex() def __setstate__(self, state): r"""In deserialization, we set up the index and the references to the data packs inside. """ + self._entry_converter = EntryConverter() super().__setstate__(state) - self.links = SortedList(self.links) - self.groups = SortedList(self.groups) - self.generics = SortedList(self.generics) - self._index = MultiIndex() # TODO: index those pointers? - self._index.update_basic_index(list(self.links)) - self._index.update_basic_index(list(self.groups)) - self._index.update_basic_index(list(self.generics)) - - for a in self.links: - a.set_pack(self) - - for a in self.groups: - a.set_pack(self) - - for a in self.generics: - a.set_pack(self) + self._index.update_basic_index(list(iter(self))) # Rebuild the name to index lookup. self._name_index = {n: i for (i, n) in enumerate(self._pack_names)} @@ -155,11 +143,7 @@ def __getstate__(self): state = super().__getstate__() # Do not directly serialize the pack itself. state.pop("_packs") - - state["links"] = list(state["links"]) - state["groups"] = list(state["groups"]) - state["generics"] = list(state["generics"]) - + state.pop("_entry_converter") return state def __iter__(self): @@ -578,18 +562,23 @@ def all_links(self) -> Iterator[MultiPackLink]: :class:`~forte.data.ontology.top.MultiPackLink`. """ - yield from self.links + for entry in self._data_store.all_entries( + "forte.data.ontology.top.MultiPackLink" + ): + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_links(self) -> int: """ - Number of groups in this multi pack. + Number of links in this multi pack. Returns: Number of links. """ - return len(self.groups) + return self._data_store.num_entries( + "forte.data.ontology.top.MultiPackLink" + ) @property def all_groups(self) -> Iterator[MultiPackGroup]: @@ -601,7 +590,10 @@ def all_groups(self) -> Iterator[MultiPackGroup]: :class:`~forte.data.ontology.top.MultiPackGroup`. """ - yield from self.groups + for entry in self._data_store.all_entries( + "forte.data.ontology.top.MultiPackGroup" + ): + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore @property def num_groups(self) -> int: @@ -612,11 +604,56 @@ def num_groups(self) -> int: Number of groups. """ - return len(self.groups) + return self._data_store.num_entries( + "forte.data.ontology.top.MultiPackGroup" + ) @property def generic_entries(self) -> Iterator[MultiPackGeneric]: - yield from self.generics + """ + An iterator of all generics in this multi pack. + + Returns: + Iterator of all generics, of type + :class:`~forte.data.ontology.top.MultiPackGeneric`. + + """ + for entry in self._data_store.all_entries( + "forte.data.ontology.top.MultiPackGeneric" + ): + yield self.get_entry(tid=entry[TID_INDEX]) # type: ignore + + @property + def links(self): + """ + A List container of all links in this multi pack. + + Returns: List of all links, of + type :class:`~forte.data.ontology.top.MultiPackLink`. + + """ + return SortedList(self.all_links) + + @property + def groups(self): + """ + A List container of all groups in this multi pack. + + Returns: List of all groups, of + type :class:`~forte.data.ontology.top.MultiPackGroup`. + + """ + return SortedList(self.all_groups) + + @property + def generics(self): + """ + A SortedList container of all generic entries in this multi pack. + + Returns: SortedList of generics + + """ + return SortedList(self.generic_entries) def add_all_remaining_entries(self, component: Optional[str] = None): """ @@ -738,9 +775,7 @@ def get_cross_pack_data( # TODO: Not finished yet pass - def __add_entry_with_check( - self, entry: EntryType, allow_duplicate: bool = True - ) -> EntryType: + def __add_entry_with_check(self, entry: Union[EntryType, int]) -> EntryType: r"""Internal method to add an :class:`~forte.data.ontology.core.Entry` object to the :class:`~forte.data.multi_pack.MultiPack` object. @@ -753,37 +788,20 @@ def __add_entry_with_check( Returns: The input entry itself """ - if isinstance(entry, MultiPackLink): - target = self.links - elif isinstance(entry, MultiPackGroup): - target = self.groups - elif isinstance(entry, MultiPackGeneric): - target = self.generics - else: - raise ValueError( - f"Invalid entry type {type(entry)} for Multipack. A valid " - f"entry should be an instance of MultiPackLink, MultiPackGroup" - f", or MultiPackGeneric." - ) - - add_new = allow_duplicate or (entry not in target) - - if add_new: - target.add(entry) + if isinstance(entry, int): + # If entry is a TID, convert it to the class object. + entry = self._entry_converter.get_entry_object(tid=entry, pack=self) - # TODO: add the pointers? + # update the data pack index if needed + # TODO: MultiIndex will be deprecated in future + self._index.update_basic_index([entry]) + if self._index.link_index_on and isinstance(entry, MultiPackLink): + self._index.update_link_index([entry]) + if self._index.group_index_on and isinstance(entry, MultiPackGroup): + self._index.update_group_index([entry]) - # update the data pack index if needed - self._index.update_basic_index([entry]) - if self._index.link_index_on and isinstance(entry, MultiPackLink): - self._index.update_link_index([entry]) - if self._index.group_index_on and isinstance(entry, MultiPackGroup): - self._index.update_group_index([entry]) - - self._pending_entries.pop(entry.tid) - return entry - else: - return target[target.index(entry)] + self._pending_entries.pop(entry.tid) + return entry # type: ignore def get( # type: ignore self, @@ -833,40 +851,25 @@ def get( # type: ignore else: entry_type_ = entry_type - entry_iter: Iterator[Entry] - - if not include_sub_type: - entry_iter = self.get_entries_of(entry_type_) - elif issubclass(entry_type_, MultiPackLink): - entry_iter = self.links - elif issubclass(entry_type_, MultiPackGroup): - entry_iter = self.groups - elif issubclass(entry_type_, MultiPackGeneric): - entry_iter = self.generics - else: - raise ValueError( - f"The entry type: {entry_type_} is not supported by MultiPack." - ) - - all_types: Set[Type] - if include_sub_type: - all_types = self._expand_to_sub_types(entry_type_) - if components is not None: if isinstance(components, str): components = [components] - for e in entry_iter: - # Will check for the type matching if sub types are also requested. - if include_sub_type and type(e) not in all_types: - continue - - # Check for the component. - if components is not None: - if not self.is_created_by(e, components): - continue + try: + for entry_data in self._data_store.get( + type_name=get_full_module_name(entry_type_), + include_sub_type=include_sub_type, + ): + entry: Entry = self.get_entry(tid=entry_data[TID_INDEX]) + # Filter by components + if components is not None: + if not self.is_created_by(entry, components): + continue - yield e # type: ignore + yield entry # type: ignore + except ValueError: + # type_name does not exist in DataStore + yield from [] @classmethod def deserialize( @@ -925,7 +928,7 @@ def from_string(cls, data_content: str): return mp - def _add_entry(self, entry: EntryType) -> EntryType: # type: ignore + def _add_entry(self, entry: Union[Entry, int]) -> EntryType: r"""Force add an :class:`forte.data.ontology.core.Entry` object to the :class:`~forte.data.multi_pack.MultiPack` object. @@ -938,41 +941,7 @@ def _add_entry(self, entry: EntryType) -> EntryType: # type: ignore Returns: The input entry itself """ - return self.__add_entry_with_check(entry, True) - - def delete_entry(self, entry: EntryType): - r"""Delete an :class:`~forte.data.ontology.core.Entry` object from the - :class:`~forte.data.multi_pack.MultiPack`. - - Args: - entry: An :class:`~forte.data.ontology.core.Entry` - object to be deleted from the pack. - - """ - if isinstance(entry, MultiPackLink): - target = self.links - elif isinstance(entry, MultiPackGroup): - target = self.groups - elif isinstance(entry, MultiPackGeneric): - target = self.generics - else: - raise ValueError( - f"Invalid entry type {type(entry)}. A valid entry " - f"should be an instance of Annotation, Link, or Group." - ) - - begin = 0 - for i, e in enumerate(target[begin:]): - if e.tid == entry.tid: - target.pop(i + begin) - break - - # update basic index - self._index.remove_entry(entry) - - # set other index invalid - self._index.turn_link_index_switch(on=False) - self._index.turn_group_index_switch(on=False) + return self.__add_entry_with_check(entry) # type: ignore @classmethod def validate_link(cls, entry: EntryType) -> bool: @@ -985,6 +954,14 @@ def validate_group(cls, entry: EntryType) -> bool: def view(self): return copy.deepcopy(self) + def _save_entry_to_data_store(self, entry: Entry): + r"""Save an existing entry object into DataStore""" + self._entry_converter.save_entry_object(entry=entry, pack=self) + + def _get_entry_from_data_store(self, tid: int) -> EntryType: + r"""Generate a class object from entry data in DataStore""" + return self._entry_converter.get_entry_object(tid=tid, pack=self) + class MultiIndex(BaseIndex): pass diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index c74bb6819..9325a8c22 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -379,7 +379,6 @@ def __init__( pack: PackType, members: Optional[Iterable[Entry]] = None, ): # pylint: disable=useless-super-delegation - self._member_type: Type[Entry] = Entry super().__init__(pack, members) def add_member(self, member: Entry): @@ -438,8 +437,8 @@ def __init__( parent: Optional[Entry] = None, child: Optional[Entry] = None, ): - self._parent: Optional[Tuple[int, int]] = None - self._child: Optional[Tuple[int, int]] = None + self._parent: Tuple = (None, None) + self._child: Tuple = (None, None) super().__init__(pack) @@ -449,17 +448,49 @@ def __init__( self.set_child(child) @property - def parent(self) -> Tuple[int, int]: - if self._parent is None: - raise ValueError("Parent is not set for this link.") + def parent(self): + r"""Get ``pack_id`` and ``tid`` of the parent node. To get the object + of the parent node, call :meth:`get_parent`. The function will first + try to retrieve the parent from ``DataStore`` in ``self.pack``. If + this attempt fails, it will directly return the value in ``_parent``. + """ + try: + self._parent = self.pack.get_entry_raw(self.tid)[PARENT_TID_INDEX] + except KeyError: + # self.tid not found in DataStore + pass return self._parent + @parent.setter + def parent(self, val: Tuple): + r"""Setter function of ``parent``. The update will also be populated + into ``DataStore`` in ``self.pack``. + """ + self._parent = val + self.pack.get_entry_raw(self.tid)[PARENT_TID_INDEX] = val + @property - def child(self) -> Tuple[int, int]: - if self._child is None: - raise ValueError("Child is not set for this link.") + def child(self): + r"""Get ``pack_id`` and ``tid`` of the child node. To get the object + of the child node, call :meth:`get_child`. The function will first try + to retrieve the child from ``DataStore`` in ``self.pack``. If + this attempt fails, it will directly return the value in ``_child``. + """ + try: + self._child = self.pack.get_entry_raw(self.tid)[CHILD_TID_INDEX] + except KeyError: + # self.tid not found in DataStore + pass return self._child + @child.setter + def child(self, val: Tuple): + r"""Setter function of ``child``. The update will also be populated + into ``DataStore`` in ``self.pack``. + """ + self._child = val + self.pack.get_entry_raw(self.tid)[CHILD_TID_INDEX] = val + def parent_id(self) -> int: """ Return the ``tid`` of the parent entry. @@ -485,9 +516,9 @@ def parent_pack_id(self) -> int: Returns: The `pack_id` of the parent pack.. """ - if self._parent is None: + if self.parent[0] is None: raise ValueError("Parent is not set for this link.") - return self.pack.packs[self._parent[0]].pack_id + return self.pack.packs[self.parent[0]].pack_id def child_pack_id(self) -> int: """ @@ -496,9 +527,9 @@ def child_pack_id(self) -> int: Returns: The `pack_id` of the child pack. """ - if self._child is None: + if self.child[0] is None: raise ValueError("Child is not set for this link.") - return self.pack.packs[self._child[0]].pack_id + return self.pack.packs[self.child[0]].pack_id def set_parent(self, parent: Entry): r"""This will set the `parent` of the current instance with given Entry. @@ -517,7 +548,7 @@ def set_parent(self, parent: Entry): ) # fix bug/enhancement #559: using pack_id instead of index # self._parent = self.pack.get_pack_index(parent.pack_id), parent.tid - self._parent = parent.pack_id, parent.tid + self.parent = parent.pack_id, parent.tid def set_child(self, child: Entry): r"""This will set the `child` of the current instance with given Entry. @@ -537,7 +568,7 @@ def set_child(self, child: Entry): ) # fix bug/enhancement #559: using pack_id instead of index # self._child = self.pack.get_pack_index(child.pack_id), child.tid - self._child = child.pack_id, child.tid + self.child = child.pack_id, child.tid def get_parent(self) -> Entry: r"""Get the parent entry of the link. @@ -549,7 +580,7 @@ def get_parent(self) -> Entry: if self._parent is None: raise ValueError("The parent of this link is not set.") - pack_idx, parent_tid = self._parent + pack_idx, parent_tid = self.parent return self.pack.get_subentry(pack_idx, parent_tid) def get_child(self) -> Entry: @@ -562,7 +593,7 @@ def get_child(self) -> Entry: if self._child is None: raise ValueError("The parent of this link is not set.") - pack_idx, child_tid = self._child + pack_idx, child_tid = self.child return self.pack.get_subentry(pack_idx, child_tid) @@ -576,7 +607,6 @@ class MultiPackGroup(MultiEntry, BaseGroup[Entry]): def __init__( self, pack: PackType, members: Optional[Iterable[Entry]] = None ): # pylint: disable=useless-super-delegation - self._members: List[Tuple[int, int]] = [] super().__init__(pack) if members is not None: self.add_members(members) @@ -587,15 +617,15 @@ def add_member(self, member: Entry): f"The members of {type(self)} should be " f"instances of {self.MemberType}, but got {type(member)}" ) - - self._members.append( - # fix bug/enhancement 559: use pack_id instead of index - (member.pack_id, member.tid) # self.pack.get_pack_index(..) + self.pack.get_entry_raw(self.tid)[MEMBER_TID_INDEX].append( + (member.pack_id, member.tid) ) def get_members(self) -> List[Entry]: members = [] - for pack_idx, member_tid in self._members: + for pack_idx, member_tid in self.pack.get_entry_raw(self.tid)[ + MEMBER_TID_INDEX + ]: members.append(self.pack.get_subentry(pack_idx, member_tid)) return members diff --git a/forte/processors/nlp/ner_predictor.py b/forte/processors/nlp/ner_predictor.py index 5ed2c220a..d94e2c624 100644 --- a/forte/processors/nlp/ner_predictor.py +++ b/forte/processors/nlp/ner_predictor.py @@ -199,7 +199,7 @@ def pack( for j in range(len(predict_results["Token"]["tid"][i])): tid: int = predict_results["Token"]["tid"][i][j] # type: ignore - orig_token: Token = pack.get_entry(tid) + orig_token: Token = pack.get_entry(tid) # type: ignore ner_tag: str = predict_results["Token"]["ner"][i][j] orig_token.ner = ner_tag diff --git a/tests/forte/data/entry_data_structures_test.py b/tests/forte/data/entry_data_structures_test.py index 6b050c3f4..3a55ea72a 100644 --- a/tests/forte/data/entry_data_structures_test.py +++ b/tests/forte/data/entry_data_structures_test.py @@ -163,7 +163,6 @@ def setUp(self): def test_entry_attribute_mp_pointer(self): mpe: ExampleMPEntry = self.pack.get_single(ExampleMPEntry) self.assertIsInstance(mpe.refer_entry, ExampleEntry) - self.assertIsInstance(mpe.__dict__["refer_entry"], ExampleEntry) serialized_mp = self.pack.to_string(drop_record=True) recovered_mp = MultiPack.from_string(serialized_mp) From f72210a177f43bfabc770223e532b9e756855215 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 16 Jun 2022 17:01:53 -0400 Subject: [PATCH 2/5] Fix some issues --- forte/data/__init__.py | 1 - forte/data/base_pack.py | 23 ++++--- forte/data/base_store.py | 62 ++++++++++-------- forte/data/data_store.py | 11 ++-- forte/data/entry_converter.py | 99 ++++++++++++++--------------- forte/data/multi_pack.py | 8 ++- forte/data/ontology/top.py | 9 +++ tests/forte/data/data_store_test.py | 50 +++++++++++++++ 8 files changed, 172 insertions(+), 91 deletions(-) diff --git a/forte/data/__init__.py b/forte/data/__init__.py index 1b06afe10..01858ebca 100644 --- a/forte/data/__init__.py +++ b/forte/data/__init__.py @@ -20,4 +20,3 @@ from forte.data.data_store import * from forte.data.selector import * from forte.data.index import * -from forte.data.entry_converter import * diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index 81a18c353..cb0b61565 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -477,6 +477,7 @@ def entry_getter(cls: Entry, attr_name: str, field_type): return field_type(parent_entry=cls, data=attr_val) try: # TODO: Find a better solution to determine if a field is Entry + # Will be addressed by https://github.com/asyml/forte/issues/835 # Convert tid to entry object on the fly if isinstance(attr_val, int): # Single pack entry @@ -502,20 +503,20 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): data_store_ref = ( cls.pack._data_store # pylint: disable=protected-access ) + # Assumption: When users assign value to a FList/FDict field, the + # value's type has to be Iterator[Entry]/Dict[Any, Entry]. if field_type is FList: - attr_value = [ - entry.tid if isinstance(entry, Entry) else entry - for entry in value - ] + attr_value = [entry.tid for entry in value] elif field_type is FDict: - attr_value = { - key: entry.tid if isinstance(entry, Entry) else entry - for key, entry in value.items() - } + attr_value = {key: entry.tid for key, entry in value.items()} elif isinstance(value, Entry): attr_value = ( value.tid if value.pack.pack_id == cls.pack.pack_id + # When value's pack and cls's pack are not the same, we + # assume that cls.pack is a MultiPack, which will resolve + # value.tid using MultiPack.get_subentry(pack_id, tid). + # In this case, both pack_id and tid should be stored. else (value.pack.pack_id, value.tid) ) else: @@ -529,11 +530,17 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): # Register property functions for all dataclass fields. for name, field in entry.__dataclass_fields__.items(): + # Convert the typing annotation to the original class. + # This will be used to determine if a field is FList/FDict. field_type = get_origin(field.type) setattr( type(entry), name, + # property(fget, fset) will register a conversion layer + # that specifies how to retrieve/assign value of this field. property( + # We need to bound the attribute name and field type here + # for the getter and setter of each field. fget=partial( entry_getter, attr_name=name, field_type=field_type ), diff --git a/forte/data/base_store.py b/forte/data/base_store.py index 15349db62..f2dab0d34 100644 --- a/forte/data/base_store.py +++ b/forte/data/base_store.py @@ -270,31 +270,6 @@ def add_image_annotation_raw( """ raise NotImplementedError - @abstractmethod - def add_grid_raw( - self, - type_name: str, - image_payload_idx: int, - tid: Optional[int] = None, - ) -> int: - - r""" - This function adds an image annotation entry with ``image_payload_idx`` - indices to current data store object. Returns the ``tid`` for the - inserted entry. - - Args: - type_name: The fully qualified type name of the new grid. - image_payload_idx: the index of the image payload. - tid: ``tid`` of the Annotation entry that is being added. - It's optional, and it will be - auto-assigned if not given. - - Returns: - ``tid`` of the entry. - """ - raise NotImplementedError - @abstractmethod def add_multipack_generic_raw( self, type_name: str, tid: Optional[int] = None @@ -539,3 +514,40 @@ def prev_entry(self, tid: int) -> Optional[List]: """ raise NotImplementedError + + @abstractmethod + def _is_subclass( + self, type_name: str, cls, no_dynamic_subclass: bool = False + ) -> bool: + r"""This function takes a fully qualified ``type_name`` class name, + ``cls`` class and returns whether ``type_name`` class is the``cls`` + subclass or not. This function accepts two types of class: the class defined + in forte, or the classes in user provided ontology file. + + + Args: + type_name: A fully qualified name of an entry class. + cls: An entry class. + no_dynamic_subclass: A boolean value controlling where to look for + subclasses. If True, this function will not check the subclass + relations via `issubclass` but rely on pre-populated states only. + + Returns: + A boolean value whether ``type_name`` class is the``cls`` + subclass or not. + + """ + raise NotImplementedError + + @abstractmethod + def _is_annotation(self, type_name: str) -> bool: + r"""This function takes a type_name and returns whether a type + is an annotation type or not. + Args: + type_name: The name of type in `self.__elements`. + + Returns: + A boolean value whether this type_name belongs to an annotation + type or not. + """ + raise NotImplementedError diff --git a/forte/data/data_store.py b/forte/data/data_store.py index 70cf2d6ae..1a07ee241 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -784,7 +784,7 @@ def _is_subclass( """ if type_name not in DataStore._type_attributes: - DataStore._type_attributes[type_name] = {} + self._get_type_info(type_name=type_name) if "parent_class" not in DataStore._type_attributes[type_name]: DataStore._type_attributes[type_name]["parent_class"] = set() cls_qualified_name = get_full_module_name(cls) @@ -838,9 +838,10 @@ def _is_annotation(self, type_name: str) -> bool: A boolean value whether this type_name belongs to an annotation type or not. """ - # TODO: use is_subclass() in DataStore to replace this - entry_class = get_class(type_name) - return issubclass(entry_class, (Annotation, AudioAnnotation)) + return any( + self._is_subclass(type_name, entry_class) + for entry_class in (Annotation, AudioAnnotation) + ) def all_entries(self, entry_type_name: str) -> Iterator[List]: """ @@ -1079,7 +1080,7 @@ def add_image_annotation_raw( # A reference to the entry should be store in both self.__elements and # self.__tid_ref_dict. entry = self._new_image_annotation(type_name, image_payload_idx, tid) - return self._add_entry_raw(AudioAnnotation, type_name, entry) + return self._add_entry_raw(ImageAnnotation, type_name, entry) def add_grid_raw( self, diff --git a/forte/data/entry_converter.py b/forte/data/entry_converter.py index 3350609cf..bd1ab2863 100644 --- a/forte/data/entry_converter.py +++ b/forte/data/entry_converter.py @@ -13,8 +13,7 @@ # limitations under the License. import logging -from typing import Dict -from forte.common.constants import TID_INDEX +from typing import Dict, Optional from forte.data.base_pack import PackType from forte.data.ontology.core import Entry, FList, FDict from forte.data.ontology.core import EntryType @@ -29,6 +28,8 @@ MultiPackGeneric, MultiPackGroup, MultiPackLink, + SinglePackEntries, + MultiPackEntries, ) from forte.utils import get_class, get_full_module_name @@ -50,11 +51,12 @@ def __init__(self) -> None: def save_entry_object( self, entry: Entry, pack: PackType, allow_duplicate: bool = True ): + # pylint: disable=protected-access """ Save an existing entry object into DataStore. """ # Check if the entry is already stored - data_store_ref = pack._data_store # pylint: disable=protected-access + data_store_ref = pack._data_store try: data_store_ref.get_entry(tid=entry.tid) logger.info( @@ -67,68 +69,69 @@ def save_entry_object( pass # Create a new registry in DataStore based on entry's type - if isinstance(entry, Annotation): + if data_store_ref._is_subclass(entry.entry_type(), Annotation): data_store_ref.add_annotation_raw( type_name=entry.entry_type(), - begin=entry.begin, - end=entry.end, + begin=entry.begin, # type: ignore + end=entry.end, # type: ignore tid=entry.tid, allow_duplicate=allow_duplicate, ) - elif isinstance(entry, Link): + elif data_store_ref._is_subclass(entry.entry_type(), Link): data_store_ref.add_link_raw( type_name=entry.entry_type(), - parent_tid=entry.parent, - child_tid=entry.child, + parent_tid=entry.parent, # type: ignore + child_tid=entry.child, # type: ignore tid=entry.tid, ) - elif isinstance(entry, Group): + elif data_store_ref._is_subclass(entry.entry_type(), Group): data_store_ref.add_group_raw( type_name=entry.entry_type(), - member_type=get_full_module_name(entry.MemberType), + member_type=get_full_module_name(entry.MemberType), # type: ignore tid=entry.tid, ) - elif isinstance(entry, Generics): + elif data_store_ref._is_subclass(entry.entry_type(), Generics): data_store_ref.add_generics_raw( type_name=entry.entry_type(), tid=entry.tid, ) - elif isinstance(entry, AudioAnnotation): + elif data_store_ref._is_subclass(entry.entry_type(), AudioAnnotation): data_store_ref.add_audio_annotation_raw( type_name=entry.entry_type(), - begin=entry.begin, - end=entry.end, + begin=entry.begin, # type: ignore + end=entry.end, # type: ignore tid=entry.tid, allow_duplicate=allow_duplicate, ) - elif isinstance(entry, ImageAnnotation): + elif data_store_ref._is_subclass(entry.entry_type(), ImageAnnotation): data_store_ref.add_image_annotation_raw( type_name=entry.entry_type(), - image_payload_idx=entry.image_payload_idx, + image_payload_idx=entry.image_payload_idx, # type: ignore tid=entry.tid, ) - elif isinstance(entry, Grids): - data_store_ref.add_grid_raw( + elif data_store_ref._is_subclass(entry.entry_type(), Grids): + # Will be deprecated in future + data_store_ref.add_grid_raw( # type: ignore type_name=entry.entry_type(), - image_payload_idx=entry.image_payload_idx, + image_payload_idx=entry.image_payload_idx, # type: ignore tid=entry.tid, ) - elif isinstance(entry, MultiPackLink): + elif data_store_ref._is_subclass(entry.entry_type(), MultiPackLink): data_store_ref.add_multipack_link_raw( type_name=entry.entry_type(), - parent_pack_id=entry.parent[0], - parent_tid=entry.parent[1], - child_pack_id=entry.child[0], - child_tid=entry.child[1], + parent_pack_id=entry.parent[0], # type: ignore + parent_tid=entry.parent[1], # type: ignore + child_pack_id=entry.child[0], # type: ignore + child_tid=entry.child[1], # type: ignore tid=entry.tid, ) - elif isinstance(entry, MultiPackGroup): + elif data_store_ref._is_subclass(entry.entry_type(), MultiPackGroup): data_store_ref.add_multipack_group_raw( type_name=entry.entry_type(), - member_type=get_full_module_name(entry.MemberType), + member_type=get_full_module_name(entry.MemberType), # type: ignore tid=entry.tid, ) - elif isinstance(entry, MultiPackGeneric): + elif data_store_ref._is_subclass(entry.entry_type(), MultiPackGeneric): data_store_ref.add_multipack_generic_raw( type_name=entry.entry_type(), tid=entry.tid, @@ -158,7 +161,9 @@ def save_entry_object( # Cache the stored entry and its tid self._entry_dict[entry.tid] = entry - def get_entry_object(self, tid: int, pack: PackType) -> EntryType: + def get_entry_object( + self, tid: int, pack: PackType, type_name: Optional[str] = None + ) -> EntryType: """ Convert a tid to its corresponding entry object. """ @@ -168,42 +173,36 @@ def get_entry_object(self, tid: int, pack: PackType) -> EntryType: return self._entry_dict[tid] # type: ignore data_store_ref = pack._data_store # pylint: disable=protected-access - entry_data, entry_type = data_store_ref.get_entry(tid=tid) - entry_class = get_class(entry_type) + if type_name is None: + _, type_name = data_store_ref.get_entry(tid=tid) + entry_class = get_class(type_name) entry: Entry + # pylint: disable=protected-access # Here the entry arguments are optional (begin, end, parent, ...) and # the value can be arbitrary since they will all be routed to DataStore. - if issubclass(entry_class, (Annotation, AudioAnnotation)): + if data_store_ref._is_annotation(type_name): entry = entry_class(pack=pack, begin=0, end=0) - elif issubclass( - entry_class, - ( - Link, - Group, - Generics, - MultiPackGeneric, - MultiPackGroup, - MultiPackLink, - ), + elif any( + data_store_ref._is_subclass(type_name, type_class) + for type_class in SinglePackEntries + MultiPackEntries ): entry = entry_class(pack=pack) else: + valid_entries: str = ", ".join( + map(get_full_module_name, SinglePackEntries + MultiPackEntries) + ) raise ValueError( - f"Invalid entry type {type(entry_class)}. A valid entry " - f"should be an instance of Annotation, Link, Group, Generics " - "or AudioAnnotation." + f"Invalid entry type {type_name}. A valid entry should be an" + f" instance of {valid_entries}." ) - # TODO: Remove the new tid and direct the entry object to the correct - # tid. The implementation here is a little bit hacky. Will need a stable - # solution in future. - # pylint: disable=protected-access + # Remove the new tid and direct the entry object to the correct tid. if entry.tid in self._entry_dict: self._entry_dict.pop(entry.tid) if entry.tid in pack._pending_entries: pack._pending_entries.pop(entry.tid) data_store_ref.delete_entry(tid=entry.tid) - entry._tid = entry_data[TID_INDEX] + entry._tid = tid self._entry_dict[tid] = entry return entry # type: ignore diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index 045541040..92ec744c8 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -24,7 +24,7 @@ from sortedcontainers import SortedList from forte.common import ProcessExecutionException -from forte.common.constants import TID_INDEX +from forte.common.constants import TID_INDEX, ENTRY_TYPE_INDEX from forte.data.base_pack import BaseMeta, BasePack from forte.data.data_pack import DataPack from forte.data.data_store import DataStore @@ -860,7 +860,11 @@ def get( # type: ignore type_name=get_full_module_name(entry_type_), include_sub_type=include_sub_type, ): - entry: Entry = self.get_entry(tid=entry_data[TID_INDEX]) + entry: Entry = self._entry_converter.get_entry_object( + tid=entry_data[TID_INDEX], + pack=self, + type_name=entry_data[ENTRY_TYPE_INDEX], + ) # Filter by components if components is not None: if not self.is_created_by(entry, components): diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index 9325a8c22..e621fe028 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -55,6 +55,15 @@ QueryType = Union[Dict[str, Any], np.ndarray] +""" +To create a new top level entry, the following steps are required to +make sure it available across the ontology system: + 1. Create a new top level class that inherits from `Entry` or `MultiEntry` + 2. Add the new class to `SinglePackEntries` or `MultiPackEntries` + 3. Register a new method in `DataStore`: `add__raw()` + 4. Insert a new conditional branch in `EntryConverter.save_entry_object()` +""" + class Generics(Entry): def __init__(self, pack: PackType): diff --git a/tests/forte/data/data_store_test.py b/tests/forte/data/data_store_test.py index f278cba73..2fbd40546 100644 --- a/tests/forte/data/data_store_test.py +++ b/tests/forte/data/data_store_test.py @@ -758,6 +758,56 @@ def test_add_group_raw(self): ["test_group", [], tid, "forte.data.ontology.top.Group"], ) + def test_add_multientry_raw(self): + self.data_store.add_multipack_generic_raw( + "forte.data.ontology.top.MultiPackGeneric" + ) + # check number of MultiPackGeneric + self.assertEqual( + len( + self.data_store._DataStore__elements[ + "forte.data.ontology.top.MultiPackGeneric" + ] + ), + 1, + ) + + self.data_store.add_multipack_group_raw( + "forte.data.ontology.top.MultiPackGroup", "test_group" + ) + # check number of MultiPackGeneric + self.assertEqual( + len( + self.data_store._DataStore__elements[ + "forte.data.ontology.top.MultiPackGroup" + ] + ), + 1, + ) + + self.data_store.add_multipack_link_raw( + "forte.data.ontology.top.MultiPackLink", 100, 1234, 20, 9999 + ) + # check number of MultiPackGeneric + self.assertEqual( + len( + self.data_store._DataStore__elements[ + "forte.data.ontology.top.MultiPackLink" + ] + ), + 1, + ) + + # check add MultiPackLink with tid + tid = 77968 + self.data_store.add_multipack_link_raw( + "forte.data.ontology.top.MultiPackLink", 100, 1234, 20, 9999, tid + ) + self.assertEqual( + self.data_store.get_entry(tid=tid)[0], + [[100, 1234], [20, 9999], tid, "forte.data.ontology.top.MultiPackLink"], + ) + def test_get_attribute(self): speaker = self.data_store.get_attribute(9999, "speaker") classifications = self.data_store.get_attribute(3456, "classifications") From 39933386221c26715e38d7a7fce9c184d170743e Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Thu, 16 Jun 2022 19:33:43 -0400 Subject: [PATCH 3/5] Fix _get_type_info --- forte/data/data_store.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/forte/data/data_store.py b/forte/data/data_store.py index 1a07ee241..3e5c627f7 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -508,7 +508,10 @@ def _get_type_info(self, type_name: str) -> Dict[str, Any]: dynamic import is disabled. """ # check if type is in dictionary - if type_name in DataStore._type_attributes: + if ( + type_name in DataStore._type_attributes + and "attributes" in DataStore._type_attributes[type_name] + ): return DataStore._type_attributes[type_name] if not self._dynamically_add_type: raise ValueError( @@ -784,7 +787,7 @@ def _is_subclass( """ if type_name not in DataStore._type_attributes: - self._get_type_info(type_name=type_name) + DataStore._type_attributes[type_name] = {} if "parent_class" not in DataStore._type_attributes[type_name]: DataStore._type_attributes[type_name]["parent_class"] = set() cls_qualified_name = get_full_module_name(cls) From c1c06ee0a82aeaaafb654defa2908a4b1e6a1ea6 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Tue, 21 Jun 2022 18:00:21 -0400 Subject: [PATCH 4/5] Fix issues in property getter and setter --- forte/data/base_pack.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index cb0b61565..16b178d69 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -465,6 +465,11 @@ def entry_getter(cls: Entry, attr_name: str, field_type): """A getter function for dataclass fields of entry object. When the field contains ``tid``s, we will convert them to entry object on the fly. + + Args: + cls: An ``Entry`` class object. + attr_name: The name of the attribute. + field_type: The type of the attribute. """ data_store_ref = ( cls.pack._data_store # pylint: disable=protected-access @@ -482,6 +487,10 @@ def entry_getter(cls: Entry, attr_name: str, field_type): if isinstance(attr_val, int): # Single pack entry return cls.pack.get_entry(tid=attr_val) + # The condition below is to check whether the attribute's value + # is a pair of integers - `(pack_id, tid)`. If so we may have + # encountered a `tid` that can only be resolved by + # `MultiPack.get_subentry`. elif ( isinstance(attr_val, tuple) and len(attr_val) == 2 @@ -498,17 +507,39 @@ def entry_setter(cls: Entry, value: Any, attr_name: str, field_type): """A setter function for dataclass fields of entry object. When the value contains entry objects, we will convert them into ``tid``s before storing to ``DataStore``. + + Args: + cls: An ``Entry`` class object. + value: The value to be assigned to the attribute. + attr_name: The name of the attribute. + field_type: The type of the attribute. """ attr_value: Any data_store_ref = ( cls.pack._data_store # pylint: disable=protected-access ) - # Assumption: When users assign value to a FList/FDict field, the - # value's type has to be Iterator[Entry]/Dict[Any, Entry]. + # Assumption: Users will not assign value to a FList/FDict field. + # Only internal methods can set the FList/FDict field, and value's + # type has to be Iterator[Entry]/Dict[Any, Entry]. if field_type is FList: - attr_value = [entry.tid for entry in value] + try: + attr_value = [entry.tid for entry in value] + except AttributeError as e: + raise ValueError( + "You are trying to assign value to a `FList` field, " + "which can only accept an iterator of `Entry` objects." + ) from e elif field_type is FDict: - attr_value = {key: entry.tid for key, entry in value.items()} + try: + attr_value = { + key: entry.tid for key, entry in value.items() + } + except AttributeError as e: + raise ValueError( + "You are trying to assign value to a `FDict` field, " + "which can only accept a mapping whose values are " + "`Entry` objects." + ) from e elif isinstance(value, Entry): attr_value = ( value.tid From 86c650c4757ea2dd2fbbffa1e80a7f03b2ba1966 Mon Sep 17 00:00:00 2001 From: Suqi Sun Date: Wed, 22 Jun 2022 17:58:19 -0400 Subject: [PATCH 5/5] Remove Entry serialization/deserialization logics --- docs/code/data.rst | 6 - forte/data/base_pack.py | 10 +- forte/data/container.py | 14 -- forte/data/data_store.py | 7 +- forte/data/entry_converter.py | 8 +- forte/data/multi_pack.py | 8 - forte/data/ontology/core.py | 234 +----------------- .../forte/data/entry_data_structures_test.py | 2 +- 8 files changed, 12 insertions(+), 277 deletions(-) diff --git a/docs/code/data.rst b/docs/code/data.rst index e11c51c58..0232be75e 100644 --- a/docs/code/data.rst +++ b/docs/code/data.rst @@ -387,12 +387,6 @@ Container :members: -:hidden:`BasePointer` ------------------------------- -.. autoclass:: forte.data.container.BasePointer - :members: - - Types =============== :hidden:`ReplaceOperationsType` diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index 16b178d69..a519ab7fd 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -173,15 +173,9 @@ def __iter__(self) -> Iterator[EntryType]: raise NotImplementedError def __del__(self): - num_remaning_entries: int = len(self._pending_entries) - if num_remaning_entries > 0: - # Remove all the remaining tids in _pending_entries. - tids: List = list(self._pending_entries.keys()) - for tid in tids: - self._pending_entries.pop(tid) - self._data_store.delete_entry(tid=tid) + if len(self._pending_entries) > 0: raise ProcessExecutionException( - f"There are {num_remaning_entries} " + f"There are {len(self._pending_entries)} " f"entries not added to the index correctly." ) diff --git a/forte/data/container.py b/forte/data/container.py index 7948bff63..3dce7ef93 100644 --- a/forte/data/container.py +++ b/forte/data/container.py @@ -24,7 +24,6 @@ __all__ = [ "EntryContainer", "ContainerType", - "BasePointer", ] E = TypeVar("E") @@ -32,19 +31,6 @@ G = TypeVar("G") -class BasePointer: - """ - Objects to point to other objects in the data pack. - """ - - def __str__(self): - raise NotImplementedError - - def __getstate__(self): - state = self.__dict__.copy() - return state - - class EntryContainer(Generic[E, L, G]): def __init__(self): # Record the set of entries created by some components. diff --git a/forte/data/data_store.py b/forte/data/data_store.py index 3e5c627f7..1a07ee241 100644 --- a/forte/data/data_store.py +++ b/forte/data/data_store.py @@ -508,10 +508,7 @@ def _get_type_info(self, type_name: str) -> Dict[str, Any]: dynamic import is disabled. """ # check if type is in dictionary - if ( - type_name in DataStore._type_attributes - and "attributes" in DataStore._type_attributes[type_name] - ): + if type_name in DataStore._type_attributes: return DataStore._type_attributes[type_name] if not self._dynamically_add_type: raise ValueError( @@ -787,7 +784,7 @@ def _is_subclass( """ if type_name not in DataStore._type_attributes: - DataStore._type_attributes[type_name] = {} + self._get_type_info(type_name=type_name) if "parent_class" not in DataStore._type_attributes[type_name]: DataStore._type_attributes[type_name]["parent_class"] = set() cls_qualified_name = get_full_module_name(cls) diff --git a/forte/data/entry_converter.py b/forte/data/entry_converter.py index bd1ab2863..b4dfed83e 100644 --- a/forte/data/entry_converter.py +++ b/forte/data/entry_converter.py @@ -137,10 +137,12 @@ def save_entry_object( tid=entry.tid, ) else: + valid_entries: str = ", ".join( + map(get_full_module_name, SinglePackEntries + MultiPackEntries) + ) raise ValueError( - f"Invalid entry type {type(entry)}. A valid entry " - f"should be an instance of Annotation, Link, Group, Generics " - "or AudioAnnotation." + f"Invalid entry type {entry.entry_type()}. A valid entry should" + f" be an instance of {valid_entries}." ) # Store all the dataclass attributes to DataStore diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index 92ec744c8..eb437b276 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -124,14 +124,6 @@ def relink(self, packs: Iterator[DataPack]): None """ self._packs.extend(packs) - for a in self.links: - a.relink_pointer() - - for a in self.groups: - a.relink_pointer() - - for a in self.generics: - a.relink_pointer() def __getstate__(self): r""" diff --git a/forte/data/ontology/core.py b/forte/data/ontology/core.py index 72a9c9899..10f3b2641 100644 --- a/forte/data/ontology/core.py +++ b/forte/data/ontology/core.py @@ -36,9 +36,7 @@ import numpy as np -from packaging.version import Version - -from forte.data.container import ContainerType, BasePointer +from forte.data.container import ContainerType __all__ = [ "Entry", @@ -47,17 +45,12 @@ "LinkType", "GroupType", "EntryType", - "Pointer", - "MpPointer", "FDict", "FList", "FNdArray", "MultiEntry", ] -from forte.utils import get_full_module_name -from forte.version import DEFAULT_PACK_VERSION, PACK_ID_COMPATIBLE_VERSION - default_entry_fields = [ "_Entry__pack", "_tid", @@ -79,78 +72,6 @@ "__orig_class__", ] -_f_struct_keys: Dict[str, bool] = {} -_pointer_keys: Dict[str, bool] = {} - - -def set_state_func(instance, state): - # pylint: disable=protected-access - """ - An internal used function. `instance` is an instance of Entry or a - MultiEntry. This function will populate the internal states for them. - - Args: - instance: - state: - - Returns: - None - """ - # During de-serialization, convert the list back to numpy array. - if "_embedding" in state: - state["_embedding"] = np.array(state["_embedding"]) - else: - state["_embedding"] = np.empty(0) - - # NOTE: the __pack will be set via set_pack from the Pack side. - cls_name = get_full_module_name(instance) - for k, v in state.items(): - key = cls_name + "_" + k - if _f_struct_keys.get(key, False): - v._set_parent(instance) - else: - if isinstance(v, (FList, FDict)): - v._set_parent(instance) - _f_struct_keys[key] = True - else: - _f_struct_keys[key] = False - - instance.__dict__.update(state) - - -def get_state_func(instance): - # pylint: disable=protected-access - r"""In serialization, the reference to pack is not set, and - it will be set by the container. - - This also implies that it is not advised to serialize an entry on its - own, without the ``Container`` as the context, there is little semantics - remained in an entry and unexpected errors could occur. - """ - state = instance.__dict__.copy() - # During serialization, convert the numpy array as a list. - emb = list(instance._embedding.tolist()) - if len(emb) == 0: - state.pop("_embedding") - else: - state["_embedding"] = emb - - cls_name = get_full_module_name(instance) - for k, v in state.items(): - key = cls_name + "_" + k - if k in _pointer_keys: - if _pointer_keys[key]: - state[k] = v.as_pointer(instance) - else: - if isinstance(v, Entry): - state[k] = v.as_pointer(instance) - _pointer_keys[key] = True - else: - _pointer_keys[key] = False - - state.pop("_Entry__pack") - return state - @dataclass class Entry(Generic[ContainerType]): @@ -192,19 +113,6 @@ def __init__(self, pack: ContainerType): self.pack._validate(self) self.pack.on_entry_creation(self) - def __getstate__(self): - r"""In serialization, the pack is not serialize, and it will be set - by the container. - - This also implies that it is not advised to serialize an entry on its - own, without the ``Container`` as the context, there is little semantics - remained in an entry. - """ - return get_state_func(self) - - def __setstate__(self, state): - set_state_func(self, state) - # using property decorator # a getter function for self._embedding @property @@ -250,62 +158,6 @@ def pack_id(self) -> int: def set_pack(self, pack: ContainerType): self.__pack = pack - def relink_pointer(self): - """ - This function is normally called after deserialization. It can be called - when the pack reference of this entry is ready (i.e. after `set_pack`). - The purpose is to convert the `Pointer` objects into actual entries. - """ - cls_name = get_full_module_name(self) - for k, v in self.__dict__.items(): - key = cls_name + "_" + k - if k in _pointer_keys: - if _pointer_keys[key]: - setattr(self, k, self._resolve_pointer(v)) - else: - if isinstance(v, BasePointer): - _pointer_keys[key] = True - setattr(self, k, self._resolve_pointer(v)) - else: - _pointer_keys[key] = False - - def as_pointer(self, from_entry: "Entry"): - """ - Return this entry as a pointer of this entry relative to the - ``from_entry``. - - Args: - from_entry: the entry to point from. - - Returns: - A pointer to the this entry from the ``from_entry``. - """ - if isinstance(from_entry, MultiEntry): - return MpPointer( - # bug fix/enhancement 559: change pack index to pack_id for multi-entry/multi-pack - self.pack_id, - self.tid, # from_entry.pack.get_pack_index(self.pack_id) - ) - elif isinstance(from_entry, Entry): - return Pointer(self.tid) - - def _resolve_pointer(self, ptr: BasePointer): - """ - Resolve into an entry on the provided pointer ``ptr`` from this entry. - - Args: - ptr: A pointer that refer to an entity. - - Returns: - None - """ - if isinstance(ptr, Pointer): - return self.pack.get_entry(ptr.tid) - else: - raise TypeError( - f"Unsupported pointer type {ptr.__class__} for entry" - ) - def entry_type(self) -> str: """Return the full name of this entry type.""" module = self.__class__.__module__ @@ -414,44 +266,7 @@ class MultiEntry(Entry, ABC): A :class:`forte.data.ontology.top.MultiPackGroup` object represents a collection of multiple entries among different data packs. """ - - def as_pointer(self, from_entry: "Entry") -> "Pointer": - """ - Get a pointer of the entry relative to this entry - - Args: - from_entry: The entry relative from. - - Returns: - A pointer relative to the this entry. - """ - if isinstance(from_entry, MultiEntry): - return Pointer(self.tid) - elif isinstance(from_entry, Entry): - raise ValueError( - "Do not support reference a multi pack entry from an entry." - ) - - def _resolve_pointer(self, ptr: BasePointer) -> Entry: - if isinstance(ptr, Pointer): - return self.pack.get_entry(ptr.tid) - elif isinstance(ptr, MpPointer): - # bugfix/new feature 559: in new version pack_index will be using pack_id internally - pack_array_index = ptr.pack_index # old version - pack_version = "" - try: - pack_version = self.pack.pack_version - except AttributeError: - pack_version = DEFAULT_PACK_VERSION # set to default if lacking version attribute - - if Version(pack_version) >= Version(PACK_ID_COMPATIBLE_VERSION): - pack_array_index = self.pack.get_pack_index( - ptr.pack_index - ) # default: new version - - return self.pack.packs[pack_array_index].get_entry(ptr.tid) - else: - raise TypeError(f"Unknown pointer type {ptr.__class__}") + pass EntryType = TypeVar("EntryType", bound=Entry) @@ -676,51 +491,6 @@ def data(self, array: Union[np.ndarray, List]): self._shape = self._data.shape -class Pointer(BasePointer): - """ - A pointer that points to an entry in the current pack, this is basically - containing the entry's tid. - """ - - def __init__(self, tid: int): - self._tid: int = tid - - @property - def tid(self): - return self._tid - - def __str__(self): - return f"[Entry Pointer]:{self.tid}" - - def __eq__(self, other): - return self.tid == other.tid - - -class MpPointer(BasePointer): - """ - Multi pack Pointer. A pointer that refers to an entry of one of the pack in - the multi pack. This contains the pack's index and the entries' tid. - """ - - def __init__(self, pack_index: int, tid: int): - self._pack_index: int = pack_index - self._tid: int = tid - - @property - def pack_index(self): - return self._pack_index - - @property - def tid(self): - return self._tid - - def __str__(self): - return f"[Entry Pointer]:{self.pack_index},{self.tid}" - - def __eq__(self, other): - return self._pack_index == other._pack_index and self.tid == other.tid - - class BaseLink(Entry, ABC): def __init__( self, diff --git a/tests/forte/data/entry_data_structures_test.py b/tests/forte/data/entry_data_structures_test.py index 3a55ea72a..03cd47152 100644 --- a/tests/forte/data/entry_data_structures_test.py +++ b/tests/forte/data/entry_data_structures_test.py @@ -6,7 +6,7 @@ from forte.data.data_pack import DataPack from forte.data.multi_pack import MultiPack from forte.data.ontology import Generics, MultiPackGeneric, Annotation -from forte.data.ontology.core import FList, FDict, Pointer +from forte.data.ontology.core import FList, FDict from forte.pipeline import Pipeline from forte.processors.base import PackProcessor, MultiPackProcessor from ft.onto.base_ontology import EntityMention