1513 lines
65 KiB
Python
1513 lines
65 KiB
Python
import json
|
|
import pandas as pd
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
|
from datasets import Dataset, load_dataset
|
|
import os
|
|
from dataclasses import dataclass
|
|
from abc import ABC, abstractmethod
|
|
import logging
|
|
from sklearn.model_selection import train_test_split
|
|
import re
|
|
import argparse
|
|
import sys
|
|
import yaml
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
@dataclass
|
|
class StylingConfig:
|
|
"""Configuration for styling tasks"""
|
|
# Data source configuration
|
|
data_source: str = "huggingface" # "huggingface" or "custom"
|
|
dataset_name: Optional[str] = None # For Hugging Face datasets
|
|
data_path: Optional[str] = None # For custom datasets
|
|
data_format: str = "jsonl" # jsonl, csv, json
|
|
|
|
# Field mapping - User configures which fields map to input/output
|
|
input_field: str = "text" # Field in dataset containing source text (e.g., "text", "source", etc.)
|
|
output_field: str = "styled_text" # Field in dataset containing styled text (e.g., "styled_text", "target", etc.)
|
|
instruction: str = "Rewrite the following text in a formal style" # Style instruction from YAML
|
|
|
|
# Data processing
|
|
max_samples: Optional[int] = None
|
|
train_split: float = 0.8
|
|
validation_split: float = 0.1
|
|
test_split: float = 0.1
|
|
|
|
# Text preprocessing
|
|
clean_text: bool = True
|
|
remove_special_chars: bool = False
|
|
lowercase: bool = False # Keep original case for styling
|
|
min_length: int = 10
|
|
max_length: int = 1000
|
|
|
|
# Output configuration
|
|
output_format: str = "styling" # instruction, conversation, qa
|
|
output_dir: str = "./data"
|
|
|
|
# Hugging Face specific
|
|
hf_split: str = "train"
|
|
hf_cache_dir: Optional[str] = None
|
|
|
|
# Split configuration
|
|
test_split_from: str = "train"
|
|
val_split_from: str = "train"
|
|
|
|
# Custom data specific
|
|
encoding: str = "utf-8"
|
|
delimiter: str = "," # For CSV files
|
|
|
|
# Alpaca prompt configuration
|
|
alpaca_prompt: str = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
|
|
|
|
### Instruction:
|
|
{}
|
|
|
|
### Input:
|
|
{}
|
|
|
|
### Response:
|
|
{}"""
|
|
|
|
eos_token: str = "<|eot_id|>" # Use <|eot_id|> as EOS token
|
|
|
|
class DataValidator:
|
|
"""Validates styling data quality and format"""
|
|
|
|
@staticmethod
|
|
def validate_styling_data(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
|
|
"""Validate styling dataset splits"""
|
|
errors = []
|
|
|
|
# Check if we have the expected splits
|
|
expected_splits = ["train", "validation", "test"]
|
|
for split in expected_splits:
|
|
if split not in data:
|
|
errors.append(f"Missing '{split}' split")
|
|
elif split == "train" and not data[split]:
|
|
errors.append(f"Train split cannot be empty")
|
|
# Allow validation and test splits to be empty for small datasets
|
|
|
|
if errors:
|
|
return False, errors
|
|
|
|
total_samples = sum(len(split_data) for split_data in data.values())
|
|
logger.info(f"Validating {total_samples} total samples across all splits...")
|
|
|
|
# Determine field names based on whether data is processed or not
|
|
input_field = "input" if is_processed else config.input_field
|
|
output_field = "output" if is_processed else config.output_field
|
|
|
|
# Validate each split
|
|
for split_name, split_data in data.items():
|
|
if not split_data:
|
|
logger.info(f"Skipping validation for empty {split_name} split")
|
|
continue
|
|
|
|
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
|
|
|
|
# Check required fields
|
|
missing_input_count = 0
|
|
missing_output_count = 0
|
|
|
|
for i, item in enumerate(split_data):
|
|
if input_field not in item:
|
|
errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}")
|
|
missing_input_count += 1
|
|
if output_field not in item:
|
|
errors.append(f"Missing output field '{output_field}' in {split_name} split, item {i}")
|
|
missing_output_count += 1
|
|
|
|
logger.info(f"{split_name} - Items missing input field: {missing_input_count}")
|
|
logger.info(f"{split_name} - Items missing output field: {missing_output_count}")
|
|
|
|
# Check data types
|
|
type_errors = 0
|
|
for i, item in enumerate(split_data):
|
|
if not isinstance(item.get(input_field, ""), str):
|
|
errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}")
|
|
type_errors += 1
|
|
if not isinstance(item.get(output_field, ""), str):
|
|
errors.append(f"Output field '{output_field}' must be string in {split_name} split, item {i}")
|
|
type_errors += 1
|
|
|
|
logger.info(f"{split_name} - Type errors: {type_errors}")
|
|
|
|
# Check for empty inputs/outputs
|
|
empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip())
|
|
empty_outputs = sum(1 for item in split_data if not item.get(output_field, "").strip())
|
|
|
|
if empty_inputs > 0:
|
|
errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split")
|
|
if empty_outputs > 0:
|
|
errors.append(f"Found {empty_outputs} items with empty output text in {split_name} split")
|
|
|
|
logger.info(f"{split_name} - Empty inputs: {empty_inputs}")
|
|
logger.info(f"{split_name} - Empty outputs: {empty_outputs}")
|
|
|
|
# Show sample of processed data for debugging
|
|
if split_data:
|
|
logger.info(f"Sample processed items from {split_name}:")
|
|
for i in range(min(3, len(split_data))):
|
|
item = split_data[i]
|
|
logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', output='{item.get(output_field, '')[:50]}...'")
|
|
|
|
return len(errors) == 0, errors
|
|
|
|
@staticmethod
|
|
def analyze_dataset(data: Dict[str, List[Dict]], config: StylingConfig, is_processed: bool = False) -> Dict[str, Any]:
|
|
"""Analyze dataset characteristics across all splits"""
|
|
analysis = {
|
|
"splits": {},
|
|
"overall": {
|
|
"total_samples": 0,
|
|
"split_sizes": {}
|
|
}
|
|
}
|
|
|
|
# Determine field names based on whether data is processed or not
|
|
input_field = "input" if is_processed else config.input_field
|
|
output_field = "output" if is_processed else config.output_field
|
|
|
|
# Analyze each split
|
|
for split_name, split_data in data.items():
|
|
if not split_data:
|
|
# Handle empty splits
|
|
split_analysis = {
|
|
"total_samples": 0,
|
|
"text_length_stats": {},
|
|
"missing_values": {}
|
|
}
|
|
analysis["splits"][split_name] = split_analysis
|
|
analysis["overall"]["split_sizes"][split_name] = 0
|
|
continue
|
|
|
|
split_analysis = {
|
|
"total_samples": len(split_data),
|
|
"text_length_stats": {},
|
|
"missing_values": {}
|
|
}
|
|
|
|
# Text length statistics for both input and output
|
|
for field_name, field in [("input", input_field), ("output", output_field)]:
|
|
text_lengths = [len(item.get(field, "")) for item in split_data]
|
|
if text_lengths:
|
|
split_analysis["text_length_stats"][field_name] = {
|
|
"min": min(text_lengths),
|
|
"max": max(text_lengths),
|
|
"mean": np.mean(text_lengths),
|
|
"median": np.median(text_lengths)
|
|
}
|
|
|
|
# Missing values
|
|
for field in [input_field, output_field]:
|
|
missing_count = sum(1 for item in split_data if not item.get(field))
|
|
split_analysis["missing_values"][field] = missing_count
|
|
|
|
analysis["splits"][split_name] = split_analysis
|
|
analysis["overall"]["total_samples"] += len(split_data)
|
|
analysis["overall"]["split_sizes"][split_name] = len(split_data)
|
|
|
|
return analysis
|
|
|
|
class BaseDataLoader(ABC):
|
|
"""Abstract base class for data loaders"""
|
|
|
|
@abstractmethod
|
|
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Load data and return dictionary with train/val/test splits"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Apply preprocessing steps to all splits"""
|
|
pass
|
|
|
|
|
|
class HuggingFaceDataLoader(BaseDataLoader):
|
|
"""Load datasets from Hugging Face Hub"""
|
|
|
|
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Load dataset from Hugging Face Hub with flexible split handling"""
|
|
if not config.dataset_name:
|
|
raise ValueError("Dataset name is required for Hugging Face datasets")
|
|
|
|
logger.info(f"Loading Hugging Face dataset: {config.dataset_name}")
|
|
|
|
try:
|
|
# First, let's check what splits are available in the dataset
|
|
dataset = load_dataset(
|
|
config.dataset_name,
|
|
cache_dir=config.hf_cache_dir
|
|
)
|
|
|
|
# Log available splits
|
|
available_splits = list(dataset.keys())
|
|
logger.info(f"Available splits in dataset: {available_splits}")
|
|
|
|
# Initialize split data
|
|
splits_data = {
|
|
"train": [],
|
|
"validation": [],
|
|
"test": []
|
|
}
|
|
|
|
# Handle train split
|
|
if "train" in available_splits:
|
|
train_dataset = dataset["train"]
|
|
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
|
|
splits_data["train"] = list(train_dataset)
|
|
else:
|
|
logger.error("No 'train' split found in dataset!")
|
|
logger.error(f"Available splits: {available_splits}")
|
|
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
|
|
|
|
# Handle validation split
|
|
if config.val_split_from == "use_val_if_available" and "validation" in available_splits:
|
|
val_dataset = dataset["validation"]
|
|
logger.info(f"Using 'validation' split with {len(val_dataset)} samples")
|
|
splits_data["validation"] = list(val_dataset)
|
|
elif config.val_split_from == "use_val_if_available" and "val" in available_splits:
|
|
val_dataset = dataset["val"]
|
|
logger.info(f"Using 'val' split with {len(val_dataset)} samples")
|
|
splits_data["validation"] = list(val_dataset)
|
|
elif config.val_split_from == "use_val_if_available":
|
|
logger.warning("No validation split found in dataset. Will create from train split.")
|
|
logger.info(f"Available splits: {available_splits}")
|
|
logger.info(f"Will use {config.validation_split * 100}% of train data for validation")
|
|
else:
|
|
logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)")
|
|
|
|
# Handle test split
|
|
if config.test_split_from == "use_test_if_available" and "test" in available_splits:
|
|
test_dataset = dataset["test"]
|
|
logger.info(f"Using 'test' split with {len(test_dataset)} samples")
|
|
splits_data["test"] = list(test_dataset)
|
|
elif config.test_split_from == "use_val_if_available" and "validation" in available_splits:
|
|
test_dataset = dataset["validation"]
|
|
logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples")
|
|
splits_data["test"] = list(test_dataset)
|
|
elif config.test_split_from == "use_val_if_available" and "val" in available_splits:
|
|
test_dataset = dataset["val"]
|
|
logger.info(f"Using 'val' split as test with {len(test_dataset)} samples")
|
|
splits_data["test"] = list(test_dataset)
|
|
elif config.test_split_from == "use_test_if_available":
|
|
logger.warning("No test split found in dataset. Will create from train split.")
|
|
logger.info(f"Available splits: {available_splits}")
|
|
logger.info(f"Will use {config.test_split * 100}% of train data for test")
|
|
else:
|
|
logger.info(f"Will create test split from train data ({config.test_split * 100}%)")
|
|
|
|
# If we need to create splits from train data
|
|
if not splits_data["validation"] or not splits_data["test"]:
|
|
train_data = splits_data["train"]
|
|
|
|
# Handle very small datasets
|
|
if len(train_data) < 3:
|
|
logger.warning(f"Dataset has only {len(train_data)} samples. Using all data for training.")
|
|
splits_data["train"] = train_data
|
|
splits_data["validation"] = []
|
|
splits_data["test"] = []
|
|
else:
|
|
# Calculate remaining percentages for train
|
|
total_train_percentage = config.train_split + config.validation_split + config.test_split
|
|
if total_train_percentage != 1.0:
|
|
logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...")
|
|
# Normalize percentages
|
|
config.train_split = config.train_split / total_train_percentage
|
|
config.validation_split = config.validation_split / total_train_percentage
|
|
config.test_split = config.test_split / total_train_percentage
|
|
|
|
# Create splits from train data
|
|
if not splits_data["validation"] and not splits_data["test"]:
|
|
# Split train into train, val, test
|
|
train_size = int(len(train_data) * config.train_split)
|
|
val_size = int(len(train_data) * config.validation_split)
|
|
|
|
# Handle small datasets
|
|
if len(train_data) < 10:
|
|
# For small datasets, use more conservative splits
|
|
config.train_split = 0.6
|
|
config.validation_split = 0.2
|
|
config.test_split = 0.2
|
|
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
|
|
|
|
# Ensure minimum sizes
|
|
min_val_size = max(1, int(len(train_data) * 0.1))
|
|
min_test_size = max(1, int(len(train_data) * 0.1))
|
|
|
|
val_size = max(min_val_size, int(len(train_data) * config.validation_split))
|
|
test_size = max(min_test_size, int(len(train_data) * config.test_split))
|
|
train_size = len(train_data) - val_size - test_size
|
|
|
|
# Ensure train has at least 1 sample
|
|
if train_size < 1:
|
|
if val_size > 1:
|
|
val_size -= 1
|
|
train_size += 1
|
|
elif test_size > 1:
|
|
test_size -= 1
|
|
train_size += 1
|
|
logger.info(f"Adjusted split sizes: train={train_size}, val={val_size}, test={test_size}")
|
|
|
|
# First split: train + (val+test)
|
|
new_train, temp_data = train_test_split(
|
|
train_data,
|
|
test_size=val_size + test_size,
|
|
random_state=42
|
|
)
|
|
|
|
# Second split: val + test
|
|
new_val, new_test = train_test_split(
|
|
temp_data,
|
|
test_size=test_size / (val_size + test_size) if (val_size + test_size) > 0 else 0,
|
|
random_state=42
|
|
)
|
|
|
|
splits_data["train"] = new_train
|
|
splits_data["validation"] = new_val
|
|
splits_data["test"] = new_test
|
|
|
|
elif not splits_data["validation"]:
|
|
# Only need to create val from train
|
|
val_size = max(1, int(len(train_data) * config.validation_split))
|
|
new_train, new_val = train_test_split(
|
|
train_data,
|
|
test_size=val_size,
|
|
random_state=42
|
|
)
|
|
splits_data["train"] = new_train
|
|
splits_data["validation"] = new_val
|
|
|
|
elif not splits_data["test"]:
|
|
# Only need to create test from train
|
|
test_size = max(1, int(len(train_data) * config.test_split))
|
|
new_train, new_test = train_test_split(
|
|
train_data,
|
|
test_size=test_size,
|
|
random_state=42
|
|
)
|
|
splits_data["train"] = new_train
|
|
splits_data["test"] = new_test
|
|
|
|
logger.info(f"Final split sizes:")
|
|
logger.info(f" Train: {len(splits_data['train'])} samples")
|
|
logger.info(f" Validation: {len(splits_data['validation'])} samples")
|
|
logger.info(f" Test: {len(splits_data['test'])} samples")
|
|
|
|
# Ensure all splits exist (even if empty) for the pipeline
|
|
if "validation" not in splits_data:
|
|
splits_data["validation"] = []
|
|
if "test" not in splits_data:
|
|
splits_data["test"] = []
|
|
|
|
# Apply max_samples limit to each split if specified
|
|
if config.max_samples:
|
|
for split_name in splits_data:
|
|
if splits_data[split_name]:
|
|
original_size = len(splits_data[split_name])
|
|
splits_data[split_name] = splits_data[split_name][:config.max_samples]
|
|
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
|
|
|
|
# Log dataset info for debugging
|
|
for split_name, split_data in splits_data.items():
|
|
if split_data:
|
|
logger.info(f"Sample data item from {split_name}: {split_data[0]}")
|
|
logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}")
|
|
|
|
# Check if the required fields exist
|
|
if config.input_field not in split_data[0]:
|
|
logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
|
|
# Suggest alternative fields
|
|
text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])]
|
|
if text_fields:
|
|
logger.info(f"Suggested text fields for {split_name}: {text_fields}")
|
|
if config.output_field not in split_data[0]:
|
|
logger.warning(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
|
|
# Suggest alternative fields
|
|
output_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['output', 'response', 'result', 'target', 'styled'])]
|
|
if output_fields:
|
|
logger.info(f"Suggested output fields for {split_name}: {output_fields}")
|
|
|
|
logger.info(f"Successfully loaded dataset {config.dataset_name}")
|
|
return splits_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
|
|
raise
|
|
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Apply preprocessing steps to all splits separately"""
|
|
processed_splits = {}
|
|
|
|
logger.info(f"=== PREPROCESSING DATA ===")
|
|
|
|
for split_name, split_data in data.items():
|
|
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
|
|
|
|
# Log field availability for debugging
|
|
if split_data:
|
|
available_fields = set(split_data[0].keys())
|
|
logger.info(f"Available fields in {split_name}: {available_fields}")
|
|
logger.info(f"Looking for input field: '{config.input_field}', output field: '{config.output_field}'")
|
|
|
|
if config.input_field not in available_fields:
|
|
logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}")
|
|
if config.output_field not in available_fields:
|
|
logger.error(f"Output field '{config.output_field}' not found in {split_name}. Available fields: {available_fields}")
|
|
|
|
# Count items with missing fields
|
|
missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field))
|
|
missing_output = sum(1 for item in split_data if config.output_field not in item or not item.get(config.output_field))
|
|
|
|
logger.info(f"{split_name} - Items missing input field: {missing_input}")
|
|
logger.info(f"{split_name} - Items missing output field: {missing_output}")
|
|
|
|
# Show sample of raw data before preprocessing
|
|
logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===")
|
|
for i in range(min(3, len(split_data))):
|
|
item = split_data[i]
|
|
logger.info(f"Raw item {i} from {split_name}:")
|
|
for key, value in item.items():
|
|
if isinstance(value, str) and len(value) > 100:
|
|
logger.info(f" {key}: '{value[:100]}...'")
|
|
else:
|
|
logger.info(f" {key}: {value}")
|
|
|
|
# Process each item in the split
|
|
processed_data = []
|
|
processed_count = 0
|
|
skipped_count = 0
|
|
|
|
# Reset debug counter for each split
|
|
self._debug_count = 0
|
|
|
|
for i, item in enumerate(split_data):
|
|
processed_item = self._preprocess_item(item, config)
|
|
if processed_item is not None:
|
|
processed_data.append(processed_item)
|
|
processed_count += 1
|
|
else:
|
|
skipped_count += 1
|
|
if skipped_count <= 3: # Log first few skipped items
|
|
logger.info(f"Skipped item {i} from {split_name}: {item}")
|
|
|
|
processed_splits[split_name] = processed_data
|
|
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
|
|
|
# Show sample of processed data
|
|
if processed_data:
|
|
logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===")
|
|
for i in range(min(3, len(processed_data))):
|
|
logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}")
|
|
|
|
return processed_splits
|
|
|
|
def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]:
|
|
"""Preprocess a single item"""
|
|
# Extract input and output using configurable field names
|
|
input_text = item.get(config.input_field, "")
|
|
output_text = item.get(config.output_field, "")
|
|
|
|
# Log what we're extracting (for first few items)
|
|
if hasattr(self, '_debug_count'):
|
|
self._debug_count += 1
|
|
else:
|
|
self._debug_count = 1
|
|
|
|
if self._debug_count <= 3:
|
|
logger.debug(f"Processing item {self._debug_count}:")
|
|
logger.debug(f" Looking for input field '{config.input_field}': {input_text}")
|
|
logger.debug(f" Looking for output field '{config.output_field}': {output_text}")
|
|
|
|
# Handle None values
|
|
if input_text is None:
|
|
input_text = ""
|
|
if output_text is None:
|
|
output_text = ""
|
|
|
|
# Convert to string if needed
|
|
input_text = str(input_text)
|
|
output_text = str(output_text)
|
|
|
|
if self._debug_count <= 3:
|
|
logger.debug(f" After conversion - input: '{input_text[:50]}...', output: '{output_text[:50]}...'")
|
|
|
|
# Clean text if requested
|
|
if config.clean_text:
|
|
original_input = input_text
|
|
original_output = output_text
|
|
input_text = self._clean_text(input_text, config)
|
|
output_text = self._clean_text(output_text, config)
|
|
if self._debug_count <= 3:
|
|
logger.debug(f" After cleaning - input: '{original_input[:50]}...' -> '{input_text[:50]}...'")
|
|
logger.debug(f" After cleaning - output: '{original_output[:50]}...' -> '{output_text[:50]}...'")
|
|
|
|
# Check length constraints
|
|
if len(input_text) < config.min_length or len(input_text) > config.max_length:
|
|
if self._debug_count <= 3:
|
|
logger.debug(f" Skipping - input length {len(input_text)} not in range [{config.min_length}, {config.max_length}]")
|
|
return None
|
|
|
|
if len(output_text) < config.min_length or len(output_text) > config.max_length:
|
|
if self._debug_count <= 3:
|
|
logger.debug(f" Skipping - output length {len(output_text)} not in range [{config.min_length}, {config.max_length}]")
|
|
return None
|
|
|
|
# Create processed item - Always use "input" and "output" for internal processing
|
|
processed_item = {
|
|
"input": input_text,
|
|
"output": output_text
|
|
}
|
|
|
|
if self._debug_count <= 3:
|
|
logger.debug(f" Final processed item: {processed_item}")
|
|
|
|
return processed_item
|
|
|
|
def _clean_text(self, text: str, config: StylingConfig) -> str:
|
|
"""Clean and normalize text"""
|
|
if not isinstance(text, str):
|
|
return ""
|
|
|
|
# Remove extra whitespace
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
# Convert to lowercase if requested
|
|
if config.lowercase:
|
|
text = text.lower()
|
|
|
|
# Remove special characters if requested
|
|
if config.remove_special_chars:
|
|
text = re.sub(r'[^\w\s]', '', text)
|
|
|
|
return text
|
|
|
|
|
|
class CustomDataLoader(BaseDataLoader):
|
|
"""Load custom datasets from local files"""
|
|
|
|
def load(self, config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Load custom dataset from local file and create splits"""
|
|
if not config.data_path:
|
|
raise ValueError("Data path is required for custom datasets")
|
|
|
|
file_path = Path(config.data_path)
|
|
|
|
if not file_path.exists():
|
|
raise FileNotFoundError(f"Data file not found: {file_path}")
|
|
|
|
logger.info(f"Loading custom dataset: {file_path}")
|
|
|
|
if config.data_format == "jsonl":
|
|
raw_data = self._load_jsonl(file_path, config)
|
|
elif config.data_format == "csv":
|
|
raw_data = self._load_csv(file_path, config)
|
|
elif config.data_format == "json":
|
|
raw_data = self._load_json(file_path, config)
|
|
else:
|
|
raise ValueError(f"Unsupported format: {config.data_format}")
|
|
|
|
if config.max_samples:
|
|
raw_data = raw_data[:config.max_samples]
|
|
|
|
logger.info(f"Loaded {len(raw_data)} samples from {file_path}")
|
|
|
|
# Create splits from the raw data
|
|
splits_data = self._create_splits(raw_data, config)
|
|
|
|
return splits_data
|
|
|
|
def _create_splits(self, data: List[Dict], config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Create train/validation/test splits from raw data"""
|
|
logger.info(f"Creating splits from {len(data)} samples...")
|
|
|
|
# Handle very small datasets
|
|
if len(data) < 3:
|
|
logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.")
|
|
return {
|
|
"train": data,
|
|
"validation": [],
|
|
"test": []
|
|
}
|
|
|
|
# Calculate split sizes with minimum guarantees
|
|
total_samples = len(data)
|
|
|
|
# Ensure minimum sizes for each split
|
|
min_val_size = max(1, int(total_samples * 0.1)) # At least 1 sample for validation
|
|
min_test_size = max(1, int(total_samples * 0.1)) # At least 1 sample for test
|
|
|
|
# Adjust split ratios if dataset is too small
|
|
if total_samples < 10:
|
|
# For small datasets, use more conservative splits
|
|
config.train_split = 0.6
|
|
config.validation_split = 0.2
|
|
config.test_split = 0.2
|
|
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
|
|
|
|
# Calculate actual split sizes
|
|
val_size = max(min_val_size, int(total_samples * config.validation_split))
|
|
test_size = max(min_test_size, int(total_samples * config.test_split))
|
|
train_size = total_samples - val_size - test_size
|
|
|
|
# Ensure train split has at least 1 sample
|
|
if train_size < 1:
|
|
# Adjust validation and test to ensure train has at least 1 sample
|
|
if val_size > 1:
|
|
val_size -= 1
|
|
train_size += 1
|
|
elif test_size > 1:
|
|
test_size -= 1
|
|
train_size += 1
|
|
logger.info(f"Adjusted split sizes to ensure train has at least 1 sample: train={train_size}, val={val_size}, test={test_size}")
|
|
|
|
logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
|
|
|
|
# Create splits
|
|
if val_size == 0 and test_size == 0:
|
|
# All data goes to train
|
|
splits_data = {
|
|
"train": data,
|
|
"validation": [],
|
|
"test": []
|
|
}
|
|
elif val_size == 0:
|
|
# Split between train and test
|
|
train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
|
|
splits_data = {
|
|
"train": train_data,
|
|
"validation": [],
|
|
"test": test_data
|
|
}
|
|
elif test_size == 0:
|
|
# Split between train and validation
|
|
train_data, val_data = train_test_split(data, test_size=val_size, random_state=42)
|
|
splits_data = {
|
|
"train": train_data,
|
|
"validation": val_data,
|
|
"test": []
|
|
}
|
|
else:
|
|
# Full three-way split
|
|
# First split: train + (val+test)
|
|
train_data, temp_data = train_test_split(
|
|
data,
|
|
test_size=val_size + test_size,
|
|
random_state=42
|
|
)
|
|
|
|
# Second split: val + test
|
|
val_data, test_data = train_test_split(
|
|
temp_data,
|
|
test_size=test_size,
|
|
random_state=42
|
|
)
|
|
|
|
splits_data = {
|
|
"train": train_data,
|
|
"validation": val_data,
|
|
"test": test_data
|
|
}
|
|
|
|
logger.info(f"Created splits:")
|
|
logger.info(f" Train: {len(splits_data['train'])} samples")
|
|
logger.info(f" Validation: {len(splits_data['validation'])} samples")
|
|
logger.info(f" Test: {len(splits_data['test'])} samples")
|
|
|
|
return splits_data
|
|
|
|
def _load_jsonl(self, file_path: Path, config: StylingConfig) -> List[Dict]:
|
|
"""Load JSONL file"""
|
|
data = []
|
|
with open(file_path, 'r', encoding=config.encoding) as f:
|
|
for line_num, line in enumerate(f, 1):
|
|
if line.strip():
|
|
try:
|
|
data.append(json.loads(line))
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Invalid JSON at line {line_num}: {e}")
|
|
return data
|
|
|
|
def _load_csv(self, file_path: Path, config: StylingConfig) -> List[Dict]:
|
|
"""Load CSV file"""
|
|
df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter)
|
|
return df.to_dict('records')
|
|
|
|
def _load_json(self, file_path: Path, config: StylingConfig) -> List[Dict]:
|
|
"""Load JSON file"""
|
|
with open(file_path, 'r', encoding=config.encoding) as f:
|
|
data = json.load(f)
|
|
|
|
if isinstance(data, list):
|
|
return data
|
|
elif isinstance(data, dict) and "data" in data:
|
|
return data["data"]
|
|
else:
|
|
return [data]
|
|
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Apply preprocessing steps to all splits separately"""
|
|
processed_splits = {}
|
|
|
|
logger.info(f"=== PREPROCESSING CUSTOM DATA ===")
|
|
|
|
for split_name, split_data in data.items():
|
|
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
|
|
|
|
processed_data = []
|
|
processed_count = 0
|
|
skipped_count = 0
|
|
|
|
# Reset debug counter for each split
|
|
self._debug_count = 0
|
|
|
|
for i, item in enumerate(split_data):
|
|
processed_item = self._preprocess_item(item, config)
|
|
if processed_item is not None:
|
|
processed_data.append(processed_item)
|
|
processed_count += 1
|
|
else:
|
|
skipped_count += 1
|
|
if skipped_count <= 3: # Log first few skipped items
|
|
logger.info(f"Skipped item {i} from {split_name}: {item}")
|
|
|
|
processed_splits[split_name] = processed_data
|
|
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
|
|
|
return processed_splits
|
|
|
|
def _preprocess_item(self, item: Dict, config: StylingConfig) -> Optional[Dict]:
|
|
"""Preprocess a single item"""
|
|
# Extract input and output using configurable field names
|
|
input_text = item.get(config.input_field, "")
|
|
output_text = item.get(config.output_field, "")
|
|
|
|
# Handle None values
|
|
if input_text is None:
|
|
input_text = ""
|
|
if output_text is None:
|
|
output_text = ""
|
|
|
|
# Convert to string if needed
|
|
input_text = str(input_text)
|
|
output_text = str(output_text)
|
|
|
|
# Clean text if requested
|
|
if config.clean_text:
|
|
input_text = self._clean_text(input_text, config)
|
|
output_text = self._clean_text(output_text, config)
|
|
|
|
# Check length constraints
|
|
if len(input_text) < config.min_length or len(input_text) > config.max_length:
|
|
return None
|
|
|
|
if len(output_text) < config.min_length or len(output_text) > config.max_length:
|
|
return None
|
|
|
|
# Create processed item - Always use "input" and "output" for internal processing
|
|
processed_item = {
|
|
"input": input_text,
|
|
"output": output_text
|
|
}
|
|
|
|
return processed_item
|
|
|
|
def _clean_text(self, text: str, config: StylingConfig) -> str:
|
|
"""Clean and normalize text"""
|
|
if not isinstance(text, str):
|
|
return ""
|
|
|
|
# Remove extra whitespace
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
# Convert to lowercase if requested
|
|
if config.lowercase:
|
|
text = text.lower()
|
|
|
|
# Remove special characters if requested
|
|
if config.remove_special_chars:
|
|
text = re.sub(r'[^\w\s]', '', text)
|
|
|
|
return text
|
|
|
|
|
|
class StylingDataPipeline:
|
|
"""Main styling pipeline"""
|
|
|
|
def __init__(self):
|
|
self.validator = DataValidator()
|
|
self.hf_loader = HuggingFaceDataLoader()
|
|
self.custom_loader = CustomDataLoader()
|
|
|
|
def create_config(
|
|
self,
|
|
data_source: str,
|
|
dataset_name: Optional[str] = None,
|
|
data_path: Optional[str] = None,
|
|
input_field: str = "input",
|
|
output_field: str = "output",
|
|
instruction: str = "Rewrite the following text in a formal style",
|
|
**kwargs
|
|
) -> StylingConfig:
|
|
"""Create styling configuration"""
|
|
return StylingConfig(
|
|
data_source=data_source,
|
|
dataset_name=dataset_name,
|
|
data_path=data_path,
|
|
input_field=input_field,
|
|
output_field=output_field,
|
|
instruction=instruction,
|
|
**kwargs
|
|
)
|
|
|
|
def load_config_from_yaml(self, yaml_path: str) -> StylingConfig:
|
|
"""Load configuration from YAML file"""
|
|
try:
|
|
config_dict = load_yaml_config(yaml_path)
|
|
|
|
# Create configuration object from YAML data
|
|
config = StylingConfig(
|
|
data_source=config_dict.get('data_source', 'custom'),
|
|
dataset_name=config_dict.get('dataset_name'),
|
|
data_path=config_dict.get('data_path'),
|
|
data_format=config_dict.get('data_format', 'jsonl'),
|
|
input_field=config_dict.get('input_field', 'text'),
|
|
output_field=config_dict.get('output_field', 'styled_text'),
|
|
instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'),
|
|
max_samples=config_dict.get('max_samples'),
|
|
train_split=config_dict.get('train_split', 0.8),
|
|
validation_split=config_dict.get('validation_split', 0.1),
|
|
test_split=config_dict.get('test_split', 0.1),
|
|
clean_text=config_dict.get('clean_text', True),
|
|
remove_special_chars=config_dict.get('remove_special_chars', False),
|
|
lowercase=config_dict.get('lowercase', False),
|
|
min_length=config_dict.get('min_length', 10),
|
|
max_length=config_dict.get('max_length', 1000),
|
|
output_format=config_dict.get('output_format', 'styling'),
|
|
output_dir=config_dict.get('output_dir', './data'),
|
|
hf_split=config_dict.get('hf_split', 'train'),
|
|
hf_cache_dir=config_dict.get('hf_cache_dir'),
|
|
test_split_from=config_dict.get('test_split_from', 'train'),
|
|
val_split_from=config_dict.get('val_split_from', 'train'),
|
|
encoding=config_dict.get('encoding', 'utf-8'),
|
|
delimiter=config_dict.get('delimiter', ',')
|
|
)
|
|
|
|
logger.info(f"Configuration loaded from YAML: {yaml_path}")
|
|
logger.info(f"Output directory: {config.output_dir}")
|
|
logger.info(f"Instruction: {config.instruction}")
|
|
|
|
return config
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading configuration from YAML {yaml_path}: {e}")
|
|
raise
|
|
|
|
def load_and_preprocess(self, config: StylingConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
|
|
"""Load and preprocess data"""
|
|
|
|
logger.info(f"Starting data loading and preprocessing...")
|
|
logger.info(f"Data source: {config.data_source}")
|
|
|
|
try:
|
|
# Load data
|
|
if config.data_source == "huggingface":
|
|
logger.info("Loading HuggingFace dataset...")
|
|
raw_splits = self.hf_loader.load(config)
|
|
logger.info("Preprocessing HuggingFace dataset...")
|
|
processed_splits = self.hf_loader.preprocess(raw_splits, config)
|
|
elif config.data_source == "custom":
|
|
logger.info("Loading custom dataset...")
|
|
raw_splits = self.custom_loader.load(config)
|
|
logger.info("Preprocessing custom dataset...")
|
|
processed_splits = self.custom_loader.preprocess(raw_splits, config)
|
|
else:
|
|
raise ValueError(f"Unsupported data source: {config.data_source}")
|
|
|
|
logger.info(f"Data loading and preprocessing completed successfully")
|
|
logger.info(f"Raw splits: {list(raw_splits.keys())}")
|
|
logger.info(f"Processed splits: {list(processed_splits.keys())}")
|
|
|
|
# Validate processed data
|
|
logger.info("Validating processed data...")
|
|
is_valid, errors = self.validator.validate_styling_data(processed_splits, config, is_processed=True)
|
|
if not is_valid:
|
|
logger.error("Data validation failed:")
|
|
for error in errors:
|
|
logger.error(f" - {error}")
|
|
raise ValueError("Data validation failed")
|
|
|
|
logger.info("Data validation passed")
|
|
|
|
# Analyze dataset
|
|
logger.info("Analyzing dataset...")
|
|
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
|
|
logger.info("Dataset analysis completed")
|
|
|
|
return processed_splits, analysis
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in load_and_preprocess: {e}")
|
|
raise
|
|
|
|
def convert_to_alpaca_format(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[Dict]]:
|
|
"""Convert styling data to Alpaca format with instruction"""
|
|
alpaca_splits = {}
|
|
|
|
for split_name, split_data in data.items():
|
|
alpaca_data = []
|
|
for item in split_data:
|
|
# Ensure input and output fields exist, default to empty string if missing
|
|
input_text = item.get("input", "")
|
|
output_text = item.get("output", "")
|
|
|
|
# Handle None values
|
|
if input_text is None:
|
|
input_text = ""
|
|
if output_text is None:
|
|
output_text = ""
|
|
|
|
# Convert to string if needed
|
|
input_text = str(input_text)
|
|
output_text = str(output_text)
|
|
|
|
alpaca_data.append({
|
|
"instruction": config.instruction,
|
|
"input": input_text,
|
|
"output": output_text
|
|
})
|
|
alpaca_splits[split_name] = alpaca_data
|
|
|
|
return alpaca_splits
|
|
|
|
def format_for_training(self, data: Dict[str, List[Dict]], config: StylingConfig) -> Dict[str, List[str]]:
|
|
"""Format entries for training using Alpaca prompt format"""
|
|
formatted_splits = {}
|
|
|
|
for split_name, split_data in data.items():
|
|
formatted_texts = []
|
|
for item in split_data:
|
|
# Ensure input and output fields exist, default to empty string if missing
|
|
input_text = item.get("input", "")
|
|
output_text = item.get("output", "")
|
|
|
|
# Handle None values
|
|
if input_text is None:
|
|
input_text = ""
|
|
if output_text is None:
|
|
output_text = ""
|
|
|
|
# Convert to string if needed
|
|
input_text = str(input_text)
|
|
output_text = str(output_text)
|
|
|
|
text = config.alpaca_prompt.format(
|
|
config.instruction,
|
|
input_text,
|
|
output_text
|
|
) + config.eos_token
|
|
formatted_texts.append(text)
|
|
formatted_splits[split_name] = formatted_texts
|
|
|
|
return formatted_splits
|
|
|
|
def convert_to_hf_dataset(self, dataset_entries: List[Dict], config: StylingConfig):
|
|
"""Convert dataset entries to HuggingFace dataset format with text formatting"""
|
|
from datasets import Dataset
|
|
|
|
# Create HuggingFace dataset from list of dictionaries
|
|
hf_dataset = Dataset.from_list(dataset_entries)
|
|
|
|
# Apply formatting function to generate the text field
|
|
def formatting_prompts_func(examples):
|
|
instructions = examples["instruction"]
|
|
inputs = examples["input"]
|
|
outputs = examples["output"]
|
|
texts = []
|
|
|
|
for instruction, input_text, output in zip(instructions, inputs, outputs):
|
|
# Handle None values and ensure strings
|
|
if input_text is None:
|
|
input_text = ""
|
|
if output is None:
|
|
output = ""
|
|
|
|
# Convert to string if needed
|
|
input_text = str(input_text)
|
|
output = str(output)
|
|
|
|
# Use the config's EOS token and alpaca prompt
|
|
text = config.alpaca_prompt.format(instruction, input_text, output) + config.eos_token
|
|
texts.append(text)
|
|
|
|
return {"text": texts}
|
|
|
|
# Apply the formatting function
|
|
formatted_dataset = hf_dataset.map(formatting_prompts_func, batched=True)
|
|
|
|
return formatted_dataset
|
|
|
|
def save_hf_dataset_to_disk(self, hf_dataset, save_path: str):
|
|
"""Save HuggingFace dataset to disk"""
|
|
try:
|
|
hf_dataset.save_to_disk(save_path)
|
|
logger.info(f"HuggingFace dataset saved to disk at: {save_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error saving HuggingFace dataset to disk: {e}")
|
|
return False
|
|
|
|
def load_hf_dataset_from_disk(self, load_path: str):
|
|
"""Load HuggingFace dataset from disk"""
|
|
try:
|
|
from datasets import load_from_disk
|
|
hf_dataset = load_from_disk(load_path)
|
|
logger.info(f"HuggingFace dataset loaded from disk: {load_path}")
|
|
logger.info(f"Dataset has {len(hf_dataset)} entries")
|
|
logger.info(f"Dataset features: {hf_dataset.features}")
|
|
return hf_dataset
|
|
except Exception as e:
|
|
logger.error(f"Error loading HuggingFace dataset from disk: {e}")
|
|
return None
|
|
|
|
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
|
|
"""Save processed data splits to files"""
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
for split_name, split_data in data.items():
|
|
if format == "jsonl":
|
|
output_file = output_path / f"{split_name}.jsonl"
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
for item in split_data:
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
elif format == "json":
|
|
output_file = output_path / f"{split_name}.json"
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
json.dump(split_data, f, ensure_ascii=False, indent=2)
|
|
elif format == "csv":
|
|
output_file = output_path / f"{split_name}.csv"
|
|
df = pd.DataFrame(split_data)
|
|
df.to_csv(output_file, index=False)
|
|
|
|
logger.info(f"Saved {len(split_data)} samples to {output_file}")
|
|
|
|
def run_pipeline(
|
|
self,
|
|
config: StylingConfig,
|
|
output_format: str = "styling",
|
|
save_splits: bool = True,
|
|
create_hf_dataset: bool = False,
|
|
save_hf_dataset: bool = False,
|
|
hf_dataset_path: str = None
|
|
) -> Dict[str, Any]:
|
|
"""Run complete styling pipeline"""
|
|
|
|
logger.info("Starting styling pipeline...")
|
|
|
|
# Load and preprocess data
|
|
processed_splits, analysis = self.load_and_preprocess(config)
|
|
|
|
# Convert to desired output format
|
|
if output_format == "alpaca":
|
|
formatted_splits = self.convert_to_alpaca_format(processed_splits, config)
|
|
else:
|
|
formatted_splits = processed_splits
|
|
|
|
# Save data if requested
|
|
if save_splits:
|
|
# Save directly in the output directory, not in a subdirectory
|
|
output_dir = Path(config.output_dir)
|
|
self.save_data(formatted_splits, str(output_dir))
|
|
|
|
# Convert to HuggingFace dataset if requested
|
|
hf_dataset = None
|
|
hf_dataset_save_path = None
|
|
if create_hf_dataset:
|
|
# Flatten all splits into one list for HF dataset
|
|
all_entries = []
|
|
for split_name, split_data in formatted_splits.items():
|
|
for item in split_data:
|
|
# Ensure we have the instruction field
|
|
if "instruction" not in item:
|
|
item["instruction"] = config.instruction
|
|
all_entries.append(item)
|
|
|
|
hf_dataset = self.convert_to_hf_dataset(all_entries, config)
|
|
logger.info(f"HuggingFace dataset created with {len(hf_dataset)} entries")
|
|
logger.info(f"Dataset features: {hf_dataset.features}")
|
|
|
|
# Save HuggingFace dataset to disk if requested
|
|
if save_hf_dataset:
|
|
if hf_dataset_path is None:
|
|
# Generate default path using the YAML output_dir
|
|
hf_dataset_path = str(Path(config.output_dir) / "hf_dataset")
|
|
|
|
success = self.save_hf_dataset_to_disk(hf_dataset, hf_dataset_path)
|
|
if success:
|
|
hf_dataset_save_path = hf_dataset_path
|
|
logger.info(f"HuggingFace dataset saved to: {hf_dataset_save_path}")
|
|
else:
|
|
logger.warning("Failed to save HuggingFace dataset to disk")
|
|
|
|
# Create result summary
|
|
result = {
|
|
"config": config,
|
|
"analysis": analysis,
|
|
"splits": {
|
|
split_name: len(split_data) for split_name, split_data in formatted_splits.items()
|
|
},
|
|
"output_format": output_format,
|
|
"output_dir": config.output_dir,
|
|
"data": formatted_splits, # Include the actual processed data
|
|
"instruction": config.instruction
|
|
}
|
|
|
|
# Add HuggingFace dataset info to result if created
|
|
if hf_dataset is not None:
|
|
result["hf_dataset"] = hf_dataset
|
|
if hf_dataset_save_path:
|
|
result["hf_dataset_path"] = hf_dataset_save_path
|
|
|
|
logger.info("Styling pipeline completed successfully!")
|
|
return result
|
|
|
|
# Helper functions
|
|
def create_huggingface_config(dataset_name: str, input_field: str = "text", output_field: str = "output", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig:
|
|
"""Helper function to create a HuggingFace configuration"""
|
|
return StylingConfig(
|
|
data_source="huggingface",
|
|
dataset_name=dataset_name,
|
|
input_field=input_field,
|
|
output_field=output_field,
|
|
instruction=instruction,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", output_field: str = "styled_text", instruction: str = "Rewrite the following text in a formal style", **kwargs) -> StylingConfig:
|
|
"""Helper function to create a custom data configuration"""
|
|
return StylingConfig(
|
|
data_source="custom",
|
|
data_path=data_path,
|
|
data_format=data_format,
|
|
input_field=input_field,
|
|
output_field=output_field,
|
|
instruction=instruction,
|
|
**kwargs
|
|
)
|
|
|
|
|
|
def save_hf_dataset_to_disk(hf_dataset, save_path: str) -> bool:
|
|
"""Utility function to save HuggingFace dataset to disk"""
|
|
try:
|
|
hf_dataset.save_to_disk(save_path)
|
|
print(f"HuggingFace dataset saved to disk at: {save_path}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error saving HuggingFace dataset to disk: {e}")
|
|
return False
|
|
|
|
|
|
def load_hf_dataset_from_disk(load_path: str):
|
|
"""Utility function to load HuggingFace dataset from disk"""
|
|
try:
|
|
from datasets import load_from_disk
|
|
hf_dataset = load_from_disk(load_path)
|
|
print(f"HuggingFace dataset loaded from disk: {load_path}")
|
|
print(f"Dataset has {len(hf_dataset)} entries")
|
|
print(f"Dataset features: {hf_dataset.features}")
|
|
return hf_dataset
|
|
except Exception as e:
|
|
print(f"Error loading HuggingFace dataset from disk: {e}")
|
|
return None
|
|
|
|
|
|
def load_yaml_config(config_path: str) -> Dict[str, Any]:
|
|
"""Load and parse YAML configuration file with proper structure handling"""
|
|
try:
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
yaml_data = yaml.safe_load(f)
|
|
|
|
# Extract configuration from YAML structure
|
|
config_dict = {}
|
|
|
|
# Handle task section
|
|
if 'task' in yaml_data:
|
|
task_data = yaml_data['task']
|
|
config_dict.update({
|
|
'task_name': task_data.get('name'),
|
|
'task_type': task_data.get('type')
|
|
})
|
|
|
|
# Handle data section
|
|
if 'data' in yaml_data:
|
|
data_config = yaml_data['data']
|
|
config_dict.update({
|
|
'data_source': data_config.get('source'),
|
|
'dataset_name': data_config.get('dataset_name'),
|
|
'data_path': data_config.get('data_path'),
|
|
'data_format': data_config.get('data_format'),
|
|
'input_field': data_config.get('input_field'),
|
|
'output_field': data_config.get('output_field'),
|
|
'instruction': data_config.get('instruction'),
|
|
'max_samples': data_config.get('max_samples'),
|
|
'train_split': data_config.get('train_split'),
|
|
'validation_split': data_config.get('validation_split'),
|
|
'test_split': data_config.get('test_split'),
|
|
'clean_text': data_config.get('clean_text'),
|
|
'lowercase': data_config.get('lowercase'),
|
|
'min_length': data_config.get('min_length'),
|
|
'max_length': data_config.get('max_length'),
|
|
'output_format': data_config.get('output_format'),
|
|
'output_dir': data_config.get('output_dir'),
|
|
'encoding': data_config.get('encoding'),
|
|
'delimiter': data_config.get('delimiter')
|
|
})
|
|
|
|
# Handle model section
|
|
if 'model' in yaml_data:
|
|
model_data = yaml_data['model']
|
|
config_dict.update({
|
|
'model_name': model_data.get('name'),
|
|
'model_max_length': model_data.get('max_length')
|
|
})
|
|
|
|
# Handle training section
|
|
if 'training' in yaml_data:
|
|
training_data = yaml_data['training']
|
|
config_dict.update({
|
|
'num_epochs': training_data.get('num_epochs'),
|
|
'batch_size': training_data.get('batch_size'),
|
|
'learning_rate': training_data.get('learning_rate'),
|
|
'weight_decay': training_data.get('weight_decay'),
|
|
'warmup_ratio': training_data.get('warmup_ratio'),
|
|
'lr_scheduler_type': training_data.get('lr_scheduler_type')
|
|
})
|
|
|
|
# Handle inference section
|
|
if 'inference' in yaml_data:
|
|
inference_data = yaml_data['inference']
|
|
config_dict.update({
|
|
'inference_batch_size': inference_data.get('batch_size'),
|
|
'max_new_tokens': inference_data.get('max_new_tokens'),
|
|
'temperature': inference_data.get('temperature')
|
|
})
|
|
|
|
logger.info(f"Successfully parsed YAML configuration from: {config_path}")
|
|
logger.info(f"Extracted {len(config_dict)} configuration parameters")
|
|
|
|
return config_dict
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading YAML config from {config_path}: {e}")
|
|
raise
|
|
|
|
|
|
def main():
|
|
"""Main function with YAML configuration support"""
|
|
|
|
parser = argparse.ArgumentParser(description="Styling Data Processing Pipeline")
|
|
|
|
# YAML configuration
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
|
|
|
# Data source arguments
|
|
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
|
|
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
|
|
parser.add_argument("--data-path", type=str, help="Path to custom data file")
|
|
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
|
|
|
|
# Field mapping
|
|
parser.add_argument("--input-field", type=str, help="Input field name")
|
|
parser.add_argument("--output-field", type=str, help="Output field name")
|
|
parser.add_argument("--instruction", type=str, help="Style instruction")
|
|
|
|
# Data processing
|
|
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
|
|
parser.add_argument("--train-split", type=float, help="Training split ratio")
|
|
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
|
|
parser.add_argument("--test-split", type=float, help="Test split ratio")
|
|
|
|
# Text preprocessing
|
|
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
|
|
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
|
|
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
|
|
parser.add_argument("--min-length", type=int, help="Minimum text length")
|
|
parser.add_argument("--max-length", type=int, help="Maximum text length")
|
|
|
|
# Output configuration
|
|
parser.add_argument("--output-format", choices=["styling", "alpaca"], help="Output format")
|
|
parser.add_argument("--output-dir", type=str, help="Output directory")
|
|
|
|
# HuggingFace dataset options
|
|
parser.add_argument("--create-hf-dataset", action="store_true", help="Create HuggingFace dataset")
|
|
parser.add_argument("--hf-dataset-path", type=str, help="Path to save HuggingFace dataset")
|
|
|
|
# Logging
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, args.log_level),
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
# Load configuration
|
|
config_dict = {}
|
|
|
|
# Load YAML config if provided
|
|
if args.config:
|
|
try:
|
|
config_dict = load_yaml_config(args.config)
|
|
except Exception as e:
|
|
logger.error(f"Error loading YAML config: {e}")
|
|
sys.exit(1)
|
|
|
|
# Override YAML config with CLI arguments
|
|
cli_overrides = {}
|
|
if args.data_source:
|
|
cli_overrides['data_source'] = args.data_source
|
|
if args.dataset_name:
|
|
cli_overrides['dataset_name'] = args.dataset_name
|
|
if args.data_path:
|
|
cli_overrides['data_path'] = args.data_path
|
|
if args.data_format:
|
|
cli_overrides['data_format'] = args.data_format
|
|
if args.input_field:
|
|
cli_overrides['input_field'] = args.input_field
|
|
if args.output_field:
|
|
cli_overrides['output_field'] = args.output_field
|
|
if args.instruction:
|
|
cli_overrides['instruction'] = args.instruction
|
|
if args.max_samples:
|
|
cli_overrides['max_samples'] = args.max_samples
|
|
if args.train_split:
|
|
cli_overrides['train_split'] = args.train_split
|
|
if args.validation_split:
|
|
cli_overrides['validation_split'] = args.validation_split
|
|
if args.test_split:
|
|
cli_overrides['test_split'] = args.test_split
|
|
if args.clean_text:
|
|
cli_overrides['clean_text'] = True
|
|
if args.remove_special_chars:
|
|
cli_overrides['remove_special_chars'] = True
|
|
if args.lowercase:
|
|
cli_overrides['lowercase'] = True
|
|
if args.min_length:
|
|
cli_overrides['min_length'] = args.min_length
|
|
if args.max_length:
|
|
cli_overrides['max_length'] = args.max_length
|
|
if args.output_format:
|
|
cli_overrides['output_format'] = args.output_format
|
|
if args.output_dir:
|
|
cli_overrides['output_dir'] = args.output_dir
|
|
|
|
# HuggingFace dataset options
|
|
if args.create_hf_dataset:
|
|
cli_overrides['create_hf_dataset'] = True
|
|
if args.hf_dataset_path:
|
|
cli_overrides['hf_dataset_path'] = args.hf_dataset_path
|
|
|
|
# Logging
|
|
if args.log_level:
|
|
cli_overrides['log_level'] = args.log_level
|
|
|
|
# Merge configurations
|
|
for key, value in cli_overrides.items():
|
|
if key in config_dict:
|
|
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
|
config_dict[key] = value
|
|
|
|
# Validate required arguments
|
|
if not config_dict.get('data_source'):
|
|
parser.error("--data-source is required (either in YAML config or CLI)")
|
|
|
|
if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'):
|
|
parser.error("--dataset-name is required for HuggingFace datasets")
|
|
|
|
if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'):
|
|
parser.error("--data-path is required for custom datasets")
|
|
|
|
# Create configuration object - properly handle YAML structure
|
|
config = StylingConfig(
|
|
data_source=config_dict.get('data_source', 'huggingface'),
|
|
dataset_name=config_dict.get('dataset_name'),
|
|
data_path=config_dict.get('data_path'),
|
|
data_format=config_dict.get('data_format', 'jsonl'),
|
|
input_field=config_dict.get('input_field', 'text'),
|
|
output_field=config_dict.get('output_field', 'styled_text'),
|
|
instruction=config_dict.get('instruction', 'Rewrite the following text in a formal style'),
|
|
max_samples=config_dict.get('max_samples'),
|
|
train_split=config_dict.get('train_split', 0.8),
|
|
validation_split=config_dict.get('validation_split', 0.1),
|
|
test_split=config_dict.get('test_split', 0.1),
|
|
clean_text=config_dict.get('clean_text', True),
|
|
remove_special_chars=config_dict.get('remove_special_chars', False),
|
|
lowercase=config_dict.get('lowercase', False),
|
|
min_length=config_dict.get('min_length', 10),
|
|
max_length=config_dict.get('max_length', 1000),
|
|
output_format=config_dict.get('output_format', 'styling'),
|
|
output_dir=config_dict.get('output_dir', './data'),
|
|
hf_split=config_dict.get('hf_split', 'train'),
|
|
hf_cache_dir=config_dict.get('hf_cache_dir'),
|
|
test_split_from=config_dict.get('test_split_from', 'train'),
|
|
val_split_from=config_dict.get('val_split_from', 'train'),
|
|
encoding=config_dict.get('encoding', 'utf-8'),
|
|
delimiter=config_dict.get('delimiter', ',')
|
|
)
|
|
|
|
# Initialize pipeline
|
|
pipeline = StylingDataPipeline()
|
|
|
|
try:
|
|
print(f"Starting styling pipeline with {config.data_source} data source...")
|
|
if args.config:
|
|
print(f"Using YAML configuration: {args.config}")
|
|
print(f"Style instruction: {config.instruction}")
|
|
print()
|
|
|
|
# Check if we should create HuggingFace dataset
|
|
create_hf_dataset = cli_overrides.get('create_hf_dataset', False)
|
|
hf_dataset_path = cli_overrides.get('hf_dataset_path')
|
|
|
|
# If creating HF dataset, also save it by default
|
|
save_hf_dataset = create_hf_dataset
|
|
|
|
result = pipeline.run_pipeline(
|
|
config,
|
|
config.output_format,
|
|
save_splits=True,
|
|
create_hf_dataset=create_hf_dataset,
|
|
save_hf_dataset=save_hf_dataset,
|
|
hf_dataset_path=hf_dataset_path
|
|
)
|
|
|
|
print(f"✅ Pipeline completed successfully!")
|
|
print(f" Data source: {config.data_source}")
|
|
if config.data_source == "huggingface":
|
|
print(f" Dataset: {config.dataset_name}")
|
|
else:
|
|
print(f" Data file: {config.data_path}")
|
|
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
|
|
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
|
|
print(f" Output directory: {config.output_dir}")
|
|
print(f" Style instruction: {config.instruction}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error running pipeline: {e}")
|
|
import traceback
|
|
print("Full error traceback:")
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|