2025-08-13 21:17:01 +01:00
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict , List , Optional , Union , Any , Tuple
from datasets import Dataset , load_dataset
import os
from dataclasses import dataclass
from abc import ABC , abstractmethod
import logging
from sklearn . model_selection import train_test_split
import re
import argparse
import sys
import yaml
logger = logging . getLogger ( __name__ )
logger . setLevel ( logging . DEBUG )
@dataclass
class StylingConfig :
""" Configuration for styling tasks """
# Data source configuration
data_source : str = " huggingface " # "huggingface" or "custom"
dataset_name : Optional [ str ] = None # For Hugging Face datasets
data_path : Optional [ str ] = None # For custom datasets
data_format : str = " jsonl " # jsonl, csv, json
# Field mapping - User configures which fields map to input/output
input_field : str = " text " # Field in dataset containing source text (e.g., "text", "source", etc.)
output_field : str = " styled_text " # Field in dataset containing styled text (e.g., "styled_text", "target", etc.)
instruction : str = " Rewrite the following text in a formal style " # Style instruction from YAML
# Data processing
max_samples : Optional [ int ] = None
train_split : float = 0.8
validation_split : float = 0.1
test_split : float = 0.1
# Text preprocessing
clean_text : bool = True
remove_special_chars : bool = False
lowercase : bool = False # Keep original case for styling
min_length : int = 10
max_length : int = 1000
# Output configuration
output_format : str = " styling " # instruction, conversation, qa
output_dir : str = " ./data "
# Hugging Face specific
hf_split : str = " train "
hf_cache_dir : Optional [ str ] = None
# Split configuration
test_split_from : str = " train "
val_split_from : str = " train "
# Custom data specific
encoding : str = " utf-8 "
delimiter : str = " , " # For CSV files
# Alpaca prompt configuration
alpaca_prompt : str = """ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that follows the instruction
### Instruction:
{}
### Input:
{}
### Response:
{} """
eos_token : str = " <|eot_id|> " # Use <|eot_id|> as EOS token
class DataValidator :
""" Validates styling data quality and format """
@staticmethod
def validate_styling_data ( data : Dict [ str , List [ Dict ] ] , config : StylingConfig , is_processed : bool = False ) - > Tuple [ bool , List [ str ] ] :
""" Validate styling dataset splits """
errors = [ ]
# Check if we have the expected splits
expected_splits = [ " train " , " validation " , " test " ]
for split in expected_splits :
if split not in data :
errors . append ( f " Missing ' { split } ' split " )
elif split == " train " and not data [ split ] :
errors . append ( f " Train split cannot be empty " )
# Allow validation and test splits to be empty for small datasets
if errors :
return False , errors
total_samples = sum ( len ( split_data ) for split_data in data . values ( ) )
logger . info ( f " Validating { total_samples } total samples across all splits... " )
# Determine field names based on whether data is processed or not
input_field = " input " if is_processed else config . input_field
output_field = " output " if is_processed else config . output_field
# Validate each split
for split_name , split_data in data . items ( ) :
if not split_data :
logger . info ( f " Skipping validation for empty { split_name } split " )
continue
logger . info ( f " Validating { split_name } split with { len ( split_data ) } samples... " )
# Check required fields
missing_input_count = 0
missing_output_count = 0
for i , item in enumerate ( split_data ) :
if input_field not in item :
errors . append ( f " Missing input field ' { input_field } ' in { split_name } split, item { i } " )
missing_input_count + = 1
if output_field not in item :
errors . append ( f " Missing output field ' { output_field } ' in { split_name } split, item { i } " )
missing_output_count + = 1
logger . info ( f " { split_name } - Items missing input field: { missing_input_count } " )
logger . info ( f " { split_name } - Items missing output field: { missing_output_count } " )
# Check data types
type_errors = 0
for i , item in enumerate ( split_data ) :
if not isinstance ( item . get ( input_field , " " ) , str ) :
errors . append ( f " Input field ' { input_field } ' must be string in { split_name } split, item { i } " )
type_errors + = 1
if not isinstance ( item . get ( output_field , " " ) , str ) :
errors . append ( f " Output field ' { output_field } ' must be string in { split_name } split, item { i } " )
type_errors + = 1
logger . info ( f " { split_name } - Type errors: { type_errors } " )
# Check for empty inputs/outputs
empty_inputs = sum ( 1 for item in split_data if not item . get ( input_field , " " ) . strip ( ) )
empty_outputs = sum ( 1 for item in split_data if not item . get ( output_field , " " ) . strip ( ) )
if empty_inputs > 0 :
errors . append ( f " Found { empty_inputs } items with empty input text in { split_name } split " )
if empty_outputs > 0 :
errors . append ( f " Found { empty_outputs } items with empty output text in { split_name } split " )
logger . info ( f " { split_name } - Empty inputs: { empty_inputs } " )
logger . info ( f " { split_name } - Empty outputs: { empty_outputs } " )
# Show sample of processed data for debugging
if split_data :
logger . info ( f " Sample processed items from { split_name } : " )
for i in range ( min ( 3 , len ( split_data ) ) ) :
item = split_data [ i ]
logger . info ( f " Item { i } : input= ' { item . get ( input_field , ' ' ) [ : 50 ] } ... ' , output= ' { item . get ( output_field , ' ' ) [ : 50 ] } ... ' " )
return len ( errors ) == 0 , errors
@staticmethod
def analyze_dataset ( data : Dict [ str , List [ Dict ] ] , config : StylingConfig , is_processed : bool = False ) - > Dict [ str , Any ] :
""" Analyze dataset characteristics across all splits """
analysis = {
" splits " : { } ,
" overall " : {
" total_samples " : 0 ,
" split_sizes " : { }
}
}
# Determine field names based on whether data is processed or not
input_field = " input " if is_processed else config . input_field
output_field = " output " if is_processed else config . output_field
# Analyze each split
for split_name , split_data in data . items ( ) :
if not split_data :
# Handle empty splits
split_analysis = {
" total_samples " : 0 ,
" text_length_stats " : { } ,
" missing_values " : { }
}
analysis [ " splits " ] [ split_name ] = split_analysis
analysis [ " overall " ] [ " split_sizes " ] [ split_name ] = 0
continue
split_analysis = {
" total_samples " : len ( split_data ) ,
" text_length_stats " : { } ,
" missing_values " : { }
}
# Text length statistics for both input and output
for field_name , field in [ ( " input " , input_field ) , ( " output " , output_field ) ] :
text_lengths = [ len ( item . get ( field , " " ) ) for item in split_data ]
if text_lengths :
split_analysis [ " text_length_stats " ] [ field_name ] = {
" min " : min ( text_lengths ) ,
" max " : max ( text_lengths ) ,
" mean " : np . mean ( text_lengths ) ,
" median " : np . median ( text_lengths )
}
# Missing values
for field in [ input_field , output_field ] :
missing_count = sum ( 1 for item in split_data if not item . get ( field ) )
split_analysis [ " missing_values " ] [ field ] = missing_count
analysis [ " splits " ] [ split_name ] = split_analysis
analysis [ " overall " ] [ " total_samples " ] + = len ( split_data )
analysis [ " overall " ] [ " split_sizes " ] [ split_name ] = len ( split_data )
return analysis
class BaseDataLoader ( ABC ) :
""" Abstract base class for data loaders """
@abstractmethod
def load ( self , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Load data and return dictionary with train/val/test splits """
pass
@abstractmethod
def preprocess ( self , data : Dict [ str , List [ Dict ] ] , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Apply preprocessing steps to all splits """
pass
class HuggingFaceDataLoader ( BaseDataLoader ) :
""" Load datasets from Hugging Face Hub """
def load ( self , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Load dataset from Hugging Face Hub with flexible split handling """
if not config . dataset_name :
raise ValueError ( " Dataset name is required for Hugging Face datasets " )
logger . info ( f " Loading Hugging Face dataset: { config . dataset_name } " )
try :
# First, let's check what splits are available in the dataset
dataset = load_dataset (
config . dataset_name ,
cache_dir = config . hf_cache_dir
)
# Log available splits
available_splits = list ( dataset . keys ( ) )
logger . info ( f " Available splits in dataset: { available_splits } " )
# Initialize split data
splits_data = {
" train " : [ ] ,
" validation " : [ ] ,
" test " : [ ]
}
# Handle train split
if " train " in available_splits :
train_dataset = dataset [ " train " ]
logger . info ( f " Using ' train ' split with { len ( train_dataset ) } samples " )
splits_data [ " train " ] = list ( train_dataset )
else :
logger . error ( " No ' train ' split found in dataset! " )
logger . error ( f " Available splits: { available_splits } " )
raise ValueError ( f " Dataset { config . dataset_name } does not have a ' train ' split " )
# Handle validation split
if config . val_split_from == " use_val_if_available " and " validation " in available_splits :
val_dataset = dataset [ " validation " ]
logger . info ( f " Using ' validation ' split with { len ( val_dataset ) } samples " )
splits_data [ " validation " ] = list ( val_dataset )
elif config . val_split_from == " use_val_if_available " and " val " in available_splits :
val_dataset = dataset [ " val " ]
logger . info ( f " Using ' val ' split with { len ( val_dataset ) } samples " )
splits_data [ " validation " ] = list ( val_dataset )
elif config . val_split_from == " use_val_if_available " :
logger . warning ( " No validation split found in dataset. Will create from train split. " )
logger . info ( f " Available splits: { available_splits } " )
logger . info ( f " Will use { config . validation_split * 100 } % of train data for validation " )
else :
logger . info ( f " Will create validation split from train data ( { config . validation_split * 100 } %) " )
# Handle test split
if config . test_split_from == " use_test_if_available " and " test " in available_splits :
test_dataset = dataset [ " test " ]
logger . info ( f " Using ' test ' split with { len ( test_dataset ) } samples " )
splits_data [ " test " ] = list ( test_dataset )
elif config . test_split_from == " use_val_if_available " and " validation " in available_splits :
test_dataset = dataset [ " validation " ]
logger . info ( f " Using ' validation ' split as test with { len ( test_dataset ) } samples " )
splits_data [ " test " ] = list ( test_dataset )
elif config . test_split_from == " use_val_if_available " and " val " in available_splits :
test_dataset = dataset [ " val " ]
logger . info ( f " Using ' val ' split as test with { len ( test_dataset ) } samples " )
splits_data [ " test " ] = list ( test_dataset )
elif config . test_split_from == " use_test_if_available " :
logger . warning ( " No test split found in dataset. Will create from train split. " )
logger . info ( f " Available splits: { available_splits } " )
logger . info ( f " Will use { config . test_split * 100 } % of train data for test " )
else :
logger . info ( f " Will create test split from train data ( { config . test_split * 100 } %) " )
# If we need to create splits from train data
if not splits_data [ " validation " ] or not splits_data [ " test " ] :
train_data = splits_data [ " train " ]
# Handle very small datasets
if len ( train_data ) < 3 :
logger . warning ( f " Dataset has only { len ( train_data ) } samples. Using all data for training. " )
splits_data [ " train " ] = train_data
splits_data [ " validation " ] = [ ]
splits_data [ " test " ] = [ ]
else :
# Calculate remaining percentages for train
total_train_percentage = config . train_split + config . validation_split + config . test_split
if total_train_percentage != 1.0 :
logger . warning ( f " Split percentages don ' t sum to 1.0 (got { total_train_percentage } ). Normalizing... " )
# Normalize percentages
config . train_split = config . train_split / total_train_percentage
config . validation_split = config . validation_split / total_train_percentage
config . test_split = config . test_split / total_train_percentage
# Create splits from train data
if not splits_data [ " validation " ] and not splits_data [ " test " ] :
# Split train into train, val, test
train_size = int ( len ( train_data ) * config . train_split )
val_size = int ( len ( train_data ) * config . validation_split )
# Handle small datasets
if len ( train_data ) < 10 :
# For small datasets, use more conservative splits
config . train_split = 0.6
config . validation_split = 0.2
config . test_split = 0.2
logger . info ( f " Small dataset detected. Adjusted split ratios to: train= { config . train_split } , val= { config . validation_split } , test= { config . test_split } " )
# Ensure minimum sizes
min_val_size = max ( 1 , int ( len ( train_data ) * 0.1 ) )
min_test_size = max ( 1 , int ( len ( train_data ) * 0.1 ) )
val_size = max ( min_val_size , int ( len ( train_data ) * config . validation_split ) )
test_size = max ( min_test_size , int ( len ( train_data ) * config . test_split ) )
train_size = len ( train_data ) - val_size - test_size
# Ensure train has at least 1 sample
if train_size < 1 :
if val_size > 1 :
val_size - = 1
train_size + = 1
elif test_size > 1 :
test_size - = 1
train_size + = 1
logger . info ( f " Adjusted split sizes: train= { train_size } , val= { val_size } , test= { test_size } " )
# First split: train + (val+test)
new_train , temp_data = train_test_split (
train_data ,
test_size = val_size + test_size ,
random_state = 42
)
# Second split: val + test
new_val , new_test = train_test_split (
temp_data ,
test_size = test_size / ( val_size + test_size ) if ( val_size + test_size ) > 0 else 0 ,
random_state = 42
)
splits_data [ " train " ] = new_train
splits_data [ " validation " ] = new_val
splits_data [ " test " ] = new_test
elif not splits_data [ " validation " ] :
# Only need to create val from train
val_size = max ( 1 , int ( len ( train_data ) * config . validation_split ) )
new_train , new_val = train_test_split (
train_data ,
test_size = val_size ,
random_state = 42
)
splits_data [ " train " ] = new_train
splits_data [ " validation " ] = new_val
elif not splits_data [ " test " ] :
# Only need to create test from train
test_size = max ( 1 , int ( len ( train_data ) * config . test_split ) )
new_train , new_test = train_test_split (
train_data ,
test_size = test_size ,
random_state = 42
)
splits_data [ " train " ] = new_train
splits_data [ " test " ] = new_test
logger . info ( f " Final split sizes: " )
logger . info ( f " Train: { len ( splits_data [ ' train ' ] ) } samples " )
logger . info ( f " Validation: { len ( splits_data [ ' validation ' ] ) } samples " )
logger . info ( f " Test: { len ( splits_data [ ' test ' ] ) } samples " )
# Ensure all splits exist (even if empty) for the pipeline
if " validation " not in splits_data :
splits_data [ " validation " ] = [ ]
if " test " not in splits_data :
splits_data [ " test " ] = [ ]
# Apply max_samples limit to each split if specified
if config . max_samples :
for split_name in splits_data :
if splits_data [ split_name ] :
original_size = len ( splits_data [ split_name ] )
splits_data [ split_name ] = splits_data [ split_name ] [ : config . max_samples ]
logger . info ( f " Limited { split_name } split from { original_size } to { len ( splits_data [ split_name ] ) } samples " )
# Log dataset info for debugging
for split_name , split_data in splits_data . items ( ) :
if split_data :
logger . info ( f " Sample data item from { split_name } : { split_data [ 0 ] } " )
logger . info ( f " Available fields in { split_name } split: { list ( split_data [ 0 ] . keys ( ) ) } " )
# Check if the required fields exist
if config . input_field not in split_data [ 0 ] :
logger . warning ( f " Input field ' { config . input_field } ' not found in { split_name } . Available fields: { list ( split_data [ 0 ] . keys ( ) ) } " )
# Suggest alternative fields
text_fields = [ f for f in split_data [ 0 ] . keys ( ) if any ( keyword in f . lower ( ) for keyword in [ ' text ' , ' sentence ' , ' content ' , ' input ' , ' comment ' , ' message ' ] ) ]
if text_fields :
logger . info ( f " Suggested text fields for { split_name } : { text_fields } " )
if config . output_field not in split_data [ 0 ] :
logger . warning ( f " Output field ' { config . output_field } ' not found in { split_name } . Available fields: { list ( split_data [ 0 ] . keys ( ) ) } " )
# Suggest alternative fields
output_fields = [ f for f in split_data [ 0 ] . keys ( ) if any ( keyword in f . lower ( ) for keyword in [ ' output ' , ' response ' , ' result ' , ' target ' , ' styled ' ] ) ]
if output_fields :
logger . info ( f " Suggested output fields for { split_name } : { output_fields } " )
logger . info ( f " Successfully loaded dataset { config . dataset_name } " )
return splits_data
except Exception as e :
logger . error ( f " Error loading dataset { config . dataset_name } : { e } " )
raise
def preprocess ( self , data : Dict [ str , List [ Dict ] ] , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Apply preprocessing steps to all splits separately """
processed_splits = { }
logger . info ( f " === PREPROCESSING DATA === " )
for split_name , split_data in data . items ( ) :
logger . info ( f " Processing { split_name } split with { len ( split_data ) } items... " )
# Log field availability for debugging
if split_data :
available_fields = set ( split_data [ 0 ] . keys ( ) )
logger . info ( f " Available fields in { split_name } : { available_fields } " )
logger . info ( f " Looking for input field: ' { config . input_field } ' , output field: ' { config . output_field } ' " )
if config . input_field not in available_fields :
logger . error ( f " Input field ' { config . input_field } ' not found in { split_name } . Available fields: { available_fields } " )
if config . output_field not in available_fields :
logger . error ( f " Output field ' { config . output_field } ' not found in { split_name } . Available fields: { available_fields } " )
# Count items with missing fields
missing_input = sum ( 1 for item in split_data if config . input_field not in item or not item . get ( config . input_field ) )
missing_output = sum ( 1 for item in split_data if config . output_field not in item or not item . get ( config . output_field ) )
logger . info ( f " { split_name } - Items missing input field: { missing_input } " )
logger . info ( f " { split_name } - Items missing output field: { missing_output } " )
# Show sample of raw data before preprocessing
logger . info ( f " === SAMPLE RAW DATA FROM { split_name . upper ( ) } BEFORE PREPROCESSING === " )
for i in range ( min ( 3 , len ( split_data ) ) ) :
item = split_data [ i ]
logger . info ( f " Raw item { i } from { split_name } : " )
for key , value in item . items ( ) :
if isinstance ( value , str ) and len ( value ) > 100 :
logger . info ( f " { key } : ' { value [ : 100 ] } ... ' " )
else :
logger . info ( f " { key } : { value } " )
# Process each item in the split
processed_data = [ ]
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self . _debug_count = 0
for i , item in enumerate ( split_data ) :
processed_item = self . _preprocess_item ( item , config )
if processed_item is not None :
processed_data . append ( processed_item )
processed_count + = 1
else :
skipped_count + = 1
if skipped_count < = 3 : # Log first few skipped items
logger . info ( f " Skipped item { i } from { split_name } : { item } " )
processed_splits [ split_name ] = processed_data
logger . info ( f " { split_name } - Preprocessed { processed_count } samples, skipped { skipped_count } samples " )
# Show sample of processed data
if processed_data :
logger . info ( f " === SAMPLE PROCESSED DATA FROM { split_name . upper ( ) } === " )
for i in range ( min ( 3 , len ( processed_data ) ) ) :
logger . info ( f " Processed item { i } from { split_name } : { processed_data [ i ] } " )
return processed_splits
def _preprocess_item ( self , item : Dict , config : StylingConfig ) - > Optional [ Dict ] :
""" Preprocess a single item """
# Extract input and output using configurable field names
input_text = item . get ( config . input_field , " " )
output_text = item . get ( config . output_field , " " )
# Log what we're extracting (for first few items)
if hasattr ( self , ' _debug_count ' ) :
self . _debug_count + = 1
else :
self . _debug_count = 1
if self . _debug_count < = 3 :
logger . debug ( f " Processing item { self . _debug_count } : " )
logger . debug ( f " Looking for input field ' { config . input_field } ' : { input_text } " )
logger . debug ( f " Looking for output field ' { config . output_field } ' : { output_text } " )
# Handle None values
if input_text is None :
input_text = " "
if output_text is None :
output_text = " "
# Convert to string if needed
input_text = str ( input_text )
output_text = str ( output_text )
if self . _debug_count < = 3 :
logger . debug ( f " After conversion - input: ' { input_text [ : 50 ] } ... ' , output: ' { output_text [ : 50 ] } ... ' " )
# Clean text if requested
if config . clean_text :
original_input = input_text
original_output = output_text
input_text = self . _clean_text ( input_text , config )
output_text = self . _clean_text ( output_text , config )
if self . _debug_count < = 3 :
logger . debug ( f " After cleaning - input: ' { original_input [ : 50 ] } ... ' -> ' { input_text [ : 50 ] } ... ' " )
logger . debug ( f " After cleaning - output: ' { original_output [ : 50 ] } ... ' -> ' { output_text [ : 50 ] } ... ' " )
# Check length constraints
if len ( input_text ) < config . min_length or len ( input_text ) > config . max_length :
if self . _debug_count < = 3 :
logger . debug ( f " Skipping - input length { len ( input_text ) } not in range [ { config . min_length } , { config . max_length } ] " )
return None
if len ( output_text ) < config . min_length or len ( output_text ) > config . max_length :
if self . _debug_count < = 3 :
logger . debug ( f " Skipping - output length { len ( output_text ) } not in range [ { config . min_length } , { config . max_length } ] " )
return None
# Create processed item - Always use "input" and "output" for internal processing
processed_item = {
" input " : input_text ,
" output " : output_text
}
if self . _debug_count < = 3 :
logger . debug ( f " Final processed item: { processed_item } " )
return processed_item
def _clean_text ( self , text : str , config : StylingConfig ) - > str :
""" Clean and normalize text """
if not isinstance ( text , str ) :
return " "
# Remove extra whitespace
text = re . sub ( r ' \ s+ ' , ' ' , text ) . strip ( )
# Convert to lowercase if requested
if config . lowercase :
text = text . lower ( )
# Remove special characters if requested
if config . remove_special_chars :
text = re . sub ( r ' [^ \ w \ s] ' , ' ' , text )
return text
class CustomDataLoader ( BaseDataLoader ) :
""" Load custom datasets from local files """
def load ( self , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Load custom dataset from local file and create splits """
if not config . data_path :
raise ValueError ( " Data path is required for custom datasets " )
file_path = Path ( config . data_path )
if not file_path . exists ( ) :
raise FileNotFoundError ( f " Data file not found: { file_path } " )
logger . info ( f " Loading custom dataset: { file_path } " )
if config . data_format == " jsonl " :
raw_data = self . _load_jsonl ( file_path , config )
elif config . data_format == " csv " :
raw_data = self . _load_csv ( file_path , config )
elif config . data_format == " json " :
raw_data = self . _load_json ( file_path , config )
else :
raise ValueError ( f " Unsupported format: { config . data_format } " )
if config . max_samples :
raw_data = raw_data [ : config . max_samples ]
logger . info ( f " Loaded { len ( raw_data ) } samples from { file_path } " )
# Create splits from the raw data
splits_data = self . _create_splits ( raw_data , config )
return splits_data
def _create_splits ( self , data : List [ Dict ] , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Create train/validation/test splits from raw data """
logger . info ( f " Creating splits from { len ( data ) } samples... " )
# Handle very small datasets
if len ( data ) < 3 :
logger . warning ( f " Dataset has only { len ( data ) } samples. Using all data for training. " )
return {
" train " : data ,
" validation " : [ ] ,
" test " : [ ]
}
# Calculate split sizes with minimum guarantees
total_samples = len ( data )
# Ensure minimum sizes for each split
min_val_size = max ( 1 , int ( total_samples * 0.1 ) ) # At least 1 sample for validation
min_test_size = max ( 1 , int ( total_samples * 0.1 ) ) # At least 1 sample for test
# Adjust split ratios if dataset is too small
if total_samples < 10 :
# For small datasets, use more conservative splits
config . train_split = 0.6
config . validation_split = 0.2
config . test_split = 0.2
logger . info ( f " Small dataset detected. Adjusted split ratios to: train= { config . train_split } , val= { config . validation_split } , test= { config . test_split } " )
# Calculate actual split sizes
val_size = max ( min_val_size , int ( total_samples * config . validation_split ) )
test_size = max ( min_test_size , int ( total_samples * config . test_split ) )
train_size = total_samples - val_size - test_size
# Ensure train split has at least 1 sample
if train_size < 1 :
# Adjust validation and test to ensure train has at least 1 sample
if val_size > 1 :
val_size - = 1
train_size + = 1
elif test_size > 1 :
test_size - = 1
train_size + = 1
logger . info ( f " Adjusted split sizes to ensure train has at least 1 sample: train= { train_size } , val= { val_size } , test= { test_size } " )
logger . info ( f " Split sizes: train= { train_size } , validation= { val_size } , test= { test_size } " )
# Create splits
if val_size == 0 and test_size == 0 :
# All data goes to train
splits_data = {
" train " : data ,
" validation " : [ ] ,
" test " : [ ]
}
elif val_size == 0 :
# Split between train and test
train_data , test_data = train_test_split ( data , test_size = test_size , random_state = 42 )
splits_data = {
" train " : train_data ,
" validation " : [ ] ,
" test " : test_data
}
elif test_size == 0 :
# Split between train and validation
train_data , val_data = train_test_split ( data , test_size = val_size , random_state = 42 )
splits_data = {
" train " : train_data ,
" validation " : val_data ,
" test " : [ ]
}
else :
# Full three-way split
# First split: train + (val+test)
train_data , temp_data = train_test_split (
data ,
test_size = val_size + test_size ,
random_state = 42
)
# Second split: val + test
val_data , test_data = train_test_split (
temp_data ,
test_size = test_size ,
random_state = 42
)
splits_data = {
" train " : train_data ,
" validation " : val_data ,
" test " : test_data
}
logger . info ( f " Created splits: " )
logger . info ( f " Train: { len ( splits_data [ ' train ' ] ) } samples " )
logger . info ( f " Validation: { len ( splits_data [ ' validation ' ] ) } samples " )
logger . info ( f " Test: { len ( splits_data [ ' test ' ] ) } samples " )
return splits_data
def _load_jsonl ( self , file_path : Path , config : StylingConfig ) - > List [ Dict ] :
""" Load JSONL file """
data = [ ]
with open ( file_path , ' r ' , encoding = config . encoding ) as f :
for line_num , line in enumerate ( f , 1 ) :
if line . strip ( ) :
try :
data . append ( json . loads ( line ) )
except json . JSONDecodeError as e :
logger . warning ( f " Invalid JSON at line { line_num } : { e } " )
return data
def _load_csv ( self , file_path : Path , config : StylingConfig ) - > List [ Dict ] :
""" Load CSV file """
df = pd . read_csv ( file_path , encoding = config . encoding , delimiter = config . delimiter )
return df . to_dict ( ' records ' )
def _load_json ( self , file_path : Path , config : StylingConfig ) - > List [ Dict ] :
""" Load JSON file """
with open ( file_path , ' r ' , encoding = config . encoding ) as f :
data = json . load ( f )
if isinstance ( data , list ) :
return data
elif isinstance ( data , dict ) and " data " in data :
return data [ " data " ]
else :
return [ data ]
def preprocess ( self , data : Dict [ str , List [ Dict ] ] , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Apply preprocessing steps to all splits separately """
processed_splits = { }
logger . info ( f " === PREPROCESSING CUSTOM DATA === " )
for split_name , split_data in data . items ( ) :
logger . info ( f " Processing { split_name } split with { len ( split_data ) } items... " )
processed_data = [ ]
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self . _debug_count = 0
for i , item in enumerate ( split_data ) :
processed_item = self . _preprocess_item ( item , config )
if processed_item is not None :
processed_data . append ( processed_item )
processed_count + = 1
else :
skipped_count + = 1
if skipped_count < = 3 : # Log first few skipped items
logger . info ( f " Skipped item { i } from { split_name } : { item } " )
processed_splits [ split_name ] = processed_data
logger . info ( f " { split_name } - Preprocessed { processed_count } samples, skipped { skipped_count } samples " )
return processed_splits
def _preprocess_item ( self , item : Dict , config : StylingConfig ) - > Optional [ Dict ] :
""" Preprocess a single item """
# Extract input and output using configurable field names
input_text = item . get ( config . input_field , " " )
output_text = item . get ( config . output_field , " " )
# Handle None values
if input_text is None :
input_text = " "
if output_text is None :
output_text = " "
# Convert to string if needed
input_text = str ( input_text )
output_text = str ( output_text )
# Clean text if requested
if config . clean_text :
input_text = self . _clean_text ( input_text , config )
output_text = self . _clean_text ( output_text , config )
# Check length constraints
if len ( input_text ) < config . min_length or len ( input_text ) > config . max_length :
return None
if len ( output_text ) < config . min_length or len ( output_text ) > config . max_length :
return None
# Create processed item - Always use "input" and "output" for internal processing
processed_item = {
" input " : input_text ,
" output " : output_text
}
return processed_item
def _clean_text ( self , text : str , config : StylingConfig ) - > str :
""" Clean and normalize text """
if not isinstance ( text , str ) :
return " "
# Remove extra whitespace
text = re . sub ( r ' \ s+ ' , ' ' , text ) . strip ( )
# Convert to lowercase if requested
if config . lowercase :
text = text . lower ( )
# Remove special characters if requested
if config . remove_special_chars :
text = re . sub ( r ' [^ \ w \ s] ' , ' ' , text )
return text
class StylingDataPipeline :
""" Main styling pipeline """
def __init__ ( self ) :
self . validator = DataValidator ( )
self . hf_loader = HuggingFaceDataLoader ( )
self . custom_loader = CustomDataLoader ( )
def create_config (
self ,
data_source : str ,
dataset_name : Optional [ str ] = None ,
data_path : Optional [ str ] = None ,
input_field : str = " input " ,
output_field : str = " output " ,
instruction : str = " Rewrite the following text in a formal style " ,
* * kwargs
) - > StylingConfig :
""" Create styling configuration """
return StylingConfig (
data_source = data_source ,
dataset_name = dataset_name ,
data_path = data_path ,
input_field = input_field ,
output_field = output_field ,
instruction = instruction ,
* * kwargs
)
def load_config_from_yaml ( self , yaml_path : str ) - > StylingConfig :
""" Load configuration from YAML file """
try :
config_dict = load_yaml_config ( yaml_path )
# Create configuration object from YAML data
config = StylingConfig (
data_source = config_dict . get ( ' data_source ' , ' custom ' ) ,
dataset_name = config_dict . get ( ' dataset_name ' ) ,
data_path = config_dict . get ( ' data_path ' ) ,
data_format = config_dict . get ( ' data_format ' , ' jsonl ' ) ,
input_field = config_dict . get ( ' input_field ' , ' text ' ) ,
output_field = config_dict . get ( ' output_field ' , ' styled_text ' ) ,
instruction = config_dict . get ( ' instruction ' , ' Rewrite the following text in a formal style ' ) ,
max_samples = config_dict . get ( ' max_samples ' ) ,
train_split = config_dict . get ( ' train_split ' , 0.8 ) ,
validation_split = config_dict . get ( ' validation_split ' , 0.1 ) ,
test_split = config_dict . get ( ' test_split ' , 0.1 ) ,
clean_text = config_dict . get ( ' clean_text ' , True ) ,
remove_special_chars = config_dict . get ( ' remove_special_chars ' , False ) ,
lowercase = config_dict . get ( ' lowercase ' , False ) ,
min_length = config_dict . get ( ' min_length ' , 10 ) ,
max_length = config_dict . get ( ' max_length ' , 1000 ) ,
output_format = config_dict . get ( ' output_format ' , ' styling ' ) ,
output_dir = config_dict . get ( ' output_dir ' , ' ./data ' ) ,
hf_split = config_dict . get ( ' hf_split ' , ' train ' ) ,
hf_cache_dir = config_dict . get ( ' hf_cache_dir ' ) ,
test_split_from = config_dict . get ( ' test_split_from ' , ' train ' ) ,
val_split_from = config_dict . get ( ' val_split_from ' , ' train ' ) ,
encoding = config_dict . get ( ' encoding ' , ' utf-8 ' ) ,
delimiter = config_dict . get ( ' delimiter ' , ' , ' )
)
logger . info ( f " Configuration loaded from YAML: { yaml_path } " )
logger . info ( f " Output directory: { config . output_dir } " )
logger . info ( f " Instruction: { config . instruction } " )
return config
except Exception as e :
logger . error ( f " Error loading configuration from YAML { yaml_path } : { e } " )
raise
def load_and_preprocess ( self , config : StylingConfig ) - > Tuple [ Dict [ str , List [ Dict ] ] , Dict [ str , Any ] ] :
""" Load and preprocess data """
2025-08-13 23:50:20 +00:00
logger . info ( f " Starting data loading and preprocessing... " )
logger . info ( f " Data source: { config . data_source } " )
2025-08-13 21:17:01 +01:00
2025-08-13 23:50:20 +00:00
try :
# Load data
if config . data_source == " huggingface " :
logger . info ( " Loading HuggingFace dataset... " )
raw_splits = self . hf_loader . load ( config )
logger . info ( " Preprocessing HuggingFace dataset... " )
processed_splits = self . hf_loader . preprocess ( raw_splits , config )
elif config . data_source == " custom " :
logger . info ( " Loading custom dataset... " )
raw_splits = self . custom_loader . load ( config )
logger . info ( " Preprocessing custom dataset... " )
processed_splits = self . custom_loader . preprocess ( raw_splits , config )
else :
raise ValueError ( f " Unsupported data source: { config . data_source } " )
logger . info ( f " Data loading and preprocessing completed successfully " )
logger . info ( f " Raw splits: { list ( raw_splits . keys ( ) ) } " )
logger . info ( f " Processed splits: { list ( processed_splits . keys ( ) ) } " )
# Validate processed data
logger . info ( " Validating processed data... " )
is_valid , errors = self . validator . validate_styling_data ( processed_splits , config , is_processed = True )
if not is_valid :
logger . error ( " Data validation failed: " )
for error in errors :
logger . error ( f " - { error } " )
raise ValueError ( " Data validation failed " )
logger . info ( " Data validation passed " )
# Analyze dataset
logger . info ( " Analyzing dataset... " )
analysis = self . validator . analyze_dataset ( processed_splits , config , is_processed = True )
logger . info ( " Dataset analysis completed " )
return processed_splits , analysis
except Exception as e :
logger . error ( f " Error in load_and_preprocess: { e } " )
raise
2025-08-13 21:17:01 +01:00
def convert_to_alpaca_format ( self , data : Dict [ str , List [ Dict ] ] , config : StylingConfig ) - > Dict [ str , List [ Dict ] ] :
""" Convert styling data to Alpaca format with instruction """
alpaca_splits = { }
for split_name , split_data in data . items ( ) :
alpaca_data = [ ]
for item in split_data :
# Ensure input and output fields exist, default to empty string if missing
input_text = item . get ( " input " , " " )
output_text = item . get ( " output " , " " )
# Handle None values
if input_text is None :
input_text = " "
if output_text is None :
output_text = " "
# Convert to string if needed
input_text = str ( input_text )
output_text = str ( output_text )
alpaca_data . append ( {
" instruction " : config . instruction ,
" input " : input_text ,
" output " : output_text
} )
alpaca_splits [ split_name ] = alpaca_data
return alpaca_splits
def format_for_training ( self , data : Dict [ str , List [ Dict ] ] , config : StylingConfig ) - > Dict [ str , List [ str ] ] :
""" Format entries for training using Alpaca prompt format """
formatted_splits = { }
for split_name , split_data in data . items ( ) :
formatted_texts = [ ]
for item in split_data :
# Ensure input and output fields exist, default to empty string if missing
input_text = item . get ( " input " , " " )
output_text = item . get ( " output " , " " )
# Handle None values
if input_text is None :
input_text = " "
if output_text is None :
output_text = " "
# Convert to string if needed
input_text = str ( input_text )
output_text = str ( output_text )
text = config . alpaca_prompt . format (
config . instruction ,
input_text ,
output_text
) + config . eos_token
formatted_texts . append ( text )
formatted_splits [ split_name ] = formatted_texts
return formatted_splits
def convert_to_hf_dataset ( self , dataset_entries : List [ Dict ] , config : StylingConfig ) :
""" Convert dataset entries to HuggingFace dataset format with text formatting """
from datasets import Dataset
# Create HuggingFace dataset from list of dictionaries
hf_dataset = Dataset . from_list ( dataset_entries )
# Apply formatting function to generate the text field
def formatting_prompts_func ( examples ) :
instructions = examples [ " instruction " ]
inputs = examples [ " input " ]
outputs = examples [ " output " ]
texts = [ ]
for instruction , input_text , output in zip ( instructions , inputs , outputs ) :
# Handle None values and ensure strings
if input_text is None :
input_text = " "
if output is None :
output = " "
# Convert to string if needed
input_text = str ( input_text )
output = str ( output )
# Use the config's EOS token and alpaca prompt
text = config . alpaca_prompt . format ( instruction , input_text , output ) + config . eos_token
texts . append ( text )
return { " text " : texts }
# Apply the formatting function
formatted_dataset = hf_dataset . map ( formatting_prompts_func , batched = True )
return formatted_dataset
def save_hf_dataset_to_disk ( self , hf_dataset , save_path : str ) :
""" Save HuggingFace dataset to disk """
try :
hf_dataset . save_to_disk ( save_path )
logger . info ( f " HuggingFace dataset saved to disk at: { save_path } " )
return True
except Exception as e :
logger . error ( f " Error saving HuggingFace dataset to disk: { e } " )
return False
def load_hf_dataset_from_disk ( self , load_path : str ) :
""" Load HuggingFace dataset from disk """
try :
from datasets import load_from_disk
hf_dataset = load_from_disk ( load_path )
logger . info ( f " HuggingFace dataset loaded from disk: { load_path } " )
logger . info ( f " Dataset has { len ( hf_dataset ) } entries " )
logger . info ( f " Dataset features: { hf_dataset . features } " )
return hf_dataset
except Exception as e :
logger . error ( f " Error loading HuggingFace dataset from disk: { e } " )
return None
def save_data ( self , data : Dict [ str , List [ Dict ] ] , output_dir : str , format : str = " jsonl " ) :
""" Save processed data splits to files """
output_path = Path ( output_dir )
output_path . mkdir ( parents = True , exist_ok = True )
for split_name , split_data in data . items ( ) :
if format == " jsonl " :
output_file = output_path / f " { split_name } .jsonl "
with open ( output_file , ' w ' , encoding = ' utf-8 ' ) as f :
for item in split_data :
f . write ( json . dumps ( item , ensure_ascii = False ) + ' \n ' )
elif format == " json " :
output_file = output_path / f " { split_name } .json "
with open ( output_file , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( split_data , f , ensure_ascii = False , indent = 2 )
elif format == " csv " :
output_file = output_path / f " { split_name } .csv "
df = pd . DataFrame ( split_data )
df . to_csv ( output_file , index = False )
logger . info ( f " Saved { len ( split_data ) } samples to { output_file } " )
def run_pipeline (
self ,
config : StylingConfig ,
output_format : str = " styling " ,
save_splits : bool = True ,
create_hf_dataset : bool = False ,
save_hf_dataset : bool = False ,
hf_dataset_path : str = None
) - > Dict [ str , Any ] :
""" Run complete styling pipeline """
logger . info ( " Starting styling pipeline... " )
# Load and preprocess data
processed_splits , analysis = self . load_and_preprocess ( config )
# Convert to desired output format
if output_format == " alpaca " :
formatted_splits = self . convert_to_alpaca_format ( processed_splits , config )
else :
formatted_splits = processed_splits
# Save data if requested
if save_splits :
# Save directly in the output directory, not in a subdirectory
output_dir = Path ( config . output_dir )
self . save_data ( formatted_splits , str ( output_dir ) )
# Convert to HuggingFace dataset if requested
hf_dataset = None
hf_dataset_save_path = None
if create_hf_dataset :
# Flatten all splits into one list for HF dataset
all_entries = [ ]
for split_name , split_data in formatted_splits . items ( ) :
for item in split_data :
# Ensure we have the instruction field
if " instruction " not in item :
item [ " instruction " ] = config . instruction
all_entries . append ( item )
hf_dataset = self . convert_to_hf_dataset ( all_entries , config )
logger . info ( f " HuggingFace dataset created with { len ( hf_dataset ) } entries " )
logger . info ( f " Dataset features: { hf_dataset . features } " )
# Save HuggingFace dataset to disk if requested
if save_hf_dataset :
if hf_dataset_path is None :
# Generate default path using the YAML output_dir
hf_dataset_path = str ( Path ( config . output_dir ) / " hf_dataset " )
success = self . save_hf_dataset_to_disk ( hf_dataset , hf_dataset_path )
if success :
hf_dataset_save_path = hf_dataset_path
logger . info ( f " HuggingFace dataset saved to: { hf_dataset_save_path } " )
else :
logger . warning ( " Failed to save HuggingFace dataset to disk " )
# Create result summary
result = {
" config " : config ,
" analysis " : analysis ,
" splits " : {
split_name : len ( split_data ) for split_name , split_data in formatted_splits . items ( )
} ,
" output_format " : output_format ,
" output_dir " : config . output_dir ,
" data " : formatted_splits , # Include the actual processed data
" instruction " : config . instruction
}
# Add HuggingFace dataset info to result if created
if hf_dataset is not None :
result [ " hf_dataset " ] = hf_dataset
if hf_dataset_save_path :
result [ " hf_dataset_path " ] = hf_dataset_save_path
logger . info ( " Styling pipeline completed successfully! " )
return result
# Helper functions
def create_huggingface_config ( dataset_name : str , input_field : str = " text " , output_field : str = " output " , instruction : str = " Rewrite the following text in a formal style " , * * kwargs ) - > StylingConfig :
""" Helper function to create a HuggingFace configuration """
return StylingConfig (
data_source = " huggingface " ,
dataset_name = dataset_name ,
input_field = input_field ,
output_field = output_field ,
instruction = instruction ,
* * kwargs
)
def create_custom_config ( data_path : str , data_format : str = " jsonl " , input_field : str = " text " , output_field : str = " styled_text " , instruction : str = " Rewrite the following text in a formal style " , * * kwargs ) - > StylingConfig :
""" Helper function to create a custom data configuration """
return StylingConfig (
data_source = " custom " ,
data_path = data_path ,
data_format = data_format ,
input_field = input_field ,
output_field = output_field ,
instruction = instruction ,
* * kwargs
)
def save_hf_dataset_to_disk ( hf_dataset , save_path : str ) - > bool :
""" Utility function to save HuggingFace dataset to disk """
try :
hf_dataset . save_to_disk ( save_path )
print ( f " HuggingFace dataset saved to disk at: { save_path } " )
return True
except Exception as e :
print ( f " Error saving HuggingFace dataset to disk: { e } " )
return False
def load_hf_dataset_from_disk ( load_path : str ) :
""" Utility function to load HuggingFace dataset from disk """
try :
from datasets import load_from_disk
hf_dataset = load_from_disk ( load_path )
print ( f " HuggingFace dataset loaded from disk: { load_path } " )
print ( f " Dataset has { len ( hf_dataset ) } entries " )
print ( f " Dataset features: { hf_dataset . features } " )
return hf_dataset
except Exception as e :
print ( f " Error loading HuggingFace dataset from disk: { e } " )
return None
def load_yaml_config ( config_path : str ) - > Dict [ str , Any ] :
""" Load and parse YAML configuration file with proper structure handling """
try :
with open ( config_path , ' r ' , encoding = ' utf-8 ' ) as f :
yaml_data = yaml . safe_load ( f )
# Extract configuration from YAML structure
config_dict = { }
# Handle task section
if ' task ' in yaml_data :
task_data = yaml_data [ ' task ' ]
config_dict . update ( {
' task_name ' : task_data . get ( ' name ' ) ,
' task_type ' : task_data . get ( ' type ' )
} )
# Handle data section
if ' data ' in yaml_data :
data_config = yaml_data [ ' data ' ]
config_dict . update ( {
' data_source ' : data_config . get ( ' source ' ) ,
' dataset_name ' : data_config . get ( ' dataset_name ' ) ,
' data_path ' : data_config . get ( ' data_path ' ) ,
' data_format ' : data_config . get ( ' data_format ' ) ,
' input_field ' : data_config . get ( ' input_field ' ) ,
' output_field ' : data_config . get ( ' output_field ' ) ,
' instruction ' : data_config . get ( ' instruction ' ) ,
' max_samples ' : data_config . get ( ' max_samples ' ) ,
' train_split ' : data_config . get ( ' train_split ' ) ,
' validation_split ' : data_config . get ( ' validation_split ' ) ,
' test_split ' : data_config . get ( ' test_split ' ) ,
' clean_text ' : data_config . get ( ' clean_text ' ) ,
' lowercase ' : data_config . get ( ' lowercase ' ) ,
' min_length ' : data_config . get ( ' min_length ' ) ,
' max_length ' : data_config . get ( ' max_length ' ) ,
' output_format ' : data_config . get ( ' output_format ' ) ,
' output_dir ' : data_config . get ( ' output_dir ' ) ,
' encoding ' : data_config . get ( ' encoding ' ) ,
' delimiter ' : data_config . get ( ' delimiter ' )
} )
# Handle model section
if ' model ' in yaml_data :
model_data = yaml_data [ ' model ' ]
config_dict . update ( {
' model_name ' : model_data . get ( ' name ' ) ,
' model_max_length ' : model_data . get ( ' max_length ' )
} )
# Handle training section
if ' training ' in yaml_data :
training_data = yaml_data [ ' training ' ]
config_dict . update ( {
' num_epochs ' : training_data . get ( ' num_epochs ' ) ,
' batch_size ' : training_data . get ( ' batch_size ' ) ,
' learning_rate ' : training_data . get ( ' learning_rate ' ) ,
' weight_decay ' : training_data . get ( ' weight_decay ' ) ,
' warmup_ratio ' : training_data . get ( ' warmup_ratio ' ) ,
' lr_scheduler_type ' : training_data . get ( ' lr_scheduler_type ' )
} )
# Handle inference section
if ' inference ' in yaml_data :
inference_data = yaml_data [ ' inference ' ]
config_dict . update ( {
' inference_batch_size ' : inference_data . get ( ' batch_size ' ) ,
' max_new_tokens ' : inference_data . get ( ' max_new_tokens ' ) ,
' temperature ' : inference_data . get ( ' temperature ' )
} )
logger . info ( f " Successfully parsed YAML configuration from: { config_path } " )
logger . info ( f " Extracted { len ( config_dict ) } configuration parameters " )
return config_dict
except Exception as e :
logger . error ( f " Error loading YAML config from { config_path } : { e } " )
raise
def main ( ) :
""" Main function with YAML configuration support """
parser = argparse . ArgumentParser ( description = " Styling Data Processing Pipeline " )
# YAML configuration
parser . add_argument ( " --config " , type = str , help = " Path to YAML configuration file " )
# Data source arguments
parser . add_argument ( " --data-source " , choices = [ " huggingface " , " custom " ] , help = " Data source " )
parser . add_argument ( " --dataset-name " , type = str , help = " HuggingFace dataset name " )
parser . add_argument ( " --data-path " , type = str , help = " Path to custom data file " )
parser . add_argument ( " --data-format " , choices = [ " jsonl " , " csv " , " json " ] , help = " Data format " )
# Field mapping
parser . add_argument ( " --input-field " , type = str , help = " Input field name " )
parser . add_argument ( " --output-field " , type = str , help = " Output field name " )
parser . add_argument ( " --instruction " , type = str , help = " Style instruction " )
# Data processing
parser . add_argument ( " --max-samples " , type = int , help = " Maximum samples to process " )
parser . add_argument ( " --train-split " , type = float , help = " Training split ratio " )
parser . add_argument ( " --validation-split " , type = float , help = " Validation split ratio " )
parser . add_argument ( " --test-split " , type = float , help = " Test split ratio " )
# Text preprocessing
parser . add_argument ( " --clean-text " , action = " store_true " , help = " Clean and normalize text " )
parser . add_argument ( " --remove-special-chars " , action = " store_true " , help = " Remove special characters " )
parser . add_argument ( " --lowercase " , action = " store_true " , help = " Convert text to lowercase " )
parser . add_argument ( " --min-length " , type = int , help = " Minimum text length " )
parser . add_argument ( " --max-length " , type = int , help = " Maximum text length " )
# Output configuration
parser . add_argument ( " --output-format " , choices = [ " styling " , " alpaca " ] , help = " Output format " )
parser . add_argument ( " --output-dir " , type = str , help = " Output directory " )
# HuggingFace dataset options
parser . add_argument ( " --create-hf-dataset " , action = " store_true " , help = " Create HuggingFace dataset " )
parser . add_argument ( " --hf-dataset-path " , type = str , help = " Path to save HuggingFace dataset " )
# Logging
parser . add_argument ( " --log-level " , choices = [ " DEBUG " , " INFO " , " WARNING " , " ERROR " ] , default = " INFO " , help = " Logging level " )
args = parser . parse_args ( )
# Set up logging
logging . basicConfig (
level = getattr ( logging , args . log_level ) ,
format = ' %(asctime)s - %(name)s - %(levelname)s - %(message)s '
)
# Load configuration
config_dict = { }
# Load YAML config if provided
if args . config :
try :
config_dict = load_yaml_config ( args . config )
except Exception as e :
logger . error ( f " Error loading YAML config: { e } " )
sys . exit ( 1 )
# Override YAML config with CLI arguments
cli_overrides = { }
if args . data_source :
cli_overrides [ ' data_source ' ] = args . data_source
if args . dataset_name :
cli_overrides [ ' dataset_name ' ] = args . dataset_name
if args . data_path :
cli_overrides [ ' data_path ' ] = args . data_path
if args . data_format :
cli_overrides [ ' data_format ' ] = args . data_format
if args . input_field :
cli_overrides [ ' input_field ' ] = args . input_field
if args . output_field :
cli_overrides [ ' output_field ' ] = args . output_field
if args . instruction :
cli_overrides [ ' instruction ' ] = args . instruction
if args . max_samples :
cli_overrides [ ' max_samples ' ] = args . max_samples
if args . train_split :
cli_overrides [ ' train_split ' ] = args . train_split
if args . validation_split :
cli_overrides [ ' validation_split ' ] = args . validation_split
if args . test_split :
cli_overrides [ ' test_split ' ] = args . test_split
if args . clean_text :
cli_overrides [ ' clean_text ' ] = True
if args . remove_special_chars :
cli_overrides [ ' remove_special_chars ' ] = True
if args . lowercase :
cli_overrides [ ' lowercase ' ] = True
if args . min_length :
cli_overrides [ ' min_length ' ] = args . min_length
if args . max_length :
cli_overrides [ ' max_length ' ] = args . max_length
if args . output_format :
cli_overrides [ ' output_format ' ] = args . output_format
if args . output_dir :
cli_overrides [ ' output_dir ' ] = args . output_dir
# HuggingFace dataset options
if args . create_hf_dataset :
cli_overrides [ ' create_hf_dataset ' ] = True
if args . hf_dataset_path :
cli_overrides [ ' hf_dataset_path ' ] = args . hf_dataset_path
# Logging
if args . log_level :
cli_overrides [ ' log_level ' ] = args . log_level
# Merge configurations
for key , value in cli_overrides . items ( ) :
if key in config_dict :
logger . info ( f " Overriding YAML config ' { key } ' with CLI value: { value } " )
config_dict [ key ] = value
# Validate required arguments
if not config_dict . get ( ' data_source ' ) :
parser . error ( " --data-source is required (either in YAML config or CLI) " )
if config_dict . get ( ' data_source ' ) == " huggingface " and not config_dict . get ( ' dataset_name ' ) :
parser . error ( " --dataset-name is required for HuggingFace datasets " )
if config_dict . get ( ' data_source ' ) == " custom " and not config_dict . get ( ' data_path ' ) :
parser . error ( " --data-path is required for custom datasets " )
# Create configuration object - properly handle YAML structure
config = StylingConfig (
data_source = config_dict . get ( ' data_source ' , ' huggingface ' ) ,
dataset_name = config_dict . get ( ' dataset_name ' ) ,
data_path = config_dict . get ( ' data_path ' ) ,
data_format = config_dict . get ( ' data_format ' , ' jsonl ' ) ,
input_field = config_dict . get ( ' input_field ' , ' text ' ) ,
output_field = config_dict . get ( ' output_field ' , ' styled_text ' ) ,
instruction = config_dict . get ( ' instruction ' , ' Rewrite the following text in a formal style ' ) ,
max_samples = config_dict . get ( ' max_samples ' ) ,
train_split = config_dict . get ( ' train_split ' , 0.8 ) ,
validation_split = config_dict . get ( ' validation_split ' , 0.1 ) ,
test_split = config_dict . get ( ' test_split ' , 0.1 ) ,
clean_text = config_dict . get ( ' clean_text ' , True ) ,
remove_special_chars = config_dict . get ( ' remove_special_chars ' , False ) ,
lowercase = config_dict . get ( ' lowercase ' , False ) ,
min_length = config_dict . get ( ' min_length ' , 10 ) ,
max_length = config_dict . get ( ' max_length ' , 1000 ) ,
output_format = config_dict . get ( ' output_format ' , ' styling ' ) ,
output_dir = config_dict . get ( ' output_dir ' , ' ./data ' ) ,
hf_split = config_dict . get ( ' hf_split ' , ' train ' ) ,
hf_cache_dir = config_dict . get ( ' hf_cache_dir ' ) ,
test_split_from = config_dict . get ( ' test_split_from ' , ' train ' ) ,
val_split_from = config_dict . get ( ' val_split_from ' , ' train ' ) ,
encoding = config_dict . get ( ' encoding ' , ' utf-8 ' ) ,
delimiter = config_dict . get ( ' delimiter ' , ' , ' )
)
# Initialize pipeline
pipeline = StylingDataPipeline ( )
try :
print ( f " Starting styling pipeline with { config . data_source } data source... " )
if args . config :
print ( f " Using YAML configuration: { args . config } " )
print ( f " Style instruction: { config . instruction } " )
print ( )
# Check if we should create HuggingFace dataset
create_hf_dataset = cli_overrides . get ( ' create_hf_dataset ' , False )
hf_dataset_path = cli_overrides . get ( ' hf_dataset_path ' )
# If creating HF dataset, also save it by default
save_hf_dataset = create_hf_dataset
result = pipeline . run_pipeline (
config ,
config . output_format ,
save_splits = True ,
create_hf_dataset = create_hf_dataset ,
save_hf_dataset = save_hf_dataset ,
hf_dataset_path = hf_dataset_path
)
print ( f " ✅ Pipeline completed successfully! " )
print ( f " Data source: { config . data_source } " )
if config . data_source == " huggingface " :
print ( f " Dataset: { config . dataset_name } " )
else :
print ( f " Data file: { config . data_path } " )
print ( f " Total samples: { result [ ' analysis ' ] [ ' overall ' ] [ ' total_samples ' ] } " )
print ( f " Split sizes: { result [ ' analysis ' ] [ ' overall ' ] [ ' split_sizes ' ] } " )
print ( f " Output directory: { config . output_dir } " )
print ( f " Style instruction: { config . instruction } " )
except Exception as e :
print ( f " ❌ Error running pipeline: { e } " )
2025-08-13 23:50:20 +00:00
import traceback
print ( " Full error traceback: " )
traceback . print_exc ( )
2025-08-13 21:17:01 +01:00
sys . exit ( 1 )
if __name__ == " __main__ " :
main ( )