import json import pandas as pd import numpy as np from pathlib import Path from typing import Dict, List, Optional, Union, Any, Tuple from datasets import Dataset, load_dataset import os from dataclasses import dataclass from abc import ABC, abstractmethod import logging from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder import re import argparse import sys import yaml logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # Set logger level to DEBUG to capture INFO, DEBUG, ERROR @dataclass class ClassificationConfig: """Configuration for classification tasks""" # Data source configuration data_source: str ="huggingface" # "huggingface" or "custom" dataset_name: Optional[str] = None # For Hugging Face datasets data_path: Optional[str] = None # For custom datasets data_format: str = "jsonl" # jsonl, csv, json # Field mapping input_field: str = "text" # Field containing input text label_field: str = "label" # Field containing labels id_field: Optional[str] = None # Optional ID field # Data processing max_samples: Optional[int] = None train_split: float = 0.8 validation_split: float = 0.1 test_split: float = 0.1 # Text preprocessing clean_text: bool = True remove_special_chars: bool = False lowercase: bool = True min_length: int = 10 max_length: int = 1000 # Label processing label_encoding: str = "auto" # auto, numeric, string multilabel: bool = False label_separator: str = "," # For multilabel datasets # Output configuration output_format: str = "classification" # instruction, conversation, qa output_dir: str = "./data" # Hugging Face specific hf_split: str = "train" hf_cache_dir: Optional[str] = None # Split configuration - new flexible split handling test_split_from: str = "train" # "train", "use_test_if_available", or "use_val_if_available" val_split_from: str = "train" # "train", "use_val_if_available" # Custom data specific encoding: str = "utf-8" delimiter: str = "," # For CSV files class DataValidator: """Validates classification data quality and format""" @staticmethod def validate_classification_data(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Tuple[bool, List[str]]: """Validate classification dataset splits""" errors = [] # Check if we have the expected splits expected_splits = ["train", "validation", "test"] for split in expected_splits: if split not in data or not data[split]: errors.append(f"Missing or empty '{split}' split") if errors: return False, errors total_samples = sum(len(split_data) for split_data in data.values()) logger.info(f"Validating {total_samples} total samples across all splits...") # Determine field names based on whether data is processed or not input_field = "input" if is_processed else config.input_field label_field = "label" # label field stays the same # Validate each split for split_name, split_data in data.items(): logger.info(f"Validating {split_name} split with {len(split_data)} samples...") # Check required fields missing_input_count = 0 missing_label_count = 0 for i, item in enumerate(split_data): if input_field not in item: errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}") missing_input_count += 1 if label_field not in item: errors.append(f"Missing label field '{label_field}' in {split_name} split, item {i}") missing_label_count += 1 logger.info(f"{split_name} - Items missing input field: {missing_input_count}") logger.info(f"{split_name} - Items missing label field: {missing_label_count}") # Check data types type_errors = 0 for i, item in enumerate(split_data): if not isinstance(item.get(input_field, ""), str): errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}") type_errors += 1 logger.info(f"{split_name} - Type errors: {type_errors}") # Check for empty inputs empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip()) if empty_inputs > 0: errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split") logger.info(f"{split_name} - Empty inputs: {empty_inputs}") # Check label distribution labels = [item.get(label_field) for item in split_data if item.get(label_field) is not None] unique_labels = set(labels) logger.info(f"{split_name} - Found {len(unique_labels)} unique labels: {unique_labels}") logger.info(f"{split_name} - Label distribution: {dict([(label, labels.count(label)) for label in unique_labels])}") if len(unique_labels) < 1: errors.append(f"{split_name} split must have at least 1 label, found: {unique_labels}") # Show sample of processed data for debugging if split_data: logger.info(f"Sample processed items from {split_name}:") for i in range(min(3, len(split_data))): item = split_data[i] logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', label='{item.get(label_field, '')}'") return len(errors) == 0, errors @staticmethod def analyze_dataset(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Dict[str, Any]: """Analyze dataset characteristics across all splits""" analysis = { "splits": {}, "overall": { "total_samples": 0, "all_unique_labels": set(), "split_sizes": {} } } # Determine field names based on whether data is processed or not input_field = "input" if is_processed else config.input_field label_field = "label" # label field stays the same # Analyze each split for split_name, split_data in data.items(): split_analysis = { "total_samples": len(split_data), "unique_labels": len(set(item.get(label_field) for item in split_data)), "label_distribution": {}, "text_length_stats": {}, "missing_values": {} } # Label distribution labels = [item.get(label_field) for item in split_data] for label in set(labels): split_analysis["label_distribution"][str(label)] = labels.count(label) analysis["overall"]["all_unique_labels"].add(str(label)) # Text length statistics text_lengths = [len(item.get(input_field, "")) for item in split_data] if text_lengths: split_analysis["text_length_stats"] = { "min": min(text_lengths), "max": max(text_lengths), "mean": np.mean(text_lengths), "median": np.median(text_lengths) } # Missing values for field in [input_field, label_field]: missing_count = sum(1 for item in split_data if not item.get(field)) split_analysis["missing_values"][field] = missing_count analysis["splits"][split_name] = split_analysis analysis["overall"]["total_samples"] += len(split_data) analysis["overall"]["split_sizes"][split_name] = len(split_data) analysis["overall"]["all_unique_labels"] = len(analysis["overall"]["all_unique_labels"]) return analysis class BaseDataLoader(ABC): """Abstract base class for data loaders""" @abstractmethod def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]: """Load data and return dictionary with train/val/test splits""" pass @abstractmethod def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]: """Apply preprocessing steps to all splits""" pass class HuggingFaceDataLoader(BaseDataLoader): """Load datasets from Hugging Face Hub""" def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]: """Load dataset from Hugging Face Hub with flexible split handling""" if not config.dataset_name: raise ValueError("Dataset name is required for Hugging Face datasets") logger.info(f"Loading Hugging Face dataset: {config.dataset_name}") try: # First, let's check what splits are available in the dataset dataset = load_dataset( config.dataset_name, cache_dir=config.hf_cache_dir ) # Log available splits available_splits = list(dataset.keys()) logger.info(f"Available splits in dataset: {available_splits}") # Initialize split data splits_data = { "train": [], "validation": [], "test": [] } # Handle train split if "train" in available_splits: train_dataset = dataset["train"] logger.info(f"Using 'train' split with {len(train_dataset)} samples") splits_data["train"] = list(train_dataset) else: logger.error("No 'train' split found in dataset!") logger.error(f"Available splits: {available_splits}") raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split") # Handle validation split if config.val_split_from == "use_val_if_available" and "validation" in available_splits: val_dataset = dataset["validation"] logger.info(f"Using 'validation' split with {len(val_dataset)} samples") splits_data["validation"] = list(val_dataset) elif config.val_split_from == "use_val_if_available" and "val" in available_splits: val_dataset = dataset["val"] logger.info(f"Using 'val' split with {len(val_dataset)} samples") splits_data["validation"] = list(val_dataset) elif config.val_split_from == "use_val_if_available": logger.warning("No validation split found in dataset. Will create from train split.") logger.info(f"Available splits: {available_splits}") logger.info(f"Will use {config.validation_split * 100}% of train data for validation") else: logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)") # Handle test split if config.test_split_from == "use_test_if_available" and "test" in available_splits: test_dataset = dataset["test"] logger.info(f"Using 'test' split with {len(test_dataset)} samples") splits_data["test"] = list(test_dataset) elif config.test_split_from == "use_val_if_available" and "validation" in available_splits: test_dataset = dataset["validation"] logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples") splits_data["test"] = list(test_dataset) elif config.test_split_from == "use_val_if_available" and "val" in available_splits: test_dataset = dataset["val"] logger.info(f"Using 'val' split as test with {len(test_dataset)} samples") splits_data["test"] = list(test_dataset) elif config.test_split_from == "use_test_if_available": logger.warning("No test split found in dataset. Will create from train split.") logger.info(f"Available splits: {available_splits}") logger.info(f"Will use {config.test_split * 100}% of train data for test") else: logger.info(f"Will create test split from train data ({config.test_split * 100}%)") # If we need to create splits from train data if not splits_data["validation"] or not splits_data["test"]: train_data = splits_data["train"] # Calculate remaining percentages for train total_train_percentage = config.train_split + config.validation_split + config.test_split if total_train_percentage != 1.0: logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...") # Normalize percentages config.train_split = config.train_split / total_train_percentage config.validation_split = config.validation_split / total_train_percentage config.test_split = config.test_split / total_train_percentage # Create splits from train data if not splits_data["validation"] and not splits_data["test"]: # Split train into train, val, test train_size = int(len(train_data) * config.train_split) val_size = int(len(train_data) * config.validation_split) # First split: train + (val+test) new_train, temp_data = train_test_split( train_data, test_size=config.validation_split + config.test_split, random_state=42, stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None ) # Second split: val + test new_val, new_test = train_test_split( temp_data, test_size=config.test_split / (config.validation_split + config.test_split), random_state=42, stratify=[item.get(config.label_field) for item in temp_data] if config.label_field in temp_data[0] else None ) splits_data["train"] = new_train splits_data["validation"] = new_val splits_data["test"] = new_test elif not splits_data["validation"]: # Only need to create val from train new_train, new_val = train_test_split( train_data, test_size=config.validation_split, random_state=42, stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None ) splits_data["train"] = new_train splits_data["validation"] = new_val elif not splits_data["test"]: # Only need to create test from train new_train, new_test = train_test_split( train_data, test_size=config.test_split, random_state=42, stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None ) splits_data["train"] = new_train splits_data["test"] = new_test logger.info(f"Final split sizes:") logger.info(f" Train: {len(splits_data['train'])} samples") logger.info(f" Validation: {len(splits_data['validation'])} samples") logger.info(f" Test: {len(splits_data['test'])} samples") # Apply max_samples limit to each split if specified if config.max_samples: for split_name in splits_data: if splits_data[split_name]: original_size = len(splits_data[split_name]) splits_data[split_name] = splits_data[split_name][:config.max_samples] logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples") # Log dataset info for debugging for split_name, split_data in splits_data.items(): if split_data: logger.info(f"Sample data item from {split_name}: {split_data[0]}") logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}") # Check if the required fields exist if config.input_field not in split_data[0]: logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}") # Suggest alternative fields text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])] if text_fields: logger.info(f"Suggested text fields for {split_name}: {text_fields}") if config.label_field not in split_data[0]: logger.warning(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}") # Suggest alternative fields label_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['label', 'class', 'category', 'target', 'emotion', 'labels'])] if label_fields: logger.info(f"Suggested label fields for {split_name}: {label_fields}") logger.info(f"Successfully loaded dataset {config.dataset_name}") return splits_data except Exception as e: logger.error(f"Error loading dataset {config.dataset_name}: {e}") raise def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]: """Apply preprocessing steps to all splits separately""" processed_splits = {} logger.info(f"=== PREPROCESSING DATA ===") for split_name, split_data in data.items(): logger.info(f"Processing {split_name} split with {len(split_data)} items...") # Log field availability for debugging if split_data: available_fields = set(split_data[0].keys()) logger.info(f"Available fields in {split_name}: {available_fields}") logger.info(f"Looking for input field: '{config.input_field}', label field: '{config.label_field}'") if config.input_field not in available_fields: logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}") if config.label_field not in available_fields: logger.error(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {available_fields}") # Count items with missing fields missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field)) missing_label = sum(1 for item in split_data if config.label_field not in item or item.get(config.label_field) is None) logger.info(f"{split_name} - Items missing input field: {missing_input}") logger.info(f"{split_name} - Items missing label field: {missing_label}") # Show sample of raw data before preprocessing logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===") for i in range(min(3, len(split_data))): item = split_data[i] logger.info(f"Raw item {i} from {split_name}:") for key, value in item.items(): if isinstance(value, str) and len(value) > 100: logger.info(f" {key}: '{value[:100]}...'") else: logger.info(f" {key}: {value}") # Process each item in the split processed_data = [] processed_count = 0 skipped_count = 0 # Reset debug counter for each split self._debug_count = 0 for i, item in enumerate(split_data): processed_item = self._preprocess_item(item, config) if processed_item is not None: processed_data.append(processed_item) processed_count += 1 else: skipped_count += 1 if skipped_count <= 3: # Log first few skipped items logger.info(f"Skipped item {i} from {split_name}: {item}") processed_splits[split_name] = processed_data logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples") # Show sample of processed data if processed_data: logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===") for i in range(min(3, len(processed_data))): logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}") return processed_splits def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]: """Preprocess a single item""" # Extract input and label input_text = item.get(config.input_field, "") label = item.get(config.label_field, "") # Log what we're extracting (for first few items) if hasattr(self, '_debug_count'): self._debug_count += 1 else: self._debug_count = 1 if self._debug_count <= 3: logger.debug(f"Processing item {self._debug_count}:") logger.debug(f" Looking for input field '{config.input_field}': {input_text}") logger.debug(f" Looking for label field '{config.label_field}': {label}") # Handle None values if input_text is None: input_text = "" if label is None: label = "" # Convert to string if needed input_text = str(input_text) label = str(label) if self._debug_count <= 3: logger.debug(f" After conversion - input: '{input_text[:50]}...', label: '{label}'") # Clean text if requested if config.clean_text: original_text = input_text input_text = self._clean_text(input_text, config) if self._debug_count <= 3: logger.debug(f" After cleaning - original: '{original_text[:50]}...', cleaned: '{input_text[:50]}...'") # Check length constraints if len(input_text) < config.min_length or len(input_text) > config.max_length: if self._debug_count <= 3: logger.debug(f" Skipping - length {len(input_text)} not in range [{config.min_length}, {config.max_length}]") return None # Create processed item processed_item = { "input": input_text, "label": label } # Add ID if available if config.id_field and config.id_field in item: processed_item["id"] = item[config.id_field] if self._debug_count <= 3: logger.debug(f" Final processed item: {processed_item}") return processed_item def _clean_text(self, text: str, config: ClassificationConfig) -> str: """Clean and normalize text""" if not isinstance(text, str): return "" # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() # Convert to lowercase if requested if config.lowercase: text = text.lower() # Remove special characters if requested if config.remove_special_chars: text = re.sub(r'[^\w\s]', '', text) return text def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig: """Helper function to create a HuggingFace configuration""" return ClassificationConfig( data_source="huggingface", dataset_name=dataset_name, input_field=input_field, label_field=label_field, **kwargs ) class CustomDataLoader(BaseDataLoader): """Load custom datasets from local files""" def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]: """Load custom dataset from local file and create splits""" if not config.data_path: raise ValueError("Data path is required for custom datasets") file_path = Path(config.data_path) if not file_path.exists(): raise FileNotFoundError(f"Data file not found: {file_path}") logger.info(f"Loading custom dataset: {file_path}") if config.data_format == "jsonl": raw_data = self._load_jsonl(file_path, config) elif config.data_format == "csv": raw_data = self._load_csv(file_path, config) elif config.data_format == "json": raw_data = self._load_json(file_path, config) else: raise ValueError(f"Unsupported format: {config.data_format}") if config.max_samples: raw_data = raw_data[:config.max_samples] logger.info(f"Loaded {len(raw_data)} samples from {file_path}") # Create splits from the raw data splits_data = self._create_splits(raw_data, config) return splits_data def _create_splits(self, data: List[Dict], config: ClassificationConfig) -> Dict[str, List[Dict]]: """Create train/validation/test splits from raw data""" logger.info(f"Creating splits from {len(data)} samples...") # Calculate split sizes total_samples = len(data) train_size = int(total_samples * config.train_split) val_size = int(total_samples * config.validation_split) test_size = total_samples - train_size - val_size # Create stratified splits if possible try: labels = [item.get(config.label_field) for item in data] # First split: train + (val+test) train_data, temp_data = train_test_split( data, test_size=config.validation_split + config.test_split, random_state=42, stratify=labels ) # Second split: val + test temp_labels = [item.get(config.label_field) for item in temp_data] val_data, test_data = train_test_split( temp_data, test_size=config.test_split / (config.validation_split + config.test_split), random_state=42, stratify=temp_labels ) except ValueError as e: logger.warning(f"Could not create stratified splits: {e}. Using random splits.") # Fallback to random splits train_data, temp_data = train_test_split(data, test_size=config.validation_split + config.test_split, random_state=42) val_data, test_data = train_test_split(temp_data, test_size=config.test_split / (config.validation_split + config.test_split), random_state=42) splits_data = { "train": train_data, "validation": val_data, "test": test_data } logger.info(f"Created splits:") logger.info(f" Train: {len(splits_data['train'])} samples") logger.info(f" Validation: {len(splits_data['validation'])} samples") logger.info(f" Test: {len(splits_data['test'])} samples") return splits_data def _load_jsonl(self, file_path: Path, config: ClassificationConfig) -> List[Dict]: """Load JSONL file""" data = [] with open(file_path, 'r', encoding=config.encoding) as f: for line_num, line in enumerate(f, 1): if line.strip(): try: data.append(json.loads(line)) except json.JSONDecodeError as e: logger.warning(f"Invalid JSON at line {line_num}: {e}") return data def _load_csv(self, file_path: Path, config: ClassificationConfig) -> List[Dict]: """Load CSV file""" df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter) return df.to_dict('records') def _load_json(self, file_path: Path, config: ClassificationConfig) -> List[Dict]: """Load JSON file""" with open(file_path, 'r', encoding=config.encoding) as f: data = json.load(f) if isinstance(data, list): return data elif isinstance(data, dict) and "data" in data: return data["data"] else: return [data] def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]: """Apply preprocessing steps to all splits separately""" processed_splits = {} logger.info(f"=== PREPROCESSING CUSTOM DATA ===") for split_name, split_data in data.items(): logger.info(f"Processing {split_name} split with {len(split_data)} items...") processed_data = [] processed_count = 0 skipped_count = 0 # Reset debug counter for each split self._debug_count = 0 for i, item in enumerate(split_data): processed_item = self._preprocess_item(item, config) if processed_item is not None: processed_data.append(processed_item) processed_count += 1 else: skipped_count += 1 if skipped_count <= 3: # Log first few skipped items logger.info(f"Skipped item {i} from {split_name}: {item}") processed_splits[split_name] = processed_data logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples") return processed_splits def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]: """Preprocess a single item""" # Extract input and label input_text = item.get(config.input_field, "") label = item.get(config.label_field, "") # Handle None values if input_text is None: input_text = "" if label is None: label = "" # Convert to string if needed input_text = str(input_text) label = str(label) # Clean text if requested if config.clean_text: input_text = self._clean_text(input_text, config) # Check length constraints if len(input_text) < config.min_length or len(input_text) > config.max_length: return None # Create processed item processed_item = { "input": input_text, "label": label } # Add ID if available if config.id_field and config.id_field in item: processed_item["id"] = item[config.id_field] return processed_item def _clean_text(self, text: str, config: ClassificationConfig) -> str: """Clean and normalize text""" if not isinstance(text, str): return "" # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() # Convert to lowercase if requested if config.lowercase: text = text.lower() # Remove special characters if requested if config.remove_special_chars: text = re.sub(r'[^\w\s]', '', text) return text class ClassificationDataPipeline: """Main classification pipeline""" def __init__(self): self.validator = DataValidator() self.hf_loader = HuggingFaceDataLoader() self.custom_loader = CustomDataLoader() def create_config( self, data_source: str, dataset_name: Optional[str] = None, data_path: Optional[str] = None, input_field: str = "text", label_field: str = "label", **kwargs ) -> ClassificationConfig: """Create classification configuration""" return ClassificationConfig( data_source=data_source, dataset_name=dataset_name, data_path=data_path, input_field=input_field, label_field=label_field, **kwargs ) def load_and_preprocess(self, config: ClassificationConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]: """Load and preprocess data""" # Load data if config.data_source == "huggingface": raw_splits = self.hf_loader.load(config) processed_splits = self.hf_loader.preprocess(raw_splits, config) elif config.data_source == "custom": raw_splits = self.custom_loader.load(config) processed_splits = self.custom_loader.preprocess(raw_splits, config) else: raise ValueError(f"Unsupported data source: {config.data_source}") # Validate processed data is_valid, errors = self.validator.validate_classification_data(processed_splits, config, is_processed=True) if not is_valid: logger.error("Data validation failed:") for error in errors: logger.error(f" - {error}") raise ValueError("Data validation failed") # Analyze dataset analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True) return processed_splits, analysis def convert_to_classification_format(self, data: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]: """Convert classification data to standard classification format""" classification_splits = {} for split_name, split_data in data.items(): classification_data = [] for item in split_data: classification_data.append({ "text": item["input"], "label": item["label"] }) classification_splits[split_name] = classification_data return classification_splits def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"): """Save processed data splits to files""" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) for split_name, split_data in data.items(): if format == "jsonl": output_file = output_path / f"{split_name}.jsonl" with open(output_file, 'w', encoding='utf-8') as f: for item in split_data: f.write(json.dumps(item, ensure_ascii=False) + '\n') elif format == "json": output_file = output_path / f"{split_name}.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(split_data, f, ensure_ascii=False, indent=2) elif format == "csv": output_file = output_path / f"{split_name}.csv" df = pd.DataFrame(split_data) df.to_csv(output_file, index=False) logger.info(f"Saved {len(split_data)} samples to {output_file}") def run_pipeline( self, config: ClassificationConfig, output_format: str = "classification", save_splits: bool = True ) -> Dict[str, Any]: """Run complete classification pipeline""" logger.info("Starting classification pipeline...") # Load and preprocess data processed_splits, analysis = self.load_and_preprocess(config) # Convert to desired output format if output_format == "classification": formatted_splits = self.convert_to_classification_format(processed_splits) else: formatted_splits = processed_splits # Save data if requested if save_splits: output_dir = Path(config.output_dir) / output_format self.save_data(formatted_splits, str(output_dir)) # Create result summary result = { "config": config, "analysis": analysis, "splits": { split_name: len(split_data) for split_name, split_data in formatted_splits.items() }, "output_format": output_format, "output_dir": config.output_dir, "data": formatted_splits # Include the actual processed data } logger.info("Classification pipeline completed successfully!") return result def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig: """Helper function to create a HuggingFace configuration""" return ClassificationConfig( data_source="huggingface", dataset_name=dataset_name, input_field=input_field, label_field=label_field, **kwargs ) def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig: """Helper function to create a custom data configuration""" return ClassificationConfig( data_source="custom", data_path=data_path, data_format=data_format, input_field=input_field, label_field=label_field, **kwargs ) def main(): """Main function with YAML configuration support""" parser = argparse.ArgumentParser(description="Classification Data Processing Pipeline") # YAML configuration parser.add_argument("--config", type=str, help="Path to YAML configuration file") # Data source arguments parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source") parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name") parser.add_argument("--data-path", type=str, help="Path to custom data file") parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format") # Field mapping parser.add_argument("--input-field", type=str, help="Input field name") parser.add_argument("--label-field", type=str, help="Label field name") parser.add_argument("--id-field", type=str, help="Optional ID field name") # Data processing parser.add_argument("--max-samples", type=int, help="Maximum samples to process") parser.add_argument("--train-split", type=float, help="Training split ratio") parser.add_argument("--validation-split", type=float, help="Validation split ratio") parser.add_argument("--test-split", type=float, help="Test split ratio") # Text preprocessing parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text") parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters") parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase") parser.add_argument("--min-length", type=int, help="Minimum text length") parser.add_argument("--max-length", type=int, help="Maximum text length") # Output configuration parser.add_argument("--output-format", choices=["classification", "instruction", "conversation", "qa"], help="Output format") parser.add_argument("--output-dir", type=str, help="Output directory") # Logging parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") args = parser.parse_args() # Set up logging logging.basicConfig( level=getattr(logging, args.log_level), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Load configuration config_dict = {} # Load YAML config if provided if args.config: try: with open(args.config, 'r', encoding='utf-8') as f: config_dict = yaml.safe_load(f) logger.info(f"Loaded YAML configuration from: {args.config}") except Exception as e: logger.error(f"Error loading YAML config: {e}") sys.exit(1) # Override YAML config with CLI arguments cli_overrides = {} if args.data_source: cli_overrides['data_source'] = args.data_source if args.dataset_name: cli_overrides['dataset_name'] = args.dataset_name if args.data_path: cli_overrides['data_path'] = args.data_path if args.data_format: cli_overrides['data_format'] = args.data_format if args.input_field: cli_overrides['input_field'] = args.input_field if args.label_field: cli_overrides['label_field'] = args.label_field if args.id_field: cli_overrides['id_field'] = args.id_field if args.max_samples: cli_overrides['max_samples'] = args.max_samples if args.train_split: cli_overrides['train_split'] = args.train_split if args.validation_split: cli_overrides['validation_split'] = args.validation_split if args.test_split: cli_overrides['test_split'] = args.test_split if args.clean_text: cli_overrides['clean_text'] = True if args.remove_special_chars: cli_overrides['remove_special_chars'] = True if args.lowercase: cli_overrides['lowercase'] = True if args.min_length: cli_overrides['min_length'] = args.min_length if args.max_length: cli_overrides['max_length'] = args.max_length if args.output_format: cli_overrides['output_format'] = args.output_format if args.output_dir: cli_overrides['output_dir'] = args.output_dir # Merge configurations for key, value in cli_overrides.items(): if key in config_dict: logger.info(f"Overriding YAML config '{key}' with CLI value: {value}") config_dict[key] = value # Validate required arguments if not config_dict.get('data', {}).get('source'): parser.error("--data-source is required (either in YAML config or CLI)") if config_dict.get('data', {}).get('source') == "huggingface" and not config_dict.get('data', {}).get('dataset_name'): parser.error("--dataset-name is required for HuggingFace datasets") if config_dict.get('data', {}).get('source') == "custom" and not config_dict.get('data', {}).get('data_path'): parser.error("--data-path is required for custom datasets") # Create configuration object config = ClassificationConfig( data_source=config_dict.get('data', {}).get('source', 'huggingface'), dataset_name=config_dict.get('data', {}).get('dataset_name'), data_path=config_dict.get('data', {}).get('data_path'), data_format=config_dict.get('data', {}).get('data_format', 'jsonl'), input_field=config_dict.get('data', {}).get('input_field', 'text'), label_field=config_dict.get('data', {}).get('label_field', 'label'), id_field=config_dict.get('data', {}).get('id_field'), max_samples=config_dict.get('data', {}).get('max_samples'), train_split=config_dict.get('data', {}).get('train_split', 0.8), validation_split=config_dict.get('data', {}).get('validation_split', 0.1), test_split=config_dict.get('data', {}).get('test_split', 0.1), clean_text=config_dict.get('data', {}).get('clean_text', True), remove_special_chars=config_dict.get('data', {}).get('remove_special_chars', False), lowercase=config_dict.get('data', {}).get('lowercase', True), min_length=config_dict.get('data', {}).get('min_length', 10), max_length=config_dict.get('data', {}).get('max_length', 1000), label_encoding=config_dict.get('data', {}).get('label_encoding', 'auto'), multilabel=config_dict.get('data', {}).get('multilabel', False), label_separator=config_dict.get('data', {}).get('label_separator', ','), output_format=config_dict.get('data', {}).get('output_format', 'classification'), output_dir=config_dict.get('data', {}).get('output_dir', './data'), hf_split=config_dict.get('data', {}).get('hf_split', 'train'), hf_cache_dir=config_dict.get('data', {}).get('hf_cache_dir'), test_split_from=config_dict.get('data', {}).get('test_split_from', 'train'), val_split_from=config_dict.get('data', {}).get('val_split_from', 'train'), encoding=config_dict.get('data', {}).get('encoding', 'utf-8'), delimiter=config_dict.get('data', {}).get('delimiter', ',') ) # Initialize pipeline pipeline = ClassificationDataPipeline() try: print(f"Starting classification pipeline with {config.data_source} data source...") if args.config: print(f"Using YAML configuration: {args.config}") print() result = pipeline.run_pipeline(config, config.output_format, save_splits=True) print(f"✅ Pipeline completed successfully!") print(f" Data source: {config.data_source}") if config.data_source == "huggingface": print(f" Dataset: {config.dataset_name}") else: print(f" Data file: {config.data_path}") print(f" Total samples: {result['analysis']['overall']['total_samples']}") print(f" Unique labels: {result['analysis']['overall']['all_unique_labels']}") print(f" Split sizes: {result['analysis']['overall']['split_sizes']}") print(f" Output directory: {config.output_dir}") except Exception as e: print(f"❌ Error running pipeline: {e}") sys.exit(1) if __name__ == "__main__": main()