import json import pandas as pd import numpy as np from pathlib import Path from typing import Dict, List, Optional, Union, Any, Tuple from datasets import Dataset, load_dataset import os from dataclasses import dataclass from abc import ABC, abstractmethod import logging from sklearn.model_selection import train_test_split import re import argparse import sys import yaml logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @dataclass class StylingConfig: """Configuration for styling tasks""" # Data source configuration data_source: str = "huggingface" # "huggingface" or "custom" dataset_name: Optional[str] = None # For Hugging Face datasets data_path: Optional[str] = None # For custom datasets data_format: str = "jsonl" # jsonl, csv, json # Field mapping - User configures which fields map to input/output input_field: str = "text" # Field in dataset containing source text (e.g., "text", "source", etc.) output_field: str = "styled_text" # Field in dataset containing styled text (e.g., "styled_text", "target", etc.) instruction: str = "Rewrite the following text in a formal style" # Style instruction from YAML # Data processing max_samples: Optional[int] = None train_split: float = 0.8 validation_split: float = 0.1 test_split: float = 0.1 # Text preprocessing clean_text: bool = True remove_special_chars: bool = False lowercase: bool = False # Keep original case for styling min_length: int = 10 max_length: int = 1000 # Output configuration output_format: str = "styling" # instruction, conversation, qa output_dir: str = "./data" # Hugging Face specific hf_split: str = "train" hf_cache_dir: Optional[str] = None # Split configuration test_split_from: str = "train" val_split_from: str = "train" # Custom data specific encoding: str = "utf-8" delimiter: str = "," # For CSV files # Alpaca prompt configuration alpaca_prompt: str = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction ### Instruction: {} ### Input: {} ### Response: {}""" eos_token: str = "<|eot_id|>" # Use <|eot_id|> as EOS token class DataValidator: """Validates styling data quality and format""" @staticmethod def validate_styling_data(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Tuple[bool, List[str]]: """Validate styling dataset splits""" errors = [] # Check if we have the expected splits expected_splits = ["train", "validation", "test"] for split in expected_splits: if split not in data: errors.append(f"Missing '{split}' split") elif split == "train" and not data[split]: errors.append(f"Train split cannot be empty") # Allow validation and test splits to be empty for small datasets if errors: return False, errors total_samples = sum(len(split_data) for split_data in data.values()) logger.info(f"Validating {total_samples} total samples across all splits...") # Determine field names based on whether data is processed or not input_field = "input" if is_processed else config.input_field output_field = "output" if is_processed else config.output_field # Validate each split for split_name, split_data in data.items(): if not split_data: logger.info(f"Skipping validation for empty {split_name} split") continue logger.info(f"Validating {split_name} split with {len(split_data)} samples...") # Check required fields missing_input_count = 0 missing_output_count = 0 for i, item in enumerate(split_data): if input_field not in item: errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}") missing_input_count += 1 if output_field not in item: errors.append(f"Missing output field '{output_field}' in {split_name} split, item {i}") missing_output_count += 1 logger.info(f"{split_name} - Items missing input field: {missing_input_count}") logger.info(f"{split_name} - Items missing output field: {missing_output_count}") # Check data types type_errors = 0 for i, item in enumerate(split_data): if not isinstance(item.get(input_field, ""), str): errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}") type_errors += 1 if not isinstance(item.get(output_field, ""), str): errors.append(f"Output field '{output_field}' must be string in {split_name} split, item {i}") type_errors += 1 logger.info(f"{split_name} - Type errors: {type_errors}") # Check for empty inputs/outputs empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip()) empty_outputs = sum(1 for item in split_data if not item.get(output_field, "").strip()) if empty_inputs > 0: errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split") if empty_outputs > 0: errors.append(f"Found {empty_outputs} items with empty output text in {split_name} split") logger.info(f"{split_name} - Empty inputs: {empty_inputs}") logger.info(f"{split_name} - Empty outputs: {empty_outputs}") # Show sample of processed data for debugging if split_data: logger.info(f"Sample processed items from {split_name}:") for i in range(min(3, len(split_data))): item = split_data[i] logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', output='{item.get(output_field, '')[:50]}...'") return len(errors) == 0, errors @staticmethod def analyze_dataset(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Dict[str, Any]: """Analyze dataset characteristics across all splits""" analysis = { "splits": {}, "overall": { "total_samples": 0, "split_sizes": {} } } # Determine field names based on whether data is processed or not input_field = "input" if is_processed else config.input_field output_field = "output" if is_processed else config.output_field # Analyze each split for split_name, split_data in data.items(): if not split_data: # Handle empty splits split_analysis = { "total_samples": 0, "text_length_stats": {}, "missing_values": {} } analysis["splits"][split_name] = split_analysis analysis["overall"]["split_sizes"][split_name] = 0 continue split_analysis = { "total_samples": len(split_data), "text_length_stats": {}, "missing_values": {} } # Text length statistics for both input and output for field_name, field in [("input", input_field), ("output", output_field)]: text_lengths = [len(item.get(field, "")) for item in split_data] if text_lengths: split_analysis["text_length_stats"][field_name] = { "min": min(text_lengths), "max": max(text_lengths), "mean": np.mean(text_lengths), "median": np.median(text_lengths) } # Missing values for field in [input_field, output_field]: missing_count = sum(1 for item in split_data if not item.get(field)) split_analysis["missing_values"][field] = missing_count analysis["splits"][split_name] = split_analysis analysis["overall"]["total_samples"] += len(split_data) analysis["overall"]["split_sizes"][split_name] = len(split_data) return analysis class BaseDataLoader(ABC): """Abstract base class for data loaders""" @abstractmethod def load(self, config: StylingConfig) -> Dict[str, List[Dict]]: """Load data and return dictionary with train/val/test splits""" pass @abstractmethod def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]: """Apply preprocessing steps to all splits""" pass class HuggingFaceDataLoader(BaseDataLoader): """Load datasets from Hugging Face Hub""" def load(self, config: StylingConfig) -> Dict[str, List[Dict]]: """Load dataset from Hugging Face Hub with flexible split handling""" if not config.dataset_name: raise ValueError("Dataset name is required for Hugging Face datasets") logger.info(f"Loading Hugging Face dataset: {config.dataset_name}") try: # First, let's check what splits are available in the dataset dataset = load_dataset( config.dataset_name, cache_dir=config.hf_cache_dir ) # Log available splits available_splits = list(dataset.keys()) logger.info(f"Available splits in dataset: {available_splits}") # Initialize split data splits_data = { "train": [], "validation": [], "test": [] } # Handle train split if "train" in available_splits: train_dataset = dataset["train"] logger.info(f"Using 'train' split with {len(train_dataset)} samples") splits_data["train"] = list(train_dataset) else: logger.error("No 'train' split found in dataset!") logger.error(f"Available splits: {available_splits}") raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split") # Handle validation split if config.val_split_from == "use_val_if_available" and "validation" in available_splits: val_dataset = dataset["validation"] logger.info(f"Using 'validation' split with {len(val_dataset)} samples") splits_data["validation"] = list(val_dataset) elif config.val_split_from == "use_val_if_available" and "val" in available_splits: val_dataset = dataset["val"] logger.info(f"Using 'val' split with {len(val_dataset)} samples") splits_data["validation"] = list(val_dataset) elif config.val_split_from == "use_val_if_available": logger.warning("No validation split found in dataset. Will create from train split.") logger.info(f"Available splits: {available_splits}") logger.info(f"Will use {config.validation_split * 100}% of train data for validation") else: logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)") # Handle test split if config.test_split_from == "use_test_if_available" and "test" in available_splits: test_dataset = dataset["test"] logger.info(f"Using 'test' split with {len(test_dataset)} samples") splits_data["test"] = list(test_dataset) elif config.test_split_from == "use_val_if_available" and "validation" in available_splits: test_dataset = dataset["validation"] logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples") splits_data["test"] = list(test_dataset) elif config.test_split_from == "use_val_if_available" and "val" in available_splits: test_dataset = dataset["val"] logger.info(f"Using 'val' split as test with {len(test_dataset)} samples") splits_data["test"] = list(test_dataset) elif config.test_split_from == "use_test_if_available": logger.warning("No test split found in dataset. Will create from train split.") logger.info(f"Available splits: {available_splits}") logger.info(f"Will use {config.test_split * 100}% of train data for test") else: logger.info(f"Will create test split from train data ({config.test_split * 100}%)") # If we need to create splits from train data if not splits_data["validation"] or not splits_data["test"]: train_data = splits_data["train"] # Handle very small datasets if len(train_data) < 3: logger.warning(f"Dataset has only {len(train_data)} samples. Using all data for training.") splits_data["train"] = train_data splits_data["validation"] = [] splits_data["test"] = [] else: # Calculate remaining percentages for train total_train_percentage = config.train_split + config.validation_split + config.test_split if total_train_percentage != 1.0: logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...") # Normalize percentages config.train_split = config.train_split / total_train_percentage config.validation_split = config.validation_split / total_train_percentage config.test_split = config.test_split / total_train_percentage # Create splits from train data if not splits_data["validation"] and not splits_data["test"]: # Split train into train, val, test train_size = int(len(train_data) * config.train_split) val_size = int(len(train_data) * config.validation_split) # Handle small datasets if len(train_data) < 10: # For small datasets, use more conservative splits config.train_split = 0.6 config.validation_split = 0.2 config.test_split = 0.2 logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}") # Ensure minimum sizes min_val_size = max(1, int(len(train_data) * 0.1)) min_test_size = max(1, int(len(train_data) * 0.1)) val_size = max(min_val_size, int(len(train_data) * config.validation_split)) test_size = max(min_test_size, int(len(train_data) * config.test_split)) train_size = len(train_data) - val_size - test_size # Ensure train has at least 1 sample if train_size < 1: if val_size > 1: val_size -= 1 train_size += 1 elif test_size > 1: test_size -= 1 train_size += 1 logger.info(f"Adjusted split sizes: train={train_size}, val={val_size}, test={test_size}") # First split: train + (val+test) new_train, temp_data = train_test_split( train_data, test_size=val_size + test_size, random_state=42 ) # Second split: val + test new_val, new_test = train_test_split( temp_data, test_size=test_size / (val_size + test_size) if (val_size + test_size) > 0 else 0, random_state=42 ) splits_data["train"] = new_train splits_data["validation"] = new_val splits_data["test"] = new_test elif not splits_data["validation"]: # Only need to create val from train val_size = max(1, int(len(train_data) * config.validation_split)) new_train, new_val = train_test_split( train_data, test_size=val_size, random_state=42 ) splits_data["train"] = new_train splits_data["validation"] = new_val elif not splits_data["test"]: # Only need to create test from train test_size = max(1, int(len(train_data) * config.test_split)) new_train, new_test = train_test_split( train_data, test_size=test_size, random_state=42 ) splits_data["train"] = new_train splits_data["test"] = new_test logger.info(f"Final split sizes:") logger.info(f" Train: {len(splits_data['train'])} samples") logger.info(f" Validation: {len(splits_data['validation'])} samples") logger.info(f" Test: {len(splits_data['test'])} samples") # Ensure all splits exist (even if empty) for the pipeline if "validation" not in splits_data: splits_data["validation"] = [] if "test" not in splits_data: splits_data["test"] = [] # Apply max_samples limit to each split if specified if config.max_samples: for split_name in splits_data: if splits_data[split_name]: original_size = len(splits_data[split_name]) splits_data[split_name] = splits_data[split_name][:config.max_samples] logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples") # Log dataset info for debugging for split_name, split_data in splits_data.items(): if split_data: logger.info(f"Sample data item from {split_name}: {split_data[0]}") logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}") # Check if the required fields exist if config.input_field not in split_data[0]: logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}") # Suggest alternative fields text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])] if text_fields: logger.info(f"Suggested text fields for {split_name}: {text_fields}") if config.output_field not in split_data[0]: logger.warning(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}") # Suggest alternative fields output_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['output', 'response', 'result', 'target', 'styled'])] if output_fields: logger.info(f"Suggested output fields for {split_name}: {output_fields}") logger.info(f"Successfully loaded dataset {config.dataset_name}") return splits_data except Exception as e: logger.error(f"Error loading dataset {config.dataset_name}: {e}") raise def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]: """Apply preprocessing steps to all splits separately""" processed_splits = {} logger.info(f"=== PREPROCESSING DATA ===") for split_name, split_data in data.items(): logger.info(f"Processing {split_name} split with {len(split_data)} items...") # Log field availability for debugging if split_data: available_fields = set(split_data[0].keys()) logger.info(f"Available fields in {split_name}: {available_fields}") logger.info(f"Looking for input field: '{config.input_field}', output field: '{config.output_field}'") if config.input_field not in available_fields: logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}") if config.output_field not in available_fields: logger.error(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {available_fields}") # Count items with missing fields missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field)) missing_output = sum(1 for item in split_data if config.output_field not in item or not item.get(config.output_field)) logger.info(f"{split_name} - Items missing input field: {missing_input}") logger.info(f"{split_name} - Items missing output field: {missing_output}") # Show sample of raw data before preprocessing logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===") for i in range(min(3, len(split_data))): item = split_data[i] logger.info(f"Raw item {i} from {split_name}:") for key, value in item.items(): if isinstance(value, str) and len(value) > 100: logger.info(f" {key}: '{value[:100]}...'") else: logger.info(f" {key}: {value}") # Process each item in the split processed_data = [] processed_count = 0 skipped_count = 0 # Reset debug counter for each split self._debug_count = 0 for i, item in enumerate(split_data): processed_item = self._preprocess_item(item, config) if processed_item is not None: processed_data.append(processed_item) processed_count += 1 else: skipped_count += 1 if skipped_count <= 3: # Log first few skipped items logger.info(f"Skipped item {i} from {split_name}: {item}") processed_splits[split_name] = processed_data logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples") # Show sample of processed data if processed_data: logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===") for i in range(min(3, len(processed_data))): logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}") return processed_splits def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]: """Preprocess a single item""" # Extract input and output using configurable field names input_text = item.get(config.input_field, "") output_text = item.get(config.output_field, "") # Log what we're extracting (for first few items) if hasattr(self, '_debug_count'): self._debug_count += 1 else: self._debug_count = 1 if self._debug_count <= 3: logger.debug(f"Processing item {self._debug_count}:") logger.debug(f" Looking for input field '{config.input_field}': {input_text}") logger.debug(f" Looking for output field '{config.output_field}': {output_text}") # Handle None values if input_text is None: input_text = "" if output_text is None: output_text = "" # Convert to string if needed input_text = str(input_text) output_text = str(output_text) if self._debug_count <= 3: logger.debug(f" After conversion - input: '{input_text[:50]}...', output: '{output_text[:50]}...'") # Clean text if requested if config.clean_text: original_input = input_text original_output = output_text input_text = self._clean_text(input_text, config) output_text = self._clean_text(output_text, config) if self._debug_count <= 3: logger.debug(f" After cleaning - input: '{original_input[:50]}...' -> '{input_text[:50]}...'") logger.debug(f" After cleaning - output: '{original_output[:50]}...' -> '{output_text[:50]}...'") # Check length constraints if len(input_text) < config.min_length or len(input_text) > config.max_length: if self._debug_count <= 3: logger.debug(f" Skipping - input length {len(input_text)} not in range [{config.min_length}, {config.max_length}]") return None if len(output_text) < config.min_length or len(output_text) > config.max_length: if self._debug_count <= 3: logger.debug(f" Skipping - output length {len(output_text)} not in range [{config.min_length}, {config.max_length}]") return None # Create processed item - Always use "input" and "output" for internal processing processed_item = { "input": input_text, "output": output_text } if self._debug_count <= 3: logger.debug(f" Final processed item: {processed_item}") return processed_item def _clean_text(self, text: str, config: StylingConfig) -> str: """Clean and normalize text""" if not isinstance(text, str): return "" # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() # Convert to lowercase if requested if config.lowercase: text = text.lower() # Remove special characters if requested if config.remove_special_chars: text = re.sub(r'[^\w\s]', '', text) return text class CustomDataLoader(BaseDataLoader): """Load custom datasets from local files""" def load(self, config: StylingConfig) -> Dict[str, List[Dict]]: """Load custom dataset from local file and create splits""" if not config.data_path: raise ValueError("Data path is required for custom datasets") file_path = Path(config.data_path) if not file_path.exists(): raise FileNotFoundError(f"Data file not found: {file_path}") logger.info(f"Loading custom dataset: {file_path}") if config.data_format == "jsonl": raw_data = self._load_jsonl(file_path, config) elif config.data_format == "csv": raw_data = self._load_csv(file_path, config) elif config.data_format == "json": raw_data = self._load_json(file_path, config) else: raise ValueError(f"Unsupported format: {config.data_format}") if config.max_samples: raw_data = raw_data[:config.max_samples] logger.info(f"Loaded {len(raw_data)} samples from {file_path}") # Create splits from the raw data splits_data = self._create_splits(raw_data, config) return splits_data def _create_splits(self, data: List[Dict], config: StylingConfig) -> Dict[str, List[Dict]]: """Create train/validation/test splits from raw data""" logger.info(f"Creating splits from {len(data)} samples...") # Handle very small datasets if len(data) < 3: logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.") return { "train": data, "validation": [], "test": [] } # Calculate split sizes with minimum guarantees total_samples = len(data) # Ensure minimum sizes for each split min_val_size = max(1, int(total_samples * 0.1)) # At least 1 sample for validation min_test_size = max(1, int(total_samples * 0.1)) # At least 1 sample for test # Adjust split ratios if dataset is too small if total_samples < 10: # For small datasets, use more conservative splits config.train_split = 0.6 config.validation_split = 0.2 config.test_split = 0.2 logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}") # Calculate actual split sizes val_size = max(min_val_size, int(total_samples * config.validation_split)) test_size = max(min_test_size, int(total_samples * config.test_split)) train_size = total_samples - val_size - test_size # Ensure train split has at least 1 sample if train_size < 1: # Adjust validation and test to ensure train has at least 1 sample if val_size > 1: val_size -= 1 train_size += 1 elif test_size > 1: test_size -= 1 train_size += 1 logger.info(f"Adjusted split sizes to ensure train has at least 1 sample: train={train_size}, val={val_size}, test={test_size}") logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}") # Create splits if val_size == 0 and test_size == 0: # All data goes to train splits_data = { "train": data, "validation": [], "test": [] } elif val_size == 0: # Split between train and test train_data, test_data = train_test_split(data, test_size=test_size, random_state=42) splits_data = { "train": train_data, "validation": [], "test": test_data } elif test_size == 0: # Split between train and validation train_data, val_data = train_test_split(data, test_size=val_size, random_state=42) splits_data = { "train": train_data, "validation": val_data, "test": [] } else: # Full three-way split # First split: train + (val+test) train_data, temp_data = train_test_split( data, test_size=val_size + test_size, random_state=42 ) # Second split: val + test val_data, test_data = train_test_split( temp_data, test_size=test_size, random_state=42 ) splits_data = { "train": train_data, "validation": val_data, "test": test_data } logger.info(f"Created splits:") logger.info(f" Train: {len(splits_data['train'])} samples") logger.info(f" Validation: {len(splits_data['validation'])} samples") logger.info(f" Test: {len(splits_data['test'])} samples") return splits_data def _load_jsonl(self, file_path: Path, config: StylingConfig) -> List[Dict]: """Load JSONL file""" data = [] with open(file_path, 'r', encoding=config.encoding) as f: for line_num, line in enumerate(f, 1): if line.strip(): try: data.append(json.loads(line)) except json.JSONDecodeError as e: logger.warning(f"Invalid JSON at line {line_num}: {e}") return data def _load_csv(self, file_path: Path, config: StylingConfig) -> List[Dict]: """Load CSV file""" df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter) return df.to_dict('records') def _load_json(self, file_path: Path, config: StylingConfig) -> List[Dict]: """Load JSON file""" with open(file_path, 'r', encoding=config.encoding) as f: data = json.load(f) if isinstance(data, list): return data elif isinstance(data, dict) and "data" in data: return data["data"] else: return [data] def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]: """Apply preprocessing steps to all splits separately""" processed_splits = {} logger.info(f"=== PREPROCESSING CUSTOM DATA ===") for split_name, split_data in data.items(): logger.info(f"Processing {split_name} split with {len(split_data)} items...") processed_data = [] processed_count = 0 skipped_count = 0 # Reset debug counter for each split self._debug_count = 0 for i, item in enumerate(split_data): processed_item = self._preprocess_item(item, config) if processed_item is not None: processed_data.append(processed_item) processed_count += 1 else: skipped_count += 1 if skipped_count <= 3: # Log first few skipped items logger.info(f"Skipped item {i} from {split_name}: {item}") processed_splits[split_name] = processed_data logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples") return processed_splits def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]: """Preprocess a single item""" # Extract input and output using configurable field names input_text = item.get(config.input_field, "") output_text = item.get(config.output_field, "") # Handle None values if input_text is None: input_text = "" if output_text is None: output_text = "" # Convert to string if needed input_text = str(input_text) output_text = str(output_text) # Clean text if requested if config.clean_text: input_text = self._clean_text(input_text, config) output_text = self._clean_text(output_text, config) # Check length constraints if len(input_text) < config.min_length or len(input_text) > config.max_length: return None if len(output_text) < config.min_length or len(output_text) > config.max_length: return None # Create processed item - Always use "input" and "output" for internal processing processed_item = { "input": input_text, "output": output_text } return processed_item def _clean_text(self, text: str, config: StylingConfig) -> str: """Clean and normalize text""" if not isinstance(text, str): return "" # Remove extra whitespace text = re.sub(r'\s+', ' ', text).strip() # Convert to lowercase if requested if config.lowercase: text = text.lower() # Remove special characters if requested if config.remove_special_chars: text = re.sub(r'[^\w\s]', '', text) return text class StylingDataPipeline: """Main styling pipeline""" def __init__(self): self.validator = DataValidator() self.hf_loader = HuggingFaceDataLoader() self.custom_loader = CustomDataLoader() def create_config( self, data_source: str, dataset_name: Optional[str] = None, data_path: Optional[str] = None, input_field: str = "input", output_field: str = "output", instruction: str = "Rewrite the following text in a formal style", **kwargs ) -> StylingConfig: """Create styling configuration""" return StylingConfig( data_source=data_source, dataset_name=dataset_name, data_path=data_path, input_field=input_field, output_field=output_field, instruction=instruction, **kwargs ) def load_config_from_yaml(self, yaml_path: str) -> StylingConfig: """Load configuration from YAML file""" try: config_dict = load_yaml_config(yaml_path) # Create configuration object from YAML data config = StylingConfig( data_source=config_dict.get('data_source', 'custom'), dataset_name=config_dict.get('dataset_name'), data_path=config_dict.get('data_path'), data_format=config_dict.get('data_format', 'jsonl'), input_field=config_dict.get('input_field', 'text'), output_field=config_dict.get('output_field', 'styled_text'), instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'), max_samples=config_dict.get('max_samples'), train_split=config_dict.get('train_split', 0.8), validation_split=config_dict.get('validation_split', 0.1), test_split=config_dict.get('test_split', 0.1), clean_text=config_dict.get('clean_text', True), remove_special_chars=config_dict.get('remove_special_chars', False), lowercase=config_dict.get('lowercase', False), min_length=config_dict.get('min_length', 10), max_length=config_dict.get('max_length', 1000), output_format=config_dict.get('output_format', 'styling'), output_dir=config_dict.get('output_dir', './data'), hf_split=config_dict.get('hf_split', 'train'), hf_cache_dir=config_dict.get('hf_cache_dir'), test_split_from=config_dict.get('test_split_from', 'train'), val_split_from=config_dict.get('val_split_from', 'train'), encoding=config_dict.get('encoding', 'utf-8'), delimiter=config_dict.get('delimiter', ',') ) logger.info(f"Configuration loaded from YAML: {yaml_path}") logger.info(f"Output directory: {config.output_dir}") logger.info(f"Instruction: {config.instruction}") return config except Exception as e: logger.error(f"Error loading configuration from YAML {yaml_path}: {e}") raise def load_and_preprocess(self, config: StylingConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]: """Load and preprocess data""" # Load data if config.data_source == "huggingface": raw_splits = self.hf_loader.load(config) processed_splits = self.hf_loader.preprocess(raw_splits, config) elif config.data_source == "custom": raw_splits = self.custom_loader.load(config) processed_splits = self.custom_loader.preprocess(raw_splits, config) else: raise ValueError(f"Unsupported data source: {config.data_source}") # Validate processed data is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True) if not is_valid: logger.error("Data validation failed:") for error in errors: logger.error(f" - {error}") raise ValueError("Data validation failed") # Analyze dataset analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True) return processed_splits, analysis def convert_to_alpaca_format(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]: """Convert styling data to Alpaca format with instruction""" alpaca_splits = {} for split_name, split_data in data.items(): alpaca_data = [] for item in split_data: # Ensure input and output fields exist, default to empty string if missing input_text = item.get("input", "") output_text = item.get("output", "") # Handle None values if input_text is None: input_text = "" if output_text is None: output_text = "" # Convert to string if needed input_text = str(input_text) output_text = str(output_text) alpaca_data.append({ "instruction": config.instruction, "input": input_text, "output": output_text }) alpaca_splits[split_name] = alpaca_data return alpaca_splits def format_for_training(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[str]]: """Format entries for training using Alpaca prompt format""" formatted_splits = {} for split_name, split_data in data.items(): formatted_texts = [] for item in split_data: # Ensure input and output fields exist, default to empty string if missing input_text = item.get("input", "") output_text = item.get("output", "") # Handle None values if input_text is None: input_text = "" if output_text is None: output_text = "" # Convert to string if needed input_text = str(input_text) output_text = str(output_text) text = config.alpaca_prompt.format( config.instruction, input_text, output_text ) + config.eos_token formatted_texts.append(text) formatted_splits[split_name] = formatted_texts return formatted_splits def convert_to_hf_dataset(self, dataset_entries: List[Dict], config: StylingConfig): """Convert dataset entries to HuggingFace dataset format with text formatting""" from datasets import Dataset # Create HuggingFace dataset from list of dictionaries hf_dataset = Dataset.from_list(dataset_entries) # Apply formatting function to generate the text field def formatting_prompts_func(examples): instructions = examples["instruction"] inputs = examples["input"] outputs = examples["output"] texts = [] for instruction, input_text, output in zip(instructions, inputs, outputs): # Handle None values and ensure strings if input_text is None: input_text = "" if output is None: output = "" # Convert to string if needed input_text = str(input_text) output = str(output) # Use the config's EOS token and alpaca prompt text = config.alpaca_prompt.format(instruction, input_text, output) + config.eos_token texts.append(text) return {"text": texts} # Apply the formatting function formatted_dataset = hf_dataset.map(formatting_prompts_func, batched=True) return formatted_dataset def save_hf_dataset_to_disk(self, hf_dataset, save_path: str): """Save HuggingFace dataset to disk""" try: hf_dataset.save_to_disk(save_path) logger.info(f"HuggingFace dataset saved to disk at: {save_path}") return True except Exception as e: logger.error(f"Error saving HuggingFace dataset to disk: {e}") return False def load_hf_dataset_from_disk(self, load_path: str): """Load HuggingFace dataset from disk""" try: from datasets import load_from_disk hf_dataset = load_from_disk(load_path) logger.info(f"HuggingFace dataset loaded from disk: {load_path}") logger.info(f"Dataset has {len(hf_dataset)} entries") logger.info(f"Dataset features: {hf_dataset.features}") return hf_dataset except Exception as e: logger.error(f"Error loading HuggingFace dataset from disk: {e}") return None def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"): """Save processed data splits to files""" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) for split_name, split_data in data.items(): if format == "jsonl": output_file = output_path / f"{split_name}.jsonl" with open(output_file, 'w', encoding='utf-8') as f: for item in split_data: f.write(json.dumps(item, ensure_ascii=False) + '\n') elif format == "json": output_file = output_path / f"{split_name}.json" with open(output_file, 'w', encoding='utf-8') as f: json.dump(split_data, f, ensure_ascii=False, indent=2) elif format == "csv": output_file = output_path / f"{split_name}.csv" df = pd.DataFrame(split_data) df.to_csv(output_file, index=False) logger.info(f"Saved {len(split_data)} samples to {output_file}") def run_pipeline( self, config: StylingConfig, output_format: str = "styling", save_splits: bool = True, create_hf_dataset: bool = False, save_hf_dataset: bool = False, hf_dataset_path: str = None ) -> Dict[str, Any]: """Run complete styling pipeline""" logger.info("Starting styling pipeline...") # Load and preprocess data processed_splits, analysis = self.load_and_preprocess(config) # Convert to desired output format if output_format == "alpaca": formatted_splits = self.convert_to_alpaca_format(processed_splits, config) else: formatted_splits = processed_splits # Save data if requested if save_splits: # Save directly in the output directory, not in a subdirectory output_dir = Path(config.output_dir) self.save_data(formatted_splits, str(output_dir)) # Convert to HuggingFace dataset if requested hf_dataset = None hf_dataset_save_path = None if create_hf_dataset: # Flatten all splits into one list for HF dataset all_entries = [] for split_name, split_data in formatted_splits.items(): for item in split_data: # Ensure we have the instruction field if "instruction" not in item: item["instruction"] = config.instruction all_entries.append(item) hf_dataset = self.convert_to_hf_dataset(all_entries, config) logger.info(f"HuggingFace dataset created with {len(hf_dataset)} entries") logger.info(f"Dataset features: {hf_dataset.features}") # Save HuggingFace dataset to disk if requested if save_hf_dataset: if hf_dataset_path is None: # Generate default path using the YAML output_dir hf_dataset_path = str(Path(config.output_dir) / "hf_dataset") success = self.save_hf_dataset_to_disk(hf_dataset, hf_dataset_path) if success: hf_dataset_save_path = hf_dataset_path logger.info(f"HuggingFace dataset saved to: {hf_dataset_save_path}") else: logger.warning("Failed to save HuggingFace dataset to disk") # Create result summary result = { "config": config, "analysis": analysis, "splits": { split_name: len(split_data) for split_name, split_data in formatted_splits.items() }, "output_format": output_format, "output_dir": config.output_dir, "data": formatted_splits, # Include the actual processed data "instruction": config.instruction } # Add HuggingFace dataset info to result if created if hf_dataset is not None: result["hf_dataset"] = hf_dataset if hf_dataset_save_path: result["hf_dataset_path"] = hf_dataset_save_path logger.info("Styling pipeline completed successfully!") return result # Helper functions def create_huggingface_config(dataset_name: str, input_field: str = "text", output_field: str = "output", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig: """Helper function to create a HuggingFace configuration""" return StylingConfig( data_source="huggingface", dataset_name=dataset_name, input_field=input_field, output_field=output_field, instruction=instruction, **kwargs ) def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", output_field: str = "styled_text", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig: """Helper function to create a custom data configuration""" return StylingConfig( data_source="custom", data_path=data_path, data_format=data_format, input_field=input_field, output_field=output_field, instruction=instruction, **kwargs ) def save_hf_dataset_to_disk(hf_dataset, save_path: str) -> bool: """Utility function to save HuggingFace dataset to disk""" try: hf_dataset.save_to_disk(save_path) print(f"HuggingFace dataset saved to disk at: {save_path}") return True except Exception as e: print(f"Error saving HuggingFace dataset to disk: {e}") return False def load_hf_dataset_from_disk(load_path: str): """Utility function to load HuggingFace dataset from disk""" try: from datasets import load_from_disk hf_dataset = load_from_disk(load_path) print(f"HuggingFace dataset loaded from disk: {load_path}") print(f"Dataset has {len(hf_dataset)} entries") print(f"Dataset features: {hf_dataset.features}") return hf_dataset except Exception as e: print(f"Error loading HuggingFace dataset from disk: {e}") return None def load_yaml_config(config_path: str) -> Dict[str, Any]: """Load and parse YAML configuration file with proper structure handling""" try: with open(config_path, 'r', encoding='utf-8') as f: yaml_data = yaml.safe_load(f) # Extract configuration from YAML structure config_dict = {} # Handle task section if 'task' in yaml_data: task_data = yaml_data['task'] config_dict.update({ 'task_name': task_data.get('name'), 'task_type': task_data.get('type') }) # Handle data section if 'data' in yaml_data: data_config = yaml_data['data'] config_dict.update({ 'data_source': data_config.get('source'), 'dataset_name': data_config.get('dataset_name'), 'data_path': data_config.get('data_path'), 'data_format': data_config.get('data_format'), 'input_field': data_config.get('input_field'), 'output_field': data_config.get('output_field'), 'instruction': data_config.get('instruction'), 'max_samples': data_config.get('max_samples'), 'train_split': data_config.get('train_split'), 'validation_split': data_config.get('validation_split'), 'test_split': data_config.get('test_split'), 'clean_text': data_config.get('clean_text'), 'lowercase': data_config.get('lowercase'), 'min_length': data_config.get('min_length'), 'max_length': data_config.get('max_length'), 'output_format': data_config.get('output_format'), 'output_dir': data_config.get('output_dir'), 'encoding': data_config.get('encoding'), 'delimiter': data_config.get('delimiter') }) # Handle model section if 'model' in yaml_data: model_data = yaml_data['model'] config_dict.update({ 'model_name': model_data.get('name'), 'model_max_length': model_data.get('max_length') }) # Handle training section if 'training' in yaml_data: training_data = yaml_data['training'] config_dict.update({ 'num_epochs': training_data.get('num_epochs'), 'batch_size': training_data.get('batch_size'), 'learning_rate': training_data.get('learning_rate'), 'weight_decay': training_data.get('weight_decay'), 'warmup_ratio': training_data.get('warmup_ratio'), 'lr_scheduler_type': training_data.get('lr_scheduler_type') }) # Handle inference section if 'inference' in yaml_data: inference_data = yaml_data['inference'] config_dict.update({ 'inference_batch_size': inference_data.get('batch_size'), 'max_new_tokens': inference_data.get('max_new_tokens'), 'temperature': inference_data.get('temperature') }) logger.info(f"Successfully parsed YAML configuration from: {config_path}") logger.info(f"Extracted {len(config_dict)} configuration parameters") return config_dict except Exception as e: logger.error(f"Error loading YAML config from {config_path}: {e}") raise def main(): """Main function with YAML configuration support""" parser = argparse.ArgumentParser(description="Styling Data Processing Pipeline") # YAML configuration parser.add_argument("--config", type=str, help="Path to YAML configuration file") # Data source arguments parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source") parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name") parser.add_argument("--data-path", type=str, help="Path to custom data file") parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format") # Field mapping parser.add_argument("--input-field", type=str, help="Input field name") parser.add_argument("--output-field", type=str, help="Output field name") parser.add_argument("--instruction", type=str, help="Style instruction") # Data processing parser.add_argument("--max-samples", type=int, help="Maximum samples to process") parser.add_argument("--train-split", type=float, help="Training split ratio") parser.add_argument("--validation-split", type=float, help="Validation split ratio") parser.add_argument("--test-split", type=float, help="Test split ratio") # Text preprocessing parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text") parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters") parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase") parser.add_argument("--min-length", type=int, help="Minimum text length") parser.add_argument("--max-length", type=int, help="Maximum text length") # Output configuration parser.add_argument("--output-format", choices=["styling", "alpaca"], help="Output format") parser.add_argument("--output-dir", type=str, help="Output directory") # HuggingFace dataset options parser.add_argument("--create-hf-dataset", action="store_true", help="Create HuggingFace dataset") parser.add_argument("--hf-dataset-path", type=str, help="Path to save HuggingFace dataset") # Logging parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") args = parser.parse_args() # Set up logging logging.basicConfig( level=getattr(logging, args.log_level), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Load configuration config_dict = {} # Load YAML config if provided if args.config: try: config_dict = load_yaml_config(args.config) except Exception as e: logger.error(f"Error loading YAML config: {e}") sys.exit(1) # Override YAML config with CLI arguments cli_overrides = {} if args.data_source: cli_overrides['data_source'] = args.data_source if args.dataset_name: cli_overrides['dataset_name'] = args.dataset_name if args.data_path: cli_overrides['data_path'] = args.data_path if args.data_format: cli_overrides['data_format'] = args.data_format if args.input_field: cli_overrides['input_field'] = args.input_field if args.output_field: cli_overrides['output_field'] = args.output_field if args.instruction: cli_overrides['instruction'] = args.instruction if args.max_samples: cli_overrides['max_samples'] = args.max_samples if args.train_split: cli_overrides['train_split'] = args.train_split if args.validation_split: cli_overrides['validation_split'] = args.validation_split if args.test_split: cli_overrides['test_split'] = args.test_split if args.clean_text: cli_overrides['clean_text'] = True if args.remove_special_chars: cli_overrides['remove_special_chars'] = True if args.lowercase: cli_overrides['lowercase'] = True if args.min_length: cli_overrides['min_length'] = args.min_length if args.max_length: cli_overrides['max_length'] = args.max_length if args.output_format: cli_overrides['output_format'] = args.output_format if args.output_dir: cli_overrides['output_dir'] = args.output_dir # HuggingFace dataset options if args.create_hf_dataset: cli_overrides['create_hf_dataset'] = True if args.hf_dataset_path: cli_overrides['hf_dataset_path'] = args.hf_dataset_path # Logging if args.log_level: cli_overrides['log_level'] = args.log_level # Merge configurations for key, value in cli_overrides.items(): if key in config_dict: logger.info(f"Overriding YAML config '{key}' with CLI value: {value}") config_dict[key] = value # Validate required arguments if not config_dict.get('data_source'): parser.error("--data-source is required (either in YAML config or CLI)") if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'): parser.error("--dataset-name is required for HuggingFace datasets") if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'): parser.error("--data-path is required for custom datasets") # Create configuration object - properly handle YAML structure config = StylingConfig( data_source=config_dict.get('data_source', 'huggingface'), dataset_name=config_dict.get('dataset_name'), data_path=config_dict.get('data_path'), data_format=config_dict.get('data_format', 'jsonl'), input_field=config_dict.get('input_field', 'text'), output_field=config_dict.get('output_field', 'styled_text'), instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'), max_samples=config_dict.get('max_samples'), train_split=config_dict.get('train_split', 0.8), validation_split=config_dict.get('validation_split', 0.1), test_split=config_dict.get('test_split', 0.1), clean_text=config_dict.get('clean_text', True), remove_special_chars=config_dict.get('remove_special_chars', False), lowercase=config_dict.get('lowercase', False), min_length=config_dict.get('min_length', 10), max_length=config_dict.get('max_length', 1000), output_format=config_dict.get('output_format', 'styling'), output_dir=config_dict.get('output_dir', './data'), hf_split=config_dict.get('hf_split', 'train'), hf_cache_dir=config_dict.get('hf_cache_dir'), test_split_from=config_dict.get('test_split_from', 'train'), val_split_from=config_dict.get('val_split_from', 'train'), encoding=config_dict.get('encoding', 'utf-8'), delimiter=config_dict.get('delimiter', ',') ) # Initialize pipeline pipeline = StylingDataPipeline() try: print(f"Starting styling pipeline with {config.data_source} data source...") if args.config: print(f"Using YAML configuration: {args.config}") print(f"Style instruction: {config.instruction}") print() # Check if we should create HuggingFace dataset create_hf_dataset = cli_overrides.get('create_hf_dataset', False) hf_dataset_path = cli_overrides.get('hf_dataset_path') # If creating HF dataset, also save it by default save_hf_dataset = create_hf_dataset result = pipeline.run_pipeline( config, config.output_format, save_splits=True, create_hf_dataset=create_hf_dataset, save_hf_dataset=save_hf_dataset, hf_dataset_path=hf_dataset_path ) print(f"✅ Pipeline completed successfully!") print(f" Data source: {config.data_source}") if config.data_source == "huggingface": print(f" Dataset: {config.dataset_name}") else: print(f" Data file: {config.data_path}") print(f" Total samples: {result['analysis']['overall']['total_samples']}") print(f" Split sizes: {result['analysis']['overall']['split_sizes']}") print(f" Output directory: {config.output_dir}") print(f" Style instruction: {config.instruction}") except Exception as e: print(f"❌ Error running pipeline: {e}") sys.exit(1) if __name__ == "__main__": main()