Skip to content

Commit

Permalink
Add add_episode & task logic
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Oct 21, 2024
1 parent 9ebf8b8 commit 299451a
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 18 deletions.
179 changes: 169 additions & 10 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@
from typing import Callable

import datasets
import pyarrow.parquet as pq
import torch
import torch.utils
from datasets import load_dataset
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, upload_folder

from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import (
append_jsonl,
check_delta_timestamps,
check_timestamps_sync,
create_branch,
create_empty_dataset_info,
get_delta_indices,
get_episode_data_index,
Expand Down Expand Up @@ -160,6 +163,7 @@ def __init__(
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.episode_buffer = {}
self.consolidated = True
self.delta_indices = None

# Load metadata
Expand Down Expand Up @@ -192,6 +196,24 @@ def __init__(
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)

def push_to_repo(self, push_videos: bool = True) -> None:
if not self.consolidated:
raise RuntimeError(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
"Please use the '.consolidate()' method first."
)
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")

upload_folder(
repo_id=self.repo_id,
folder_path=self.root,
repo_type="dataset",
ignore_patterns=ignore_patterns,
)
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")

def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
Expand Down Expand Up @@ -303,11 +325,6 @@ def num_samples(self) -> int:
"""Number of samples/frames in selected episodes."""
return len(self.hf_dataset)

@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]

@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
Expand All @@ -318,6 +335,16 @@ def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]

@property
def total_frames(self) -> int:
"""Total number of frames saved in this dataset."""
return self.info["total_frames"]

@property
def total_tasks(self) -> int:
"""Total number of different tasks performed in this dataset."""
return self.info["total_tasks"]

@property
def total_chunks(self) -> int:
"""Total number of chunks (groups of episodes)."""
Expand All @@ -331,7 +358,46 @@ def chunks_size(self) -> int:
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
return self.info["shapes"]

@property
def features(self) -> datasets.Features:
"""Shapes for the different features."""
if self.hf_dataset is not None:
return self.hf_dataset.features
elif self.episode_buffer is None:
raise NotImplementedError(
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
)

features = {}
for key in self.episode_buffer:
if key in ["episode_index", "frame_index", "index", "task_index"]:
features[key] = datasets.Value(dtype="int64")
elif key in ["next.done", "next.success"]:
features[key] = datasets.Value(dtype="bool")
elif key in ["timestamp", "next.reward"]:
features[key] = datasets.Value(dtype="float32")
elif key in self.image_keys:
features[key] = datasets.Image()
elif key in self.keys:
features[key] = datasets.Sequence(
length=self.shapes[key], feature=datasets.Value(dtype="float32")
)

return datasets.Features(features)

@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}

def get_task_index(self, task: str) -> int:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks

def current_episode_index(self, idx: int) -> int:
episode_index = self.hf_dataset["episode_index"][idx]
Expand Down Expand Up @@ -447,12 +513,12 @@ def __repr__(self):
f")"
)

def _create_episode_buffer(self) -> dict:
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
# TODO(aliberts): Handle resume
return {
"chunk": self.total_chunks,
"episode_index": self.total_episodes,
"size": 0,
"episode_index": self.total_episodes if episode_index is None else episode_index,
"task_index": None,
"frame_index": [],
"timestamp": [],
"next.done": [],
Expand Down Expand Up @@ -490,6 +556,92 @@ def add_frame(self, frame: dict) -> None:
file_path=img_path,
)

def add_episode(self, task: str, encode_videos: bool = False) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
Use encode_videos if you want to encode videos during the saving of each episode. Otherwise,
you can do it later during dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
"""
episode_length = self.episode_buffer.pop("size")
episode_index = self.episode_buffer["episode_index"]
task_index = self.get_task_index(task)
self.episode_buffer["next.done"][-1] = True

for key in self.episode_buffer:
if key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
elif key == "task_index":
self.episode_buffer[key] = torch.full((episode_length,), task_index)
else:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])

self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._save_episode_table(episode_index)

if encode_videos:
pass # TODO

# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
self.consolidated = False

def _save_episode_table(self, episode_index: int) -> None:
features = self.features
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
ep_table = ep_dataset._data.table
ep_data_path = self.get_data_file_path(ep_index=episode_index, return_str=False)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(ep_table, ep_data_path)

def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int
) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length

if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")

chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1

self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / "meta/info.json")

episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"length": episode_length,
}
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")

def delete_episode(self) -> None:
pass # TODO

def consolidate(self) -> None:
pass # TODO
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
# - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir
self.consolidated = True

@classmethod
def create(
cls,
Expand All @@ -508,19 +660,26 @@ def create(
obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
obj.hf_dataset = None

if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warn(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)

obj.tasks = {}
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json")

# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()

# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
# It is used to know when certain operations are need (for instance, computing dataset statistics).
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
obj.consolidated = True

# obj.episodes = None
# obj.image_transforms = None
# obj.delta_timestamps = None
Expand Down
5 changes: 5 additions & 0 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def write_json(data: dict, fpath: Path) -> None:
json.dump(data, f, indent=4, ensure_ascii=False)


def append_jsonl(data: dict, fpath: Path) -> None:
with jsonlines.open(fpath, "a") as writer:
writer.write(data)


def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
Expand Down
2 changes: 1 addition & 1 deletion lerobot/common/robot_devices/control_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")

if dataset is not None and fps is not None and dataset["fps"] != fps:
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")

timestamp = 0
Expand Down
35 changes: 28 additions & 7 deletions lerobot/scripts/control_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.populate_dataset import (
create_lerobot_dataset,
delete_current_episode,
save_current_episode,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
Expand Down Expand Up @@ -195,6 +193,7 @@ def record(
robot: Robot,
root: str,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
Expand All @@ -219,6 +218,11 @@ def record(
device = None
use_amp = None

if single_task:
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")

# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
Expand All @@ -235,8 +239,8 @@ def record(
sanity_check_dataset_name(repo_id, policy)
image_writer = ImageWriter(
write_dir=root,
num_image_writer_processes=num_image_writer_processes,
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
dataset = LeRobotDataset.create(repo_id, fps, robot, image_writer=image_writer)

Expand All @@ -261,7 +265,12 @@ def record(
if recorded_episodes >= num_episodes:
break

episode_index = dataset["num_episodes"]
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
# input() messes with them.
# if multi_task:
# task = input("Enter your task description: ")

episode_index = dataset.episode_buffer["episode_index"]
log_say(f"Recording episode {episode_index}", play_sounds)
record_episode(
dataset=dataset,
Expand Down Expand Up @@ -289,11 +298,11 @@ def record(
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
delete_current_episode(dataset)
dataset.delete_episode()
continue

# Increment by one dataset["current_episode_index"]
save_current_episode(dataset)
dataset.add_episode(task)

if events["stop_recording"]:
break
Expand Down Expand Up @@ -378,9 +387,21 @@ def replay(
)

parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the task performed during the recording.",
)
# TODO(aliberts): add multi-task support
# task_args.add_argument(
# "--multi-task",
# type=int,
# help="You will need to enter the task performed at the start of each episode.",
# )
parser_record.add_argument(
"--root",
type=Path,
Expand Down

0 comments on commit 299451a

Please sign in to comment.