Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved HuggingFace Connectors #612

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from typing import Any

from datasets import load_dataset
from datasets import load_dataset, Dataset
from huggingface_hub import HfApi
from kedro.io import AbstractVersionedDataset

import logging


logger = logging.getLogger(__file__)


class HFDataset(AbstractVersionedDataset):
"""``HFDataset`` loads Hugging Face datasets
Expand Down Expand Up @@ -36,20 +41,53 @@ class HFDataset(AbstractVersionedDataset):

"""

def __init__(self, *, dataset_name: str):
def __init__(
self,
*,
dataset_name: str,
filepath: str = None,
credentials: dict[Any] = None,
save_to_disk: bool = True,
save_to_hub: bool = False,
):
self.dataset_name = dataset_name

def _load(self):
return load_dataset(self.dataset_name)

def _save(self):
raise NotImplementedError("Not yet implemented")
self.filepath = filepath
self.credentials = credentials
self.save_to_disk = save_to_disk
self.save_to_hub = save_to_hub

def _load(self) -> Dataset:
try:
ds = Dataset.load_from_disk(self.filepath)
except FileNotFoundError:
ds = load_dataset(self.dataset_name)
return ds

def _save(self, data: Dataset):
if self.save_to_disk:
logger.info("Saving to local disk.")
data.save_to_disk(self._filepath)

if self.save_to_hub:
Comment on lines +67 to +71
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this branching here is kind of unusual. As an alternative, kedro-mlflow provides 2 different datasets for this purpose, MlflowModelTrackingDataset and MlflowModelLocalFileSystemDataset. Do you think we should do the same here @merelcht ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes looking at this it might be an idea to split this into a dataset that does loading & saving to disk and one that does it remotely.

logger.info("Saving to HuggingFace Hub")

if isinstance(self.credentials, dict):
token = self.credentials.get('write')
elif isinstance(self.credentials, str):
token = self.credentials
else:
token = None

data.push_to_hub(self._dataset_name, token=token)

def _describe(self) -> dict[str, Any]:
api = HfApi()
dataset_info = list(api.list_datasets(search=self.dataset_name))[0]
return {
"dataset_name": self.dataset_name,
"filepath": self.filepath,
"save_to_disk": self.save_to_disk,
"save_to_hub": self.save_to_hub,
"dataset_tags": dataset_info.tags,
"dataset_author": dataset_info.author,
}
Expand Down
56 changes: 56 additions & 0 deletions kedro-datasets/kedro_datasets/huggingface/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Any
from kedro.io import AbstractDataset
from transformers import AutoTokenizer, AutoModel

import logging
import importlib

from collections import namedtuple


TransformerModel = namedtuple("TransformerModel", ["model", "tokenizer"])


logger = logging.getLogger(__file__)


class HFTransformer(AbstractDataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about HFAutoTransformerDataset?

(About ...Dataset, I know it's not technically a dataset and that there's potential confusion with Hugging Face Datasets, but I'm wondering if we should keep consistency here. @merelcht ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'd prefer to keep it consistent even this isn't technically a dataset. HFAutoTransformerDataset sounds like a good name.

def __init__(
self,
checkpoint: str,
model_type: str = None,
tokenizer_kwargs: dict = None,
model_kwargs: dict = None,
):
self.checkpoint = checkpoint

if model_type is not None:
try:
self.model = importlib.import_module(model_type, package='transformers')
except ImportError as e:
logger.info(
f"Given model type={model_type} doesn't exist in transformers"
)
raise e
else:
self.model = AutoModel

self.tokenizer_kwargs = tokenizer_kwargs
self.model_kwargs = model_kwargs

def _load(self) -> TransformerModel:
model = self.model.from_pretrained(self.checkpoint, **self.model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint, **self.tokenizer_kwargs)

return TransformerModel(model=model, tokenizer=tokenizer)

def _save(self, data) -> None:
raise NotImplementedError("Pretrained models don't support saving for now")

def _describe(self) -> dict[str, Any]:
return {
"checkpoint": self.checkpoint,
"model_type": self.model,
"tokenizer_kwargs": self.tokenizer_kwargs,
"model_kwargs": self.model_kwargs,
}
Loading