Skip to content

Commit

Permalink
Added text dropout option to the WebDataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
vramanuj committed May 2, 2023
1 parent 582750d commit 8090dc9
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import utils.distributed as dist_utils
import pandas as pd
import random as r

from utils.logging import Path

Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
train=True,
resolution=512,
filters=None,
text_dropout=0.0,
**kwargs,
):
self.filters = filters or {}
Expand All @@ -122,6 +124,7 @@ def __init__(
train=train,
num_examples_to_see=num_examples_to_see,
filters=self.filters,
text_dropout=text_dropout
)

self.loader = wds.WebLoader(
Expand All @@ -134,7 +137,7 @@ def __init__(

logging.info(f"Unused dataset parameters for WebDataset: {kwargs}")

def get_dataset(self, url, tokenizer, train, num_examples_to_see, filters):
def get_dataset(self, url, tokenizer, train, num_examples_to_see, filters, text_dropout=0.0):
transform = CenterCropSDTransform(center_crop=True, size=self.resolution)

pipeline = [wds.ResampledShards(url)]
Expand Down Expand Up @@ -162,6 +165,7 @@ def get_dataset(self, url, tokenizer, train, num_examples_to_see, filters):
pixel_values="jpg;png;jpeg;webp", input_ids="txt", text_raw="txt"
),
wds.map(filter_keys(set(["pixel_values", "input_ids", "text_raw"]))),
wds.map_dict(input_ids=lambda text: "" if r.random() < text_dropout else text),
wds.map_dict(
pixel_values=transform,
input_ids=lambda text: tokenizer(
Expand Down

0 comments on commit 8090dc9

Please sign in to comment.