Files

1513 lines
65 KiB
Python
Raw Permalink Normal View History

2025-08-13 21:17:01 +01:00
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
from sklearn.model_selection import train_test_split
import re
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@dataclass
class StylingConfig:
"""Configuration for styling tasks"""
# Data source configuration
data_source: str = "huggingface" # "huggingface" or "custom"
dataset_name: Optional[str] = None # For Hugging Face datasets
data_path: Optional[str] = None # For custom datasets
data_format: str = "jsonl" # jsonl, csv, json
# Field mapping - User configures which fields map to input/output
input_field: str = "text" # Field in dataset containing source text (e.g., "text", "source", etc.)
output_field: str = "styled_text" # Field in dataset containing styled text (e.g., "styled_text", "target", etc.)
instruction: str = "Rewrite the following text in a formal style" # Style instruction from YAML
# Data processing
max_samples: Optional[int] = None
train_split: float = 0.8
validation_split: float = 0.1
test_split: float = 0.1
# Text preprocessing
clean_text: bool = True
remove_special_chars: bool = False
lowercase: bool = False # Keep original case for styling
min_length: int = 10
max_length: int = 1000
# Output configuration
output_format: str = "styling" # instruction, conversation, qa
output_dir: str = "./data"
# Hugging Face specific
hf_split: str = "train"
hf_cache_dir: Optional[str] = None
# Split configuration
test_split_from: str = "train"
val_split_from: str = "train"
# Custom data specific
encoding: str = "utf-8"
delimiter: str = "," # For CSV files
# Alpaca prompt configuration
alpaca_prompt: str = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{}"""
eos_token: str = "<|eot_id|>" # Use <|eot_id|> as EOS token
class DataValidator:
"""Validates styling data quality and format"""
@staticmethod
def validate_styling_data(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
"""Validate styling dataset splits"""
errors = []
# Check if we have the expected splits
expected_splits = ["train", "validation", "test"]
for split in expected_splits:
if split not in data:
errors.append(f"Missing '{split}' split")
elif split == "train" and not data[split]:
errors.append(f"Train split cannot be empty")
# Allow validation and test splits to be empty for small datasets
if errors:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
logger.info(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
output_field = "output" if is_processed else config.output_field
# Validate each split
for split_name, split_data in data.items():
if not split_data:
logger.info(f"Skipping validation for empty {split_name} split")
continue
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_input_count = 0
missing_output_count = 0
for i, item in enumerate(split_data):
if input_field not in item:
errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}")
missing_input_count += 1
if output_field not in item:
errors.append(f"Missing output field '{output_field}' in {split_name} split, item {i}")
missing_output_count += 1
logger.info(f"{split_name} - Items missing input field: {missing_input_count}")
logger.info(f"{split_name} - Items missing output field: {missing_output_count}")
# Check data types
type_errors = 0
for i, item in enumerate(split_data):
if not isinstance(item.get(input_field, ""), str):
errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}")
type_errors += 1
if not isinstance(item.get(output_field, ""), str):
errors.append(f"Output field '{output_field}' must be string in {split_name} split, item {i}")
type_errors += 1
logger.info(f"{split_name} - Type errors: {type_errors}")
# Check for empty inputs/outputs
empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip())
empty_outputs = sum(1 for item in split_data if not item.get(output_field, "").strip())
if empty_inputs > 0:
errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split")
if empty_outputs > 0:
errors.append(f"Found {empty_outputs} items with empty output text in {split_name} split")
logger.info(f"{split_name} - Empty inputs: {empty_inputs}")
logger.info(f"{split_name} - Empty outputs: {empty_outputs}")
# Show sample of processed data for debugging
if split_data:
logger.info(f"Sample processed items from {split_name}:")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', output='{item.get(output_field, '')[:50]}...'")
return len(errors) == 0, errors
@staticmethod
def analyze_dataset(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Dict[str, Any]:
"""Analyze dataset characteristics across all splits"""
analysis = {
"splits": {},
"overall": {
"total_samples": 0,
"split_sizes": {}
}
}
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
output_field = "output" if is_processed else config.output_field
# Analyze each split
for split_name, split_data in data.items():
if not split_data:
# Handle empty splits
split_analysis = {
"total_samples": 0,
"text_length_stats": {},
"missing_values": {}
}
analysis["splits"][split_name] = split_analysis
analysis["overall"]["split_sizes"][split_name] = 0
continue
split_analysis = {
"total_samples": len(split_data),
"text_length_stats": {},
"missing_values": {}
}
# Text length statistics for both input and output
for field_name, field in [("input", input_field), ("output", output_field)]:
text_lengths = [len(item.get(field, "")) for item in split_data]
if text_lengths:
split_analysis["text_length_stats"][field_name] = {
"min": min(text_lengths),
"max": max(text_lengths),
"mean": np.mean(text_lengths),
"median": np.median(text_lengths)
}
# Missing values
for field in [input_field, output_field]:
missing_count = sum(1 for item in split_data if not item.get(field))
split_analysis["missing_values"][field] = missing_count
analysis["splits"][split_name] = split_analysis
analysis["overall"]["total_samples"] += len(split_data)
analysis["overall"]["split_sizes"][split_name] = len(split_data)
return analysis
class BaseDataLoader(ABC):
"""Abstract base class for data loaders"""
@abstractmethod
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
"""Load data and return dictionary with train/val/test splits"""
pass
@abstractmethod
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits"""
pass
class HuggingFaceDataLoader(BaseDataLoader):
"""Load datasets from Hugging Face Hub"""
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
"""Load dataset from Hugging Face Hub with flexible split handling"""
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
logger.info(f"Loading Hugging Face dataset: {config.dataset_name}")
try:
# First, let's check what splits are available in the dataset
dataset = load_dataset(
config.dataset_name,
cache_dir=config.hf_cache_dir
)
# Log available splits
available_splits = list(dataset.keys())
logger.info(f"Available splits in dataset: {available_splits}")
# Initialize split data
splits_data = {
"train": [],
"validation": [],
"test": []
}
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
logger.error("No 'train' split found in dataset!")
logger.error(f"Available splits: {available_splits}")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation split
if config.val_split_from == "use_val_if_available" and "validation" in available_splits:
val_dataset = dataset["validation"]
logger.info(f"Using 'validation' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available" and "val" in available_splits:
val_dataset = dataset["val"]
logger.info(f"Using 'val' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available":
logger.warning("No validation split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.validation_split * 100}% of train data for validation")
else:
logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)")
# Handle test split
if config.test_split_from == "use_test_if_available" and "test" in available_splits:
test_dataset = dataset["test"]
logger.info(f"Using 'test' split with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "validation" in available_splits:
test_dataset = dataset["validation"]
logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "val" in available_splits:
test_dataset = dataset["val"]
logger.info(f"Using 'val' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_test_if_available":
logger.warning("No test split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.test_split * 100}% of train data for test")
else:
logger.info(f"Will create test split from train data ({config.test_split * 100}%)")
# If we need to create splits from train data
if not splits_data["validation"] or not splits_data["test"]:
train_data = splits_data["train"]
# Handle very small datasets
if len(train_data) < 3:
logger.warning(f"Dataset has only {len(train_data)} samples. Using all data for training.")
splits_data["train"] = train_data
splits_data["validation"] = []
splits_data["test"] = []
else:
# Calculate remaining percentages for train
total_train_percentage = config.train_split + config.validation_split + config.test_split
if total_train_percentage != 1.0:
logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...")
# Normalize percentages
config.train_split = config.train_split / total_train_percentage
config.validation_split = config.validation_split / total_train_percentage
config.test_split = config.test_split / total_train_percentage
# Create splits from train data
if not splits_data["validation"] and not splits_data["test"]:
# Split train into train, val, test
train_size = int(len(train_data) * config.train_split)
val_size = int(len(train_data) * config.validation_split)
# Handle small datasets
if len(train_data) < 10:
# For small datasets, use more conservative splits
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
# Ensure minimum sizes
min_val_size = max(1, int(len(train_data) * 0.1))
min_test_size = max(1, int(len(train_data) * 0.1))
val_size = max(min_val_size, int(len(train_data) * config.validation_split))
test_size = max(min_test_size, int(len(train_data) * config.test_split))
train_size = len(train_data) - val_size - test_size
# Ensure train has at least 1 sample
if train_size < 1:
if val_size > 1:
val_size -= 1
train_size += 1
elif test_size > 1:
test_size -= 1
train_size += 1
logger.info(f"Adjusted split sizes: train={train_size}, val={val_size}, test={test_size}")
# First split: train + (val+test)
new_train, temp_data = train_test_split(
train_data,
test_size=val_size + test_size,
random_state=42
)
# Second split: val + test
new_val, new_test = train_test_split(
temp_data,
test_size=test_size / (val_size + test_size) if (val_size + test_size) > 0 else 0,
random_state=42
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
splits_data["test"] = new_test
elif not splits_data["validation"]:
# Only need to create val from train
val_size = max(1, int(len(train_data) * config.validation_split))
new_train, new_val = train_test_split(
train_data,
test_size=val_size,
random_state=42
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
elif not splits_data["test"]:
# Only need to create test from train
test_size = max(1, int(len(train_data) * config.test_split))
new_train, new_test = train_test_split(
train_data,
test_size=test_size,
random_state=42
)
splits_data["train"] = new_train
splits_data["test"] = new_test
logger.info(f"Final split sizes:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
# Ensure all splits exist (even if empty) for the pipeline
if "validation" not in splits_data:
splits_data["validation"] = []
if "test" not in splits_data:
splits_data["test"] = []
# Apply max_samples limit to each split if specified
if config.max_samples:
for split_name in splits_data:
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
# Log dataset info for debugging
for split_name, split_data in splits_data.items():
if split_data:
logger.info(f"Sample data item from {split_name}: {split_data[0]}")
logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}")
# Check if the required fields exist
if config.input_field not in split_data[0]:
logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])]
if text_fields:
logger.info(f"Suggested text fields for {split_name}: {text_fields}")
if config.output_field not in split_data[0]:
logger.warning(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
output_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['output', 'response', 'result', 'target', 'styled'])]
if output_fields:
logger.info(f"Suggested output fields for {split_name}: {output_fields}")
logger.info(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
# Log field availability for debugging
if split_data:
available_fields = set(split_data[0].keys())
logger.info(f"Available fields in {split_name}: {available_fields}")
logger.info(f"Looking for input field: '{config.input_field}', output field: '{config.output_field}'")
if config.input_field not in available_fields:
logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}")
if config.output_field not in available_fields:
logger.error(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {available_fields}")
# Count items with missing fields
missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field))
missing_output = sum(1 for item in split_data if config.output_field not in item or not item.get(config.output_field))
logger.info(f"{split_name} - Items missing input field: {missing_input}")
logger.info(f"{split_name} - Items missing output field: {missing_output}")
# Show sample of raw data before preprocessing
logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f"Raw item {i} from {split_name}:")
for key, value in item.items():
if isinstance(value, str) and len(value) > 100:
logger.info(f" {key}: '{value[:100]}...'")
else:
logger.info(f" {key}: {value}")
# Process each item in the split
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
# Show sample of processed data
if processed_data:
logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===")
for i in range(min(3, len(processed_data))):
logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}")
return processed_splits
def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and output using configurable field names
input_text = item.get(config.input_field, "")
output_text = item.get(config.output_field, "")
# Log what we're extracting (for first few items)
if hasattr(self, '_debug_count'):
self._debug_count += 1
else:
self._debug_count = 1
if self._debug_count <= 3:
logger.debug(f"Processing item {self._debug_count}:")
logger.debug(f" Looking for input field '{config.input_field}': {input_text}")
logger.debug(f" Looking for output field '{config.output_field}': {output_text}")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
if self._debug_count <= 3:
logger.debug(f" After conversion - input: '{input_text[:50]}...', output: '{output_text[:50]}...'")
# Clean text if requested
if config.clean_text:
original_input = input_text
original_output = output_text
input_text = self._clean_text(input_text, config)
output_text = self._clean_text(output_text, config)
if self._debug_count <= 3:
logger.debug(f" After cleaning - input: '{original_input[:50]}...' -> '{input_text[:50]}...'")
logger.debug(f" After cleaning - output: '{original_output[:50]}...' -> '{output_text[:50]}...'")
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
if self._debug_count <= 3:
logger.debug(f" Skipping - input length {len(input_text)} not in range [{config.min_length}, {config.max_length}]")
return None
if len(output_text) < config.min_length or len(output_text) > config.max_length:
if self._debug_count <= 3:
logger.debug(f" Skipping - output length {len(output_text)} not in range [{config.min_length}, {config.max_length}]")
return None
# Create processed item - Always use "input" and "output" for internal processing
processed_item = {
"input": input_text,
"output": output_text
}
if self._debug_count <= 3:
logger.debug(f" Final processed item: {processed_item}")
return processed_item
def _clean_text(self, text: str, config: StylingConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
class CustomDataLoader(BaseDataLoader):
"""Load custom datasets from local files"""
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
"""Load custom dataset from local file and create splits"""
if not config.data_path:
raise ValueError("Data path is required for custom datasets")
file_path = Path(config.data_path)
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
logger.info(f"Loading custom dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
elif config.data_format == "csv":
raw_data = self._load_csv(file_path, config)
elif config.data_format == "json":
raw_data = self._load_json(file_path, config)
else:
raise ValueError(f"Unsupported format: {config.data_format}")
if config.max_samples:
raw_data = raw_data[:config.max_samples]
logger.info(f"Loaded {len(raw_data)} samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
return splits_data
def _create_splits(self, data: List[Dict], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
logger.info(f"Creating splits from {len(data)} samples...")
# Handle very small datasets
if len(data) < 3:
logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.")
return {
"train": data,
"validation": [],
"test": []
}
# Calculate split sizes with minimum guarantees
total_samples = len(data)
# Ensure minimum sizes for each split
min_val_size = max(1, int(total_samples * 0.1)) # At least 1 sample for validation
min_test_size = max(1, int(total_samples * 0.1)) # At least 1 sample for test
# Adjust split ratios if dataset is too small
if total_samples < 10:
# For small datasets, use more conservative splits
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
# Calculate actual split sizes
val_size = max(min_val_size, int(total_samples * config.validation_split))
test_size = max(min_test_size, int(total_samples * config.test_split))
train_size = total_samples - val_size - test_size
# Ensure train split has at least 1 sample
if train_size < 1:
# Adjust validation and test to ensure train has at least 1 sample
if val_size > 1:
val_size -= 1
train_size += 1
elif test_size > 1:
test_size -= 1
train_size += 1
logger.info(f"Adjusted split sizes to ensure train has at least 1 sample: train={train_size}, val={val_size}, test={test_size}")
logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
# Create splits
if val_size == 0 and test_size == 0:
# All data goes to train
splits_data = {
"train": data,
"validation": [],
"test": []
}
elif val_size == 0:
# Split between train and test
train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
splits_data = {
"train": train_data,
"validation": [],
"test": test_data
}
elif test_size == 0:
# Split between train and validation
train_data, val_data = train_test_split(data, test_size=val_size, random_state=42)
splits_data = {
"train": train_data,
"validation": val_data,
"test": []
}
else:
# Full three-way split
# First split: train + (val+test)
train_data, temp_data = train_test_split(
data,
test_size=val_size + test_size,
random_state=42
)
# Second split: val + test
val_data, test_data = train_test_split(
temp_data,
test_size=test_size,
random_state=42
)
splits_data = {
"train": train_data,
"validation": val_data,
"test": test_data
}
logger.info(f"Created splits:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
return splits_data
def _load_jsonl(self, file_path: Path, config: StylingConfig) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding=config.encoding) as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_csv(self, file_path: Path, config: StylingConfig) -> List[Dict]:
"""Load CSV file"""
df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter)
return df.to_dict('records')
def _load_json(self, file_path: Path, config: StylingConfig) -> List[Dict]:
"""Load JSON file"""
with open(file_path, 'r', encoding=config.encoding) as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "data" in data:
return data["data"]
else:
return [data]
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CUSTOM DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and output using configurable field names
input_text = item.get(config.input_field, "")
output_text = item.get(config.output_field, "")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
# Clean text if requested
if config.clean_text:
input_text = self._clean_text(input_text, config)
output_text = self._clean_text(output_text, config)
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
return None
if len(output_text) < config.min_length or len(output_text) > config.max_length:
return None
# Create processed item - Always use "input" and "output" for internal processing
processed_item = {
"input": input_text,
"output": output_text
}
return processed_item
def _clean_text(self, text: str, config: StylingConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
class StylingDataPipeline:
"""Main styling pipeline"""
def __init__(self):
self.validator = DataValidator()
self.hf_loader = HuggingFaceDataLoader()
self.custom_loader = CustomDataLoader()
def create_config(
self,
data_source: str,
dataset_name: Optional[str] = None,
data_path: Optional[str] = None,
input_field: str = "input",
output_field: str = "output",
instruction: str = "Rewrite the following text in a formal style",
**kwargs
) -> StylingConfig:
"""Create styling configuration"""
return StylingConfig(
data_source=data_source,
dataset_name=dataset_name,
data_path=data_path,
input_field=input_field,
output_field=output_field,
instruction=instruction,
**kwargs
)
def load_config_from_yaml(self, yaml_path: str) -> StylingConfig:
"""Load configuration from YAML file"""
try:
config_dict = load_yaml_config(yaml_path)
# Create configuration object from YAML data
config = StylingConfig(
data_source=config_dict.get('data_source', 'custom'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
input_field=config_dict.get('input_field', 'text'),
output_field=config_dict.get('output_field', 'styled_text'),
instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
remove_special_chars=config_dict.get('remove_special_chars', False),
lowercase=config_dict.get('lowercase', False),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 1000),
output_format=config_dict.get('output_format', 'styling'),
output_dir=config_dict.get('output_dir', './data'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8'),
delimiter=config_dict.get('delimiter', ',')
)
logger.info(f"Configuration loaded from YAML: {yaml_path}")
logger.info(f"Output directory: {config.output_dir}")
logger.info(f"Instruction: {config.instruction}")
return config
except Exception as e:
logger.error(f"Error loading configuration from YAML {yaml_path}: {e}")
raise
def load_and_preprocess(self, config: StylingConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess data"""
2025-08-13 23:50:20 +00:00
logger.info(f"Starting data loading and preprocessing...")
logger.info(f"Data source: {config.data_source}")
2025-08-13 21:17:01 +01:00
2025-08-13 23:50:20 +00:00
try:
# Load data
if config.data_source == "huggingface":
logger.info("Loading HuggingFace dataset...")
raw_splits = self.hf_loader.load(config)
logger.info("Preprocessing HuggingFace dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
logger.info("Loading custom dataset...")
raw_splits = self.custom_loader.load(config)
logger.info("Preprocessing custom dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
logger.info(f"Data loading and preprocessing completed successfully")
logger.info(f"Raw splits: {list(raw_splits.keys())}")
logger.info(f"Processed splits: {list(processed_splits.keys())}")
# Validate processed data
logger.info("Validating processed data...")
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Data validation failed:")
for error in errors:
logger.error(f" - {error}")
raise ValueError("Data validation failed")
logger.info("Data validation passed")
# Analyze dataset
logger.info("Analyzing dataset...")
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
logger.info("Dataset analysis completed")
return processed_splits, analysis
except Exception as e:
logger.error(f"Error in load_and_preprocess: {e}")
raise
2025-08-13 21:17:01 +01:00
def convert_to_alpaca_format(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
"""Convert styling data to Alpaca format with instruction"""
alpaca_splits = {}
for split_name, split_data in data.items():
alpaca_data = []
for item in split_data:
# Ensure input and output fields exist, default to empty string if missing
input_text = item.get("input", "")
output_text = item.get("output", "")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
alpaca_data.append({
"instruction": config.instruction,
"input": input_text,
"output": output_text
})
alpaca_splits[split_name] = alpaca_data
return alpaca_splits
def format_for_training(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[str]]:
"""Format entries for training using Alpaca prompt format"""
formatted_splits = {}
for split_name, split_data in data.items():
formatted_texts = []
for item in split_data:
# Ensure input and output fields exist, default to empty string if missing
input_text = item.get("input", "")
output_text = item.get("output", "")
# Handle None values
if input_text is None:
input_text = ""
if output_text is None:
output_text = ""
# Convert to string if needed
input_text = str(input_text)
output_text = str(output_text)
text = config.alpaca_prompt.format(
config.instruction,
input_text,
output_text
) + config.eos_token
formatted_texts.append(text)
formatted_splits[split_name] = formatted_texts
return formatted_splits
def convert_to_hf_dataset(self, dataset_entries: List[Dict], config: StylingConfig):
"""Convert dataset entries to HuggingFace dataset format with text formatting"""
from datasets import Dataset
# Create HuggingFace dataset from list of dictionaries
hf_dataset = Dataset.from_list(dataset_entries)
# Apply formatting function to generate the text field
def formatting_prompts_func(examples):
instructions = examples["instruction"]
inputs = examples["input"]
outputs = examples["output"]
texts = []
for instruction, input_text, output in zip(instructions, inputs, outputs):
# Handle None values and ensure strings
if input_text is None:
input_text = ""
if output is None:
output = ""
# Convert to string if needed
input_text = str(input_text)
output = str(output)
# Use the config's EOS token and alpaca prompt
text = config.alpaca_prompt.format(instruction, input_text, output) + config.eos_token
texts.append(text)
return {"text": texts}
# Apply the formatting function
formatted_dataset = hf_dataset.map(formatting_prompts_func, batched=True)
return formatted_dataset
def save_hf_dataset_to_disk(self, hf_dataset, save_path: str):
"""Save HuggingFace dataset to disk"""
try:
hf_dataset.save_to_disk(save_path)
logger.info(f"HuggingFace dataset saved to disk at: {save_path}")
return True
except Exception as e:
logger.error(f"Error saving HuggingFace dataset to disk: {e}")
return False
def load_hf_dataset_from_disk(self, load_path: str):
"""Load HuggingFace dataset from disk"""
try:
from datasets import load_from_disk
hf_dataset = load_from_disk(load_path)
logger.info(f"HuggingFace dataset loaded from disk: {load_path}")
logger.info(f"Dataset has {len(hf_dataset)} entries")
logger.info(f"Dataset features: {hf_dataset.features}")
return hf_dataset
except Exception as e:
logger.error(f"Error loading HuggingFace dataset from disk: {e}")
return None
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
"""Save processed data splits to files"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for split_name, split_data in data.items():
if format == "jsonl":
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == "json":
output_file = output_path / f"{split_name}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
elif format == "csv":
output_file = output_path / f"{split_name}.csv"
df = pd.DataFrame(split_data)
df.to_csv(output_file, index=False)
logger.info(f"Saved {len(split_data)} samples to {output_file}")
def run_pipeline(
self,
config: StylingConfig,
output_format: str = "styling",
save_splits: bool = True,
create_hf_dataset: bool = False,
save_hf_dataset: bool = False,
hf_dataset_path: str = None
) -> Dict[str, Any]:
"""Run complete styling pipeline"""
logger.info("Starting styling pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
# Convert to desired output format
if output_format == "alpaca":
formatted_splits = self.convert_to_alpaca_format(processed_splits, config)
else:
formatted_splits = processed_splits
# Save data if requested
if save_splits:
# Save directly in the output directory, not in a subdirectory
output_dir = Path(config.output_dir)
self.save_data(formatted_splits, str(output_dir))
# Convert to HuggingFace dataset if requested
hf_dataset = None
hf_dataset_save_path = None
if create_hf_dataset:
# Flatten all splits into one list for HF dataset
all_entries = []
for split_name, split_data in formatted_splits.items():
for item in split_data:
# Ensure we have the instruction field
if "instruction" not in item:
item["instruction"] = config.instruction
all_entries.append(item)
hf_dataset = self.convert_to_hf_dataset(all_entries, config)
logger.info(f"HuggingFace dataset created with {len(hf_dataset)} entries")
logger.info(f"Dataset features: {hf_dataset.features}")
# Save HuggingFace dataset to disk if requested
if save_hf_dataset:
if hf_dataset_path is None:
# Generate default path using the YAML output_dir
hf_dataset_path = str(Path(config.output_dir) / "hf_dataset")
success = self.save_hf_dataset_to_disk(hf_dataset, hf_dataset_path)
if success:
hf_dataset_save_path = hf_dataset_path
logger.info(f"HuggingFace dataset saved to: {hf_dataset_save_path}")
else:
logger.warning("Failed to save HuggingFace dataset to disk")
# Create result summary
result = {
"config": config,
"analysis": analysis,
"splits": {
split_name: len(split_data) for split_name, split_data in formatted_splits.items()
},
"output_format": output_format,
"output_dir": config.output_dir,
"data": formatted_splits, # Include the actual processed data
"instruction": config.instruction
}
# Add HuggingFace dataset info to result if created
if hf_dataset is not None:
result["hf_dataset"] = hf_dataset
if hf_dataset_save_path:
result["hf_dataset_path"] = hf_dataset_save_path
logger.info("Styling pipeline completed successfully!")
return result
# Helper functions
def create_huggingface_config(dataset_name: str, input_field: str = "text", output_field: str = "output", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig:
"""Helper function to create a HuggingFace configuration"""
return StylingConfig(
data_source="huggingface",
dataset_name=dataset_name,
input_field=input_field,
output_field=output_field,
instruction=instruction,
**kwargs
)
def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", output_field: str = "styled_text", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig:
"""Helper function to create a custom data configuration"""
return StylingConfig(
data_source="custom",
data_path=data_path,
data_format=data_format,
input_field=input_field,
output_field=output_field,
instruction=instruction,
**kwargs
)
def save_hf_dataset_to_disk(hf_dataset, save_path: str) -> bool:
"""Utility function to save HuggingFace dataset to disk"""
try:
hf_dataset.save_to_disk(save_path)
print(f"HuggingFace dataset saved to disk at: {save_path}")
return True
except Exception as e:
print(f"Error saving HuggingFace dataset to disk: {e}")
return False
def load_hf_dataset_from_disk(load_path: str):
"""Utility function to load HuggingFace dataset from disk"""
try:
from datasets import load_from_disk
hf_dataset = load_from_disk(load_path)
print(f"HuggingFace dataset loaded from disk: {load_path}")
print(f"Dataset has {len(hf_dataset)} entries")
print(f"Dataset features: {hf_dataset.features}")
return hf_dataset
except Exception as e:
print(f"Error loading HuggingFace dataset from disk: {e}")
return None
def load_yaml_config(config_path: str) -> Dict[str, Any]:
"""Load and parse YAML configuration file with proper structure handling"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# Extract configuration from YAML structure
config_dict = {}
# Handle task section
if 'task' in yaml_data:
task_data = yaml_data['task']
config_dict.update({
'task_name': task_data.get('name'),
'task_type': task_data.get('type')
})
# Handle data section
if 'data' in yaml_data:
data_config = yaml_data['data']
config_dict.update({
'data_source': data_config.get('source'),
'dataset_name': data_config.get('dataset_name'),
'data_path': data_config.get('data_path'),
'data_format': data_config.get('data_format'),
'input_field': data_config.get('input_field'),
'output_field': data_config.get('output_field'),
'instruction': data_config.get('instruction'),
'max_samples': data_config.get('max_samples'),
'train_split': data_config.get('train_split'),
'validation_split': data_config.get('validation_split'),
'test_split': data_config.get('test_split'),
'clean_text': data_config.get('clean_text'),
'lowercase': data_config.get('lowercase'),
'min_length': data_config.get('min_length'),
'max_length': data_config.get('max_length'),
'output_format': data_config.get('output_format'),
'output_dir': data_config.get('output_dir'),
'encoding': data_config.get('encoding'),
'delimiter': data_config.get('delimiter')
})
# Handle model section
if 'model' in yaml_data:
model_data = yaml_data['model']
config_dict.update({
'model_name': model_data.get('name'),
'model_max_length': model_data.get('max_length')
})
# Handle training section
if 'training' in yaml_data:
training_data = yaml_data['training']
config_dict.update({
'num_epochs': training_data.get('num_epochs'),
'batch_size': training_data.get('batch_size'),
'learning_rate': training_data.get('learning_rate'),
'weight_decay': training_data.get('weight_decay'),
'warmup_ratio': training_data.get('warmup_ratio'),
'lr_scheduler_type': training_data.get('lr_scheduler_type')
})
# Handle inference section
if 'inference' in yaml_data:
inference_data = yaml_data['inference']
config_dict.update({
'inference_batch_size': inference_data.get('batch_size'),
'max_new_tokens': inference_data.get('max_new_tokens'),
'temperature': inference_data.get('temperature')
})
logger.info(f"Successfully parsed YAML configuration from: {config_path}")
logger.info(f"Extracted {len(config_dict)} configuration parameters")
return config_dict
except Exception as e:
logger.error(f"Error loading YAML config from {config_path}: {e}")
raise
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Styling Data Processing Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
# Field mapping
parser.add_argument("--input-field", type=str, help="Input field name")
parser.add_argument("--output-field", type=str, help="Output field name")
parser.add_argument("--instruction", type=str, help="Style instruction")
# Data processing
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--train-split", type=float, help="Training split ratio")
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
parser.add_argument("--test-split", type=float, help="Test split ratio")
# Text preprocessing
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
parser.add_argument("--min-length", type=int, help="Minimum text length")
parser.add_argument("--max-length", type=int, help="Maximum text length")
# Output configuration
parser.add_argument("--output-format", choices=["styling", "alpaca"], help="Output format")
parser.add_argument("--output-dir", type=str, help="Output directory")
# HuggingFace dataset options
parser.add_argument("--create-hf-dataset", action="store_true", help="Create HuggingFace dataset")
parser.add_argument("--hf-dataset-path", type=str, help="Path to save HuggingFace dataset")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
config_dict = load_yaml_config(args.config)
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.data_source:
cli_overrides['data_source'] = args.data_source
if args.dataset_name:
cli_overrides['dataset_name'] = args.dataset_name
if args.data_path:
cli_overrides['data_path'] = args.data_path
if args.data_format:
cli_overrides['data_format'] = args.data_format
if args.input_field:
cli_overrides['input_field'] = args.input_field
if args.output_field:
cli_overrides['output_field'] = args.output_field
if args.instruction:
cli_overrides['instruction'] = args.instruction
if args.max_samples:
cli_overrides['max_samples'] = args.max_samples
if args.train_split:
cli_overrides['train_split'] = args.train_split
if args.validation_split:
cli_overrides['validation_split'] = args.validation_split
if args.test_split:
cli_overrides['test_split'] = args.test_split
if args.clean_text:
cli_overrides['clean_text'] = True
if args.remove_special_chars:
cli_overrides['remove_special_chars'] = True
if args.lowercase:
cli_overrides['lowercase'] = True
if args.min_length:
cli_overrides['min_length'] = args.min_length
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.output_format:
cli_overrides['output_format'] = args.output_format
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# HuggingFace dataset options
if args.create_hf_dataset:
cli_overrides['create_hf_dataset'] = True
if args.hf_dataset_path:
cli_overrides['hf_dataset_path'] = args.hf_dataset_path
# Logging
if args.log_level:
cli_overrides['log_level'] = args.log_level
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
if not config_dict.get('data_source'):
parser.error("--data-source is required (either in YAML config or CLI)")
if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'):
parser.error("--dataset-name is required for HuggingFace datasets")
if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'):
parser.error("--data-path is required for custom datasets")
# Create configuration object - properly handle YAML structure
config = StylingConfig(
data_source=config_dict.get('data_source', 'huggingface'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
input_field=config_dict.get('input_field', 'text'),
output_field=config_dict.get('output_field', 'styled_text'),
instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
remove_special_chars=config_dict.get('remove_special_chars', False),
lowercase=config_dict.get('lowercase', False),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 1000),
output_format=config_dict.get('output_format', 'styling'),
output_dir=config_dict.get('output_dir', './data'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8'),
delimiter=config_dict.get('delimiter', ',')
)
# Initialize pipeline
pipeline = StylingDataPipeline()
try:
print(f"Starting styling pipeline with {config.data_source} data source...")
if args.config:
print(f"Using YAML configuration: {args.config}")
print(f"Style instruction: {config.instruction}")
print()
# Check if we should create HuggingFace dataset
create_hf_dataset = cli_overrides.get('create_hf_dataset', False)
hf_dataset_path = cli_overrides.get('hf_dataset_path')
# If creating HF dataset, also save it by default
save_hf_dataset = create_hf_dataset
result = pipeline.run_pipeline(
config,
config.output_format,
save_splits=True,
create_hf_dataset=create_hf_dataset,
save_hf_dataset=save_hf_dataset,
hf_dataset_path=hf_dataset_path
)
print(f"✅ Pipeline completed successfully!")
print(f" Data source: {config.data_source}")
if config.data_source == "huggingface":
print(f" Dataset: {config.dataset_name}")
else:
print(f" Data file: {config.data_path}")
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
print(f" Output directory: {config.output_dir}")
print(f" Style instruction: {config.instruction}")
except Exception as e:
print(f"❌ Error running pipeline: {e}")
2025-08-13 23:50:20 +00:00
import traceback
print("Full error traceback:")
traceback.print_exc()
2025-08-13 21:17:01 +01:00
sys.exit(1)
if __name__ == "__main__":
main()