diff --git a/presets/tuning/text-generation/dataset.py b/presets/tuning/text-generation/dataset.py index e7a9e6867..8582ddf3a 100644 --- a/presets/tuning/text-generation/dataset.py +++ b/presets/tuning/text-generation/dataset.py @@ -43,6 +43,9 @@ def load_data(self): try: self.dataset = load_dataset(file_ext, data_files=dataset_path, split="train") print(f"Dataset loaded successfully from {dataset_path} with file type '{file_ext}'.") + + self.analyze_dataset_structure() + except Exception as e: print(f"Error loading dataset: {e}") raise ValueError(f"Unable to load dataset {dataset_path} with file type '{file_ext}'") @@ -56,6 +59,22 @@ def find_valid_dataset(self, data_dir): if ext in filename_lower: return os.path.join(root, file) return None + + def analyze_dataset_structure(self): + if self.dataset is None: + raise ValueError("Dataset is not loaded. Please load the dataset first.") + + columns = self.dataset.column_names + + if len(columns) > 1: + self.dataset_text_field = self.find_text_field(columns) + + def find_text_field(self, columns): + text_words = ["text", "content", "input", "question"] + + for column in columns: + if any(text_word in column.lower() for text_word in text_words): + return column def get_file_extension(self, file_path): """ Returns the file extension based on filetype guess or filename. """