Files
2025-08-13 23:50:20 +00:00

1513 lines
65 KiB
Python

import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
from sklearn.model_selection import train_test_split
import re
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@dataclass
class StylingConfig:
"""Configuration for styling tasks"""
# Data source configuration
data_source: str = "huggingface" # "huggingface" or "custom"
dataset_name: Optional[str] = None # For Hugging Face datasets
data_path: Optional[str] = None # For custom datasets
data_format: str = "jsonl" # jsonl, csv, json
# Field mapping - User configures which fields map to input/output
input_field: str = "text" # Field in dataset containing source text (e.g., "text", "source", etc.)
output_field: str = "styled_text" # Field in dataset containing styled text (e.g., "styled_text", "target", etc.)
instruction: str = "Rewrite the following text in a formal style" # Style instruction from YAML
# Data processing
max_samples: Optional[int] = None
train_split: float = 0.8
validation_split: float = 0.1
test_split: float = 0.1
# Text preprocessing
clean_text: bool = True
remove_special_chars: bool = False
lowercase: bool = False # Keep original case for styling
min_length: int = 10
max_length: int = 1000
# Output configuration
output_format: str = "styling" # instruction, conversation, qa
output_dir: str = "./data"
# Hugging Face specific
hf_split: str = "train"
hf_cache_dir: Optional[str] = None
# Split configuration
test_split_from: str = "train"
val_split_from: str = "train"
# Custom data specific
encoding: str = "utf-8"
delimiter: str = "," # For CSV files
# Alpaca prompt configuration
alpaca_prompt: str = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}"""
eos_token: str = "<|eot_id|>" # Use <|eot_id|> as EOS token
class DataValidator:
"""Validates styling data quality and format"""
@staticmethod
def validate_styling_data(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
"""Validate styling dataset splits"""
errors = []
# Check if we have the expected splits
expected_splits = ["train", "validation", "test"]
for split in expected_splits:
if split not in data:
errors.append(f"Missing '{split}' split")
elif split == "train" and not data[split]:
errors.append(f"Train split cannot be empty")
# Allow validation and test splits to be empty for small datasets
if errors:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
logger.info(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
output_field = "output" if is_processed else config.output_field
# Validate each split
for split_name, split_data in data.items():
if not split_data:
logger.info(f"Skipping validation for empty {split_name} split")
continue
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_input_count = 0
missing_output_count = 0
for i, item in enumerate(split_data):
if input_field not in item:
errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}")
missing_input_count += 1
if output_field not in item:
errors.append(f"Missing output field '{output_field}' in {split_name} split, item {i}")
missing_output_count += 1
logger.info(f"{split_name} - Items missing input field: {missing_input_count}")
logger.info(f"{split_name} - Items missing output field: {missing_output_count}")
# Check data types
type_errors = 0
for i, item in enumerate(split_data):
if not isinstance(item.get(input_field, ""), str):
errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}")
type_errors += 1
if not isinstance(item.get(output_field, ""), str):
errors.append(f"Output field '{output_field}' must be string in {split_name} split, item {i}")
type_errors += 1
logger.info(f"{split_name} - Type errors: {type_errors}")
# Check for empty inputs/outputs
empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip())
empty_outputs = sum(1 for item in split_data if not item.get(output_field, "").strip())
if empty_inputs > 0:
errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split")
if empty_outputs > 0:
errors.append(f"Found {empty_outputs} items with empty output text in {split_name} split")
logger.info(f"{split_name} - Empty inputs: {empty_inputs}")
logger.info(f"{split_name} - Empty outputs: {empty_outputs}")
# Show sample of processed data for debugging
if split_data:
logger.info(f"Sample processed items from {split_name}:")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', output='{item.get(output_field, '')[:50]}...'")
return len(errors) == 0, errors
@staticmethod
def analyze_dataset(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Dict[str, Any]:
"""Analyze dataset characteristics across all splits"""
analysis = {
"splits": {},
"overall": {
"total_samples": 0,
"split_sizes": {}
}
}
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
output_field = "output" if is_processed else config.output_field
# Analyze each split
for split_name, split_data in data.items():
if not split_data:
# Handle empty splits
split_analysis = {
"total_samples": 0,
"text_length_stats": {},
"missing_values": {}
}
analysis["splits"][split_name] = split_analysis
analysis["overall"]["split_sizes"][split_name] = 0
continue
split_analysis = {
"total_samples": len(split_data),
"text_length_stats": {},
"missing_values": {}
}
# Text length statistics for both input and output
for field_name, field in [("input", input_field), ("output", output_field)]:
text_lengths = [len(item.get(field, "")) for item in split_data]
if text_lengths:
split_analysis["text_length_stats"][field_name] = {
"min": min(text_lengths),
"max": max(text_lengths),
"mean": np.mean(text_lengths),
"median": np.median(text_lengths)
}
# Missing values
for field in [input_field, output_field]:
missing_count = sum(1 for item in split_data if not item.get(field))
split_analysis["missing_values"][field] = missing_count
analysis["splits"][split_name] = split_analysis
analysis["overall"]["total_samples"] += len(split_data)
analysis["overall"]["split_sizes"][split_name] = len(split_data)
return analysis
class BaseDataLoader(ABC):
"""Abstract base class for data loaders"""
@abstractmethod
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
"""Load data and return dictionary with train/val/test splits"""
pass
@abstractmethod
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits"""
pass
class HuggingFaceDataLoader(BaseDataLoader):
"""Load datasets from Hugging Face Hub"""
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
"""Load dataset from Hugging Face Hub with flexible split handling"""
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
logger.info(f"Loading Hugging Face dataset: {config.dataset_name}")
try:
# First, let's check what splits are available in the dataset
dataset = load_dataset(
config.dataset_name,
cache_dir=config.hf_cache_dir
)
# Log available splits
available_splits = list(dataset.keys())
logger.info(f"Available splits in dataset: {available_splits}")
# Initialize split data
splits_data = {
"train": [],
"validation": [],
"test": []
}
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
logger.error("No 'train' split found in dataset!")
logger.error(f"Available splits: {available_splits}")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation split
if config.val_split_from == "use_val_if_available" and "validation" in available_splits:
val_dataset = dataset["validation"]
logger.info(f"Using 'validation' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available" and "val" in available_splits:
val_dataset = dataset["val"]
logger.info(f"Using 'val' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available":
logger.warning("No validation split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.validation_split * 100}% of train data for validation")
else:
logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)")
# Handle test split
if config.test_split_from == "use_test_if_available" and "test" in available_splits:
test_dataset = dataset["test"]
logger.info(f"Using 'test' split with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "validation" in available_splits:
test_dataset = dataset["validation"]
logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "val" in available_splits:
test_dataset = dataset["val"]
logger.info(f"Using 'val' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_test_if_available":
logger.warning("No test split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.test_split * 100}% of train data for test")
else:
logger.info(f"Will create test split from train data ({config.test_split * 100}%)")
# If we need to create splits from train data
if not splits_data["validation"] or not splits_data["test"]:
train_data = splits_data["train"]
# Handle very small datasets
if len(train_data) < 3:
logger.warning(f"Dataset has only {len(train_data)} samples. Using all data for training.")
splits_data["train"] = train_data
splits_data["validation"] = []
splits_data["test"] = []
else:
# Calculate remaining percentages for train
total_train_percentage = config.train_split + config.validation_split + config.test_split
if total_train_percentage != 1.0:
logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...")
# Normalize percentages
config.train_split = config.train_split / total_train_percentage
config.validation_split = config.validation_split / total_train_percentage
config.test_split = config.test_split / total_train_percentage
# Create splits from train data
if not splits_data["validation"] and not splits_data["test"]:
# Split train into train, val, test
train_size = int(len(train_data) * config.train_split)
val_size = int(len(train_data) * config.validation_split)
# Handle small datasets
if len(train_data) < 10:
# For small datasets, use more conservative splits
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
# Ensure minimum sizes
min_val_size = max(1, int(len(train_data) * 0.1))
min_test_size = max(1, int(len(train_data) * 0.1))
val_size = max(min_val_size, int(len(train_data) * config.validation_split))
test_size = max(min_test_size, int(len(train_data) * config.test_split))
train_size = len(train_data) - val_size - test_size
# Ensure train has at least 1 sample
if train_size < 1:
if val_size > 1:
val_size -= 1
train_size += 1
elif test_size > 1:
test_size -= 1
train_size += 1
logger.info(f"Adjusted split sizes: train={train_size}, val={val_size}, test={test_size}")
# First split: train + (val+test)
new_train, temp_data = train_test_split(
train_data,
test_size=val_size + test_size,
random_state=42
)
# Second split: val + test
new_val, new_test = train_test_split(
temp_data,
test_size=test_size / (val_size + test_size) if (val_size + test_size) > 0 else 0,
random_state=42
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
splits_data["test"] = new_test
elif not splits_data["validation"]:
# Only need to create val from train
val_size = max(1, int(len(train_data) * config.validation_split))
new_train, new_val = train_test_split(
train_data,
test_size=val_size,
random_state=42
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
elif not splits_data["test"]:
# Only need to create test from train
test_size = max(1, int(len(train_data) * config.test_split))
new_train, new_test = train_test_split(
train_data,
test_size=test_size,
random_state=42
)
splits_data["train"] = new_train
splits_data["test"] = new_test
logger.info(f"Final split sizes:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
# Ensure all splits exist (even if empty) for the pipeline
if "validation" not in splits_data:
splits_data["validation"] = []
if "test" not in splits_data:
splits_data["test"] = []
# Apply max_samples limit to each split if specified
if config.max_samples:
for split_name in splits_data:
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
# Log dataset info for debugging
for split_name, split_data in splits_data.items():
if split_data:
logger.info(f"Sample data item from {split_name}: {split_data[0]}")
logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}")
# Check if the required fields exist
if config.input_field not in split_data[0]:
logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])]
if text_fields:
logger.info(f"Suggested text fields for {split_name}: {text_fields}")
if config.output_field not in split_data[0]:
logger.warning(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
output_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['output', 'response', 'result', 'target', 'styled'])]
if output_fields:
logger.info(f"Suggested output fields for {split_name}: {output_fields}")
logger.info(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
# Log field availability for debugging
if split_data:
available_fields = set(split_data[0].keys())
logger.info(f"Available fields in {split_name}: {available_fields}")
logger.info(f"Looking for input field: '{config.input_field}', output field: '{config.output_field}'")
if config.input_field not in available_fields:
logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}")
if config.output_field not in available_fields:
logger.error(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {available_fields}")
# Count items with missing fields
missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field))
missing_output = sum(1 for item in split_data if config.output_field not in item or not item.get(config.output_field))
logger.info(f"{split_name} - Items missing input field: {missing_input}")
logger.info(f"{split_name} - Items missing output field: {missing_output}")
# Show sample of raw data before preprocessing
logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f"Raw item {i} from {split_name}:")
for key, value in item.items():
if isinstance(value, str) and len(value) > 100:
logger.info(f" {key}: '{value[:100]}...'")
else:
logger.info(f" {key}: {value}")
# Process each item in the split
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
# Show sample of processed data
if processed_data:
logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===")
for i in range(min(3, len(processed_data))):
logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}")
return processed_splits
def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and output using configurable field names
input_text = item.get(config.input_field, "")
output_text = item.get(config.output_field, "")
# Log what we're extracting (for first few items)
if hasattr(self, '_debug_count'):
self._debug_count += 1
else:
self._debug_count = 1
if self._debug_count <= 3:
logger.debug(f"Processing item {self._debug_count}:")
logger.debug(f" Looking for input field '{config.input_field}': {input_text}")
logger.debug(f" Looking for output field '{config.output_field}': {output_text}")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
if self._debug_count <= 3:
logger.debug(f" After conversion - input: '{input_text[:50]}...', output: '{output_text[:50]}...'")
# Clean text if requested
if config.clean_text:
original_input = input_text
original_output = output_text
input_text = self._clean_text(input_text, config)
output_text = self._clean_text(output_text, config)
if self._debug_count <= 3:
logger.debug(f" After cleaning - input: '{original_input[:50]}...' -> '{input_text[:50]}...'")
logger.debug(f" After cleaning - output: '{original_output[:50]}...' -> '{output_text[:50]}...'")
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
if self._debug_count <= 3:
logger.debug(f" Skipping - input length {len(input_text)} not in range [{config.min_length}, {config.max_length}]")
return None
if len(output_text) < config.min_length or len(output_text) > config.max_length:
if self._debug_count <= 3:
logger.debug(f" Skipping - output length {len(output_text)} not in range [{config.min_length}, {config.max_length}]")
return None
# Create processed item - Always use "input" and "output" for internal processing
processed_item = {
"input": input_text,
"output": output_text
}
if self._debug_count <= 3:
logger.debug(f" Final processed item: {processed_item}")
return processed_item
def _clean_text(self, text: str, config: StylingConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
class CustomDataLoader(BaseDataLoader):
"""Load custom datasets from local files"""
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
"""Load custom dataset from local file and create splits"""
if not config.data_path:
raise ValueError("Data path is required for custom datasets")
file_path = Path(config.data_path)
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
logger.info(f"Loading custom dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
elif config.data_format == "csv":
raw_data = self._load_csv(file_path, config)
elif config.data_format == "json":
raw_data = self._load_json(file_path, config)
else:
raise ValueError(f"Unsupported format: {config.data_format}")
if config.max_samples:
raw_data = raw_data[:config.max_samples]
logger.info(f"Loaded {len(raw_data)} samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
return splits_data
def _create_splits(self, data: List[Dict], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
logger.info(f"Creating splits from {len(data)} samples...")
# Handle very small datasets
if len(data) < 3:
logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.")
return {
"train": data,
"validation": [],
"test": []
}
# Calculate split sizes with minimum guarantees
total_samples = len(data)
# Ensure minimum sizes for each split
min_val_size = max(1, int(total_samples * 0.1)) # At least 1 sample for validation
min_test_size = max(1, int(total_samples * 0.1)) # At least 1 sample for test
# Adjust split ratios if dataset is too small
if total_samples < 10:
# For small datasets, use more conservative splits
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
# Calculate actual split sizes
val_size = max(min_val_size, int(total_samples * config.validation_split))
test_size = max(min_test_size, int(total_samples * config.test_split))
train_size = total_samples - val_size - test_size
# Ensure train split has at least 1 sample
if train_size < 1:
# Adjust validation and test to ensure train has at least 1 sample
if val_size > 1:
val_size -= 1
train_size += 1
elif test_size > 1:
test_size -= 1
train_size += 1
logger.info(f"Adjusted split sizes to ensure train has at least 1 sample: train={train_size}, val={val_size}, test={test_size}")
logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
# Create splits
if val_size == 0 and test_size == 0:
# All data goes to train
splits_data = {
"train": data,
"validation": [],
"test": []
}
elif val_size == 0:
# Split between train and test
train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
splits_data = {
"train": train_data,
"validation": [],
"test": test_data
}
elif test_size == 0:
# Split between train and validation
train_data, val_data = train_test_split(data, test_size=val_size, random_state=42)
splits_data = {
"train": train_data,
"validation": val_data,
"test": []
}
else:
# Full three-way split
# First split: train + (val+test)
train_data, temp_data = train_test_split(
data,
test_size=val_size + test_size,
random_state=42
)
# Second split: val + test
val_data, test_data = train_test_split(
temp_data,
test_size=test_size,
random_state=42
)
splits_data = {
"train": train_data,
"validation": val_data,
"test": test_data
}
logger.info(f"Created splits:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
return splits_data
def _load_jsonl(self, file_path: Path, config: StylingConfig) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding=config.encoding) as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_csv(self, file_path: Path, config: StylingConfig) -> List[Dict]:
"""Load CSV file"""
df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter)
return df.to_dict('records')
def _load_json(self, file_path: Path, config: StylingConfig) -> List[Dict]:
"""Load JSON file"""
with open(file_path, 'r', encoding=config.encoding) as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "data" in data:
return data["data"]
else:
return [data]
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CUSTOM DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and output using configurable field names
input_text = item.get(config.input_field, "")
output_text = item.get(config.output_field, "")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
# Clean text if requested
if config.clean_text:
input_text = self._clean_text(input_text, config)
output_text = self._clean_text(output_text, config)
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
return None
if len(output_text) < config.min_length or len(output_text) > config.max_length:
return None
# Create processed item - Always use "input" and "output" for internal processing
processed_item = {
"input": input_text,
"output": output_text
}
return processed_item
def _clean_text(self, text: str, config: StylingConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
class StylingDataPipeline:
"""Main styling pipeline"""
def __init__(self):
self.validator = DataValidator()
self.hf_loader = HuggingFaceDataLoader()
self.custom_loader = CustomDataLoader()
def create_config(
self,
data_source: str,
dataset_name: Optional[str] = None,
data_path: Optional[str] = None,
input_field: str = "input",
output_field: str = "output",
instruction: str = "Rewrite the following text in a formal style",
**kwargs
) -> StylingConfig:
"""Create styling configuration"""
return StylingConfig(
data_source=data_source,
dataset_name=dataset_name,
data_path=data_path,
input_field=input_field,
output_field=output_field,
instruction=instruction,
**kwargs
)
def load_config_from_yaml(self, yaml_path: str) -> StylingConfig:
"""Load configuration from YAML file"""
try:
config_dict = load_yaml_config(yaml_path)
# Create configuration object from YAML data
config = StylingConfig(
data_source=config_dict.get('data_source', 'custom'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
input_field=config_dict.get('input_field', 'text'),
output_field=config_dict.get('output_field', 'styled_text'),
instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
remove_special_chars=config_dict.get('remove_special_chars', False),
lowercase=config_dict.get('lowercase', False),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 1000),
output_format=config_dict.get('output_format', 'styling'),
output_dir=config_dict.get('output_dir', './data'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8'),
delimiter=config_dict.get('delimiter', ',')
)
logger.info(f"Configuration loaded from YAML: {yaml_path}")
logger.info(f"Output directory: {config.output_dir}")
logger.info(f"Instruction: {config.instruction}")
return config
except Exception as e:
logger.error(f"Error loading configuration from YAML {yaml_path}: {e}")
raise
def load_and_preprocess(self, config: StylingConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess data"""
logger.info(f"Starting data loading and preprocessing...")
logger.info(f"Data source: {config.data_source}")
try:
# Load data
if config.data_source == "huggingface":
logger.info("Loading HuggingFace dataset...")
raw_splits = self.hf_loader.load(config)
logger.info("Preprocessing HuggingFace dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
logger.info("Loading custom dataset...")
raw_splits = self.custom_loader.load(config)
logger.info("Preprocessing custom dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
logger.info(f"Data loading and preprocessing completed successfully")
logger.info(f"Raw splits: {list(raw_splits.keys())}")
logger.info(f"Processed splits: {list(processed_splits.keys())}")
# Validate processed data
logger.info("Validating processed data...")
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Data validation failed:")
for error in errors:
logger.error(f" - {error}")
raise ValueError("Data validation failed")
logger.info("Data validation passed")
# Analyze dataset
logger.info("Analyzing dataset...")
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
logger.info("Dataset analysis completed")
return processed_splits, analysis
except Exception as e:
logger.error(f"Error in load_and_preprocess: {e}")
raise
def convert_to_alpaca_format(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Convert styling data to Alpaca format with instruction"""
alpaca_splits = {}
for split_name, split_data in data.items():
alpaca_data = []
for item in split_data:
# Ensure input and output fields exist, default to empty string if missing
input_text = item.get("input", "")
output_text = item.get("output", "")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
alpaca_data.append({
"instruction": config.instruction,
"input": input_text,
"output": output_text
})
alpaca_splits[split_name] = alpaca_data
return alpaca_splits
def format_for_training(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[str]]:
"""Format entries for training using Alpaca prompt format"""
formatted_splits = {}
for split_name, split_data in data.items():
formatted_texts = []
for item in split_data:
# Ensure input and output fields exist, default to empty string if missing
input_text = item.get("input", "")
output_text = item.get("output", "")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
text = config.alpaca_prompt.format(
config.instruction,
input_text,
output_text
) + config.eos_token
formatted_texts.append(text)
formatted_splits[split_name] = formatted_texts
return formatted_splits
def convert_to_hf_dataset(self, dataset_entries: List[Dict], config: StylingConfig):
"""Convert dataset entries to HuggingFace dataset format with text formatting"""
from datasets import Dataset
# Create HuggingFace dataset from list of dictionaries
hf_dataset = Dataset.from_list(dataset_entries)
# Apply formatting function to generate the text field
def formatting_prompts_func(examples):
instructions = examples["instruction"]
inputs = examples["input"]
outputs = examples["output"]
texts = []
for instruction, input_text, output in zip(instructions, inputs, outputs):
# Handle None values and ensure strings
if input_text is None:
input_text = ""
if output is None:
output = ""
# Convert to string if needed
input_text = str(input_text)
output = str(output)
# Use the config's EOS token and alpaca prompt
text = config.alpaca_prompt.format(instruction, input_text, output) + config.eos_token
texts.append(text)
return {"text": texts}
# Apply the formatting function
formatted_dataset = hf_dataset.map(formatting_prompts_func, batched=True)
return formatted_dataset
def save_hf_dataset_to_disk(self, hf_dataset, save_path: str):
"""Save HuggingFace dataset to disk"""
try:
hf_dataset.save_to_disk(save_path)
logger.info(f"HuggingFace dataset saved to disk at: {save_path}")
return True
except Exception as e:
logger.error(f"Error saving HuggingFace dataset to disk: {e}")
return False
def load_hf_dataset_from_disk(self, load_path: str):
"""Load HuggingFace dataset from disk"""
try:
from datasets import load_from_disk
hf_dataset = load_from_disk(load_path)
logger.info(f"HuggingFace dataset loaded from disk: {load_path}")
logger.info(f"Dataset has {len(hf_dataset)} entries")
logger.info(f"Dataset features: {hf_dataset.features}")
return hf_dataset
except Exception as e:
logger.error(f"Error loading HuggingFace dataset from disk: {e}")
return None
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
"""Save processed data splits to files"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for split_name, split_data in data.items():
if format == "jsonl":
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == "json":
output_file = output_path / f"{split_name}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
elif format == "csv":
output_file = output_path / f"{split_name}.csv"
df = pd.DataFrame(split_data)
df.to_csv(output_file, index=False)
logger.info(f"Saved {len(split_data)} samples to {output_file}")
def run_pipeline(
self,
config: StylingConfig,
output_format: str = "styling",
save_splits: bool = True,
create_hf_dataset: bool = False,
save_hf_dataset: bool = False,
hf_dataset_path: str = None
) -> Dict[str, Any]:
"""Run complete styling pipeline"""
logger.info("Starting styling pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
# Convert to desired output format
if output_format == "alpaca":
formatted_splits = self.convert_to_alpaca_format(processed_splits, config)
else:
formatted_splits = processed_splits
# Save data if requested
if save_splits:
# Save directly in the output directory, not in a subdirectory
output_dir = Path(config.output_dir)
self.save_data(formatted_splits, str(output_dir))
# Convert to HuggingFace dataset if requested
hf_dataset = None
hf_dataset_save_path = None
if create_hf_dataset:
# Flatten all splits into one list for HF dataset
all_entries = []
for split_name, split_data in formatted_splits.items():
for item in split_data:
# Ensure we have the instruction field
if "instruction" not in item:
item["instruction"] = config.instruction
all_entries.append(item)
hf_dataset = self.convert_to_hf_dataset(all_entries, config)
logger.info(f"HuggingFace dataset created with {len(hf_dataset)} entries")
logger.info(f"Dataset features: {hf_dataset.features}")
# Save HuggingFace dataset to disk if requested
if save_hf_dataset:
if hf_dataset_path is None:
# Generate default path using the YAML output_dir
hf_dataset_path = str(Path(config.output_dir) / "hf_dataset")
success = self.save_hf_dataset_to_disk(hf_dataset, hf_dataset_path)
if success:
hf_dataset_save_path = hf_dataset_path
logger.info(f"HuggingFace dataset saved to: {hf_dataset_save_path}")
else:
logger.warning("Failed to save HuggingFace dataset to disk")
# Create result summary
result = {
"config": config,
"analysis": analysis,
"splits": {
split_name: len(split_data) for split_name, split_data in formatted_splits.items()
},
"output_format": output_format,
"output_dir": config.output_dir,
"data": formatted_splits, # Include the actual processed data
"instruction": config.instruction
}
# Add HuggingFace dataset info to result if created
if hf_dataset is not None:
result["hf_dataset"] = hf_dataset
if hf_dataset_save_path:
result["hf_dataset_path"] = hf_dataset_save_path
logger.info("Styling pipeline completed successfully!")
return result
# Helper functions
def create_huggingface_config(dataset_name: str, input_field: str = "text", output_field: str = "output", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig:
"""Helper function to create a HuggingFace configuration"""
return StylingConfig(
data_source="huggingface",
dataset_name=dataset_name,
input_field=input_field,
output_field=output_field,
instruction=instruction,
**kwargs
)
def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", output_field: str = "styled_text", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig:
"""Helper function to create a custom data configuration"""
return StylingConfig(
data_source="custom",
data_path=data_path,
data_format=data_format,
input_field=input_field,
output_field=output_field,
instruction=instruction,
**kwargs
)
def save_hf_dataset_to_disk(hf_dataset, save_path: str) -> bool:
"""Utility function to save HuggingFace dataset to disk"""
try:
hf_dataset.save_to_disk(save_path)
print(f"HuggingFace dataset saved to disk at: {save_path}")
return True
except Exception as e:
print(f"Error saving HuggingFace dataset to disk: {e}")
return False
def load_hf_dataset_from_disk(load_path: str):
"""Utility function to load HuggingFace dataset from disk"""
try:
from datasets import load_from_disk
hf_dataset = load_from_disk(load_path)
print(f"HuggingFace dataset loaded from disk: {load_path}")
print(f"Dataset has {len(hf_dataset)} entries")
print(f"Dataset features: {hf_dataset.features}")
return hf_dataset
except Exception as e:
print(f"Error loading HuggingFace dataset from disk: {e}")
return None
def load_yaml_config(config_path: str) -> Dict[str, Any]:
"""Load and parse YAML configuration file with proper structure handling"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# Extract configuration from YAML structure
config_dict = {}
# Handle task section
if 'task' in yaml_data:
task_data = yaml_data['task']
config_dict.update({
'task_name': task_data.get('name'),
'task_type': task_data.get('type')
})
# Handle data section
if 'data' in yaml_data:
data_config = yaml_data['data']
config_dict.update({
'data_source': data_config.get('source'),
'dataset_name': data_config.get('dataset_name'),
'data_path': data_config.get('data_path'),
'data_format': data_config.get('data_format'),
'input_field': data_config.get('input_field'),
'output_field': data_config.get('output_field'),
'instruction': data_config.get('instruction'),
'max_samples': data_config.get('max_samples'),
'train_split': data_config.get('train_split'),
'validation_split': data_config.get('validation_split'),
'test_split': data_config.get('test_split'),
'clean_text': data_config.get('clean_text'),
'lowercase': data_config.get('lowercase'),
'min_length': data_config.get('min_length'),
'max_length': data_config.get('max_length'),
'output_format': data_config.get('output_format'),
'output_dir': data_config.get('output_dir'),
'encoding': data_config.get('encoding'),
'delimiter': data_config.get('delimiter')
})
# Handle model section
if 'model' in yaml_data:
model_data = yaml_data['model']
config_dict.update({
'model_name': model_data.get('name'),
'model_max_length': model_data.get('max_length')
})
# Handle training section
if 'training' in yaml_data:
training_data = yaml_data['training']
config_dict.update({
'num_epochs': training_data.get('num_epochs'),
'batch_size': training_data.get('batch_size'),
'learning_rate': training_data.get('learning_rate'),
'weight_decay': training_data.get('weight_decay'),
'warmup_ratio': training_data.get('warmup_ratio'),
'lr_scheduler_type': training_data.get('lr_scheduler_type')
})
# Handle inference section
if 'inference' in yaml_data:
inference_data = yaml_data['inference']
config_dict.update({
'inference_batch_size': inference_data.get('batch_size'),
'max_new_tokens': inference_data.get('max_new_tokens'),
'temperature': inference_data.get('temperature')
})
logger.info(f"Successfully parsed YAML configuration from: {config_path}")
logger.info(f"Extracted {len(config_dict)} configuration parameters")
return config_dict
except Exception as e:
logger.error(f"Error loading YAML config from {config_path}: {e}")
raise
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Styling Data Processing Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
# Field mapping
parser.add_argument("--input-field", type=str, help="Input field name")
parser.add_argument("--output-field", type=str, help="Output field name")
parser.add_argument("--instruction", type=str, help="Style instruction")
# Data processing
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--train-split", type=float, help="Training split ratio")
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
parser.add_argument("--test-split", type=float, help="Test split ratio")
# Text preprocessing
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
parser.add_argument("--min-length", type=int, help="Minimum text length")
parser.add_argument("--max-length", type=int, help="Maximum text length")
# Output configuration
parser.add_argument("--output-format", choices=["styling", "alpaca"], help="Output format")
parser.add_argument("--output-dir", type=str, help="Output directory")
# HuggingFace dataset options
parser.add_argument("--create-hf-dataset", action="store_true", help="Create HuggingFace dataset")
parser.add_argument("--hf-dataset-path", type=str, help="Path to save HuggingFace dataset")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
config_dict = load_yaml_config(args.config)
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.data_source:
cli_overrides['data_source'] = args.data_source
if args.dataset_name:
cli_overrides['dataset_name'] = args.dataset_name
if args.data_path:
cli_overrides['data_path'] = args.data_path
if args.data_format:
cli_overrides['data_format'] = args.data_format
if args.input_field:
cli_overrides['input_field'] = args.input_field
if args.output_field:
cli_overrides['output_field'] = args.output_field
if args.instruction:
cli_overrides['instruction'] = args.instruction
if args.max_samples:
cli_overrides['max_samples'] = args.max_samples
if args.train_split:
cli_overrides['train_split'] = args.train_split
if args.validation_split:
cli_overrides['validation_split'] = args.validation_split
if args.test_split:
cli_overrides['test_split'] = args.test_split
if args.clean_text:
cli_overrides['clean_text'] = True
if args.remove_special_chars:
cli_overrides['remove_special_chars'] = True
if args.lowercase:
cli_overrides['lowercase'] = True
if args.min_length:
cli_overrides['min_length'] = args.min_length
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.output_format:
cli_overrides['output_format'] = args.output_format
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# HuggingFace dataset options
if args.create_hf_dataset:
cli_overrides['create_hf_dataset'] = True
if args.hf_dataset_path:
cli_overrides['hf_dataset_path'] = args.hf_dataset_path
# Logging
if args.log_level:
cli_overrides['log_level'] = args.log_level
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
if not config_dict.get('data_source'):
parser.error("--data-source is required (either in YAML config or CLI)")
if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'):
parser.error("--dataset-name is required for HuggingFace datasets")
if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'):
parser.error("--data-path is required for custom datasets")
# Create configuration object - properly handle YAML structure
config = StylingConfig(
data_source=config_dict.get('data_source', 'huggingface'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
input_field=config_dict.get('input_field', 'text'),
output_field=config_dict.get('output_field', 'styled_text'),
instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
remove_special_chars=config_dict.get('remove_special_chars', False),
lowercase=config_dict.get('lowercase', False),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 1000),
output_format=config_dict.get('output_format', 'styling'),
output_dir=config_dict.get('output_dir', './data'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8'),
delimiter=config_dict.get('delimiter', ',')
)
# Initialize pipeline
pipeline = StylingDataPipeline()
try:
print(f"Starting styling pipeline with {config.data_source} data source...")
if args.config:
print(f"Using YAML configuration: {args.config}")
print(f"Style instruction: {config.instruction}")
print()
# Check if we should create HuggingFace dataset
create_hf_dataset = cli_overrides.get('create_hf_dataset', False)
hf_dataset_path = cli_overrides.get('hf_dataset_path')
# If creating HF dataset, also save it by default
save_hf_dataset = create_hf_dataset
result = pipeline.run_pipeline(
config,
config.output_format,
save_splits=True,
create_hf_dataset=create_hf_dataset,
save_hf_dataset=save_hf_dataset,
hf_dataset_path=hf_dataset_path
)
print(f"✅ Pipeline completed successfully!")
print(f" Data source: {config.data_source}")
if config.data_source == "huggingface":
print(f" Dataset: {config.dataset_name}")
else:
print(f" Data file: {config.data_path}")
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
print(f" Output directory: {config.output_dir}")
print(f" Style instruction: {config.instruction}")
except Exception as e:
print(f"❌ Error running pipeline: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()