Skip to content

Commit

Permalink
Merge branch 'praateek/fuzzy-minhash-cc-improvements' into praateek/m…
Browse files Browse the repository at this point in the history
…inhash_permuted
  • Loading branch information
praateekmahajan committed Oct 21, 2024
2 parents 221a9bd + df62a1f commit e7a6c2b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 86 deletions.
119 changes: 35 additions & 84 deletions nemo_curator/modules/fuzzy_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def __init__(
cache_dir=self.config.cache_dir,
jaccard_pairs_path=os.path.join(self.config.cache_dir, jaccard_pairs_fname),
id_column=self.config.id_field,
convert_str_ids=False,
jaccard_threshold=self.config.jaccard_threshold,
logger=self._logger,
profile_dir=self.config.profile_dir,
Expand Down Expand Up @@ -1422,7 +1421,6 @@ def __init__(
cache_dir: str,
jaccard_pairs_path: str,
id_column="id",
convert_str_ids=False,
jaccard_threshold: float = 0.8,
logger: Union[logging.LoggerAdapter, str] = "./",
profile_dir: Optional[str] = None,
Expand All @@ -1432,7 +1430,6 @@ def __init__(
self.id_column = id_column
self.left_id = f"{id_column}_x"
self.right_id = f"{id_column}_y"
self.convert_str_ids = convert_str_ids
self.jaccard_threshold = jaccard_threshold
self.profile_dir = profile_dir
if isinstance(logger, str):
Expand All @@ -1445,6 +1442,7 @@ def __init__(
self._logger = logger

def cc_workflow(self, output_path):
st = time.time()
deduped_parsed_id_path = self._write_dedup_parsed_id()
encoded_jaccard_pair_path = self._write_encoded_jaccard_pair(
deduped_parsed_id_path
Expand All @@ -1455,6 +1453,7 @@ def cc_workflow(self, output_path):
cc_path = self._run_connected_components(
deduped_encoded_jaccard_path, deduped_parsed_id_path, output_path
)
self._logger.info(f"End to End time in cc_workflow = {time.time() - st}s")
return cc_path

def _run_connected_components(
Expand All @@ -1468,13 +1467,15 @@ def _run_connected_components(
self.profile_dir, "connected-components-run"
):

Comms.initialize(p2p=True)
Comms.initialize(p2p=False)
df = dask_cudf.read_parquet(
deduped_encoded_jaccard_path, blocksize="1GB", aggregate_files=True
)
df = df[df["jaccard"] == 1].reset_index(drop=True)

labels_df = dask_cudf.read_parquet(deduped_parsed_id_path)
labels_df = dask_cudf.read_parquet(
deduped_parsed_id_path, blocksize="1GB", aggregate_files=True
)
num_nodes = len(labels_df)
self_edge_df = labels_df[["uid"]].rename(columns={"uid": self.left_id})
self_edge_df[self.right_id] = self_edge_df[self.left_id]
Expand All @@ -1496,9 +1497,7 @@ def _run_connected_components(
labels_df = labels_df.merge(
result, left_on=["uid"], right_on=["vertex"], how="inner"
)
id_columns = (
["dataset_id", "doc_id"] if self.convert_str_ids else [self.id_column]
)
id_columns = [self.id_column]
labels_df = labels_df[id_columns + ["labels"]]
labels_df = labels_df.rename(columns={"labels": "group"})
labels_df = labels_df.persist()
Expand Down Expand Up @@ -1589,14 +1588,6 @@ def _write_dedup_encoded_jaccard_pair(self, encoded_jaccard_pair_path):
)
return output_path

def _convert_str_id_pair_to_int(self, df):
for id, tag in zip([self.left_id, self.right_id], ["x", "y"]):
dx = df[id].str.rsplit("-", n=1, expand=True)
df[f"dataset_id_{tag}"] = dx[0].astype("uint32").values
df[f"doc_id_{tag}"] = dx[1].astype("int64").values
df = df.drop(columns=[id])
return df

def _write_dedup_parsed_id(self):
dedup_parsed_id_path = f"{self.cache_dir}/dedup_parsed_id.parquet"
t0 = time.time()
Expand All @@ -1606,22 +1597,10 @@ def _write_dedup_parsed_id(self):
ddf = dask_cudf.read_parquet(
self.jaccard_pairs_path,
columns=[self.left_id, self.right_id],
blocksize="1GB",
blocksize="512MB",
aggregate_files=True,
)
id_columns = [self.id_column]
if self.convert_str_ids:
ddf = ddf.map_partitions(
self._convert_str_id_pair_to_int,
meta={
"dataset_id_x": "uint32",
"doc_id_x": "int64",
"dataset_id_y": "uint32",
"doc_id_y": "int64",
},
)
id_columns = ["dataset_id", "doc_id"]

unique_docs = ddf.map_partitions(
ConnectedComponents._get_unique_ids_per_partition, id_columns=id_columns
)
Expand All @@ -1647,73 +1626,45 @@ def _write_encoded_jaccard_pair(self, dedup_parsed_id_path):
ddf_id = dask_cudf.read_parquet(
dedup_parsed_id_path, blocksize="2GB", aggregate_files=True
)
ddf_id = ddf_id.persist()
len(ddf_id)
ddf = dask_cudf.read_parquet(
self.jaccard_pairs_path,
blocksize="256MB",
blocksize="1GB",
aggregate_files=True,
)
id_columns = [self.id_column]
if self.convert_str_ids:
ddf = ddf.map_partitions(
self._convert_str_id_pair_to_int,
meta={
"jaccard": "float32",
"dataset_id_x": "uint32",
"doc_id_x": "int64",
"dataset_id_y": "uint32",
"doc_id_y": "int64",
},
)
id_columns = ["dataset_id", "doc_id"]

num_workers = get_num_workers(get_current_client())
self._batched_merge_and_write(
self._merge_and_write(
ddf=ddf,
ddf_id=ddf_id,
output_path=output_path,
id_columns=id_columns,
batch_size=num_workers,
id_column=self.id_column,
)
self._logger.info(
f"Time taken for Encoding Jaccard Pairs = {time.time() - t0}s and output written at {output_path}"
)
return output_path

def _batched_merge_and_write(
self, ddf, ddf_id, output_path, id_columns, batch_size=32
):
total_batches = (ddf.npartitions + batch_size - 1) // batch_size
for batch_id, offset in enumerate(range(0, ddf.npartitions, batch_size)):
st = time.time()
subset_ddf = ddf.partitions[offset : offset + batch_size]
for tag in ["x", "y"]:
pair_ids = []
for id_col in id_columns:
pair_ids.append(f"{id_col}_{tag}")
subset_ddf = subset_ddf.merge(
ddf_id,
left_on=pair_ids,
right_on=id_columns,
how="inner",
broadcast=True,
)
subset_ddf = subset_ddf.drop(
columns=pair_ids,
)
subset_ddf = subset_ddf.rename(
columns={"uid": f"{self.id_column}_{tag}"}
)

subset_ddf = subset_ddf[[self.left_id, self.right_id, "jaccard"]]
output_batch_path = os.path.join(output_path, f"{batch_id}.parquet")
subset_ddf.to_parquet(output_batch_path, write_index=False)

et = time.time()
print(
f"batch_id = {batch_id}/{total_batches}, time = {et - st}", flush=True
def _merge_and_write(
self,
ddf: dask_cudf.DataFrame,
ddf_id: dask_cudf.DataFrame,
output_path: str,
id_column: str,
) -> None:
# Ensure 'id_columns' is a list
ddf_id = ddf_id.set_index(id_column)
for tag in ["x", "y"]:
pair_id = f"{id_column}_{tag}"
# Merge 'ddf' with 'ddf_id' to map ids to uids
ddf = ddf.merge(
ddf_id,
left_on=pair_id,
right_index=True,
how="inner",
broadcast=True,
)
ddf = ddf.drop(columns=pair_id)
ddf = ddf.rename(columns={"uid": f"{self.id_column}_{tag}"})
ddf = ddf[[self.left_id, self.right_id, "jaccard"]]
ddf.to_parquet(output_path, write_index=False)

@staticmethod
def _get_unique_ids_per_partition(df, id_columns):
Expand All @@ -1723,11 +1674,11 @@ def _get_unique_ids_per_partition(df, id_columns):
for id_col in id_columns:
cols_to_drop.append(f"{id_col}_{tag}")

subset_df = df[cols_to_drop].drop_duplicates()
subset_df = df[cols_to_drop].drop_duplicates(ignore_index=True)
subset_df = subset_df.rename(
columns={f"{id_col}_{tag}": f"{id_col}" for id_col in id_columns}
)
unique_df_ls.append(subset_df)
unique_df = cudf.concat(unique_df_ls, ignore_index=True)
unique_df = unique_df.drop_duplicates()
unique_df = unique_df.drop_duplicates(ignore_index=True)
return unique_df
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ def main(args):
st = time.time()
output_path = os.path.join(args.output_dir, "connected_components.parquet")
args.enable_spilling = True

client = get_client(**ArgumentHelper.parse_client_args(args))

components_stage = ConnectedComponents(
cache_dir=args.cache_dir,
jaccard_pairs_path=args.jaccard_pairs_path,
id_column=args.input_json_id_field,
convert_str_ids=True,
jaccard_threshold=args.jaccard_threshold,
logger=args.log_dir,
)
components_stage.cc_workflow(output_path=output_path)
print(f"All done in {time.time()-st:.1f} seconds")
Expand Down

0 comments on commit e7a6c2b

Please sign in to comment.