diff --git a/multimodal_transformers/data/load_data.py b/multimodal_transformers/data/load_data.py index 65758b1..f589534 100644 --- a/multimodal_transformers/data/load_data.py +++ b/multimodal_transformers/data/load_data.py @@ -108,10 +108,14 @@ def load_data_from_folder(folder_path, data_df = pd.concat([data_df, cat_df], axis=1) categorical_cols = cat_feat_processor.feat_names - train_df = data_df.loc[train_df.index] + len_train = len(train_df) + len_val = len(val_df) if val_df is not None else 0 + + train_df = data_df.iloc[:len_train] if val_df is not None: - val_df = data_df.loc[val_df.index] - test_df = data_df.loc[test_df.index] + val_df = data_df.iloc[len_train: len_train + len_val] + len_train = len_train + len_val + test_df = data_df.iloc[len_train:] categorical_encode_type = None