Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/classification/data_processor.py
T

1073 lines
46 KiB
Python
Raw Normal View History

2025-08-06 22:45:37 +01:00
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import re
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Set logger level to DEBUG to capture INFO, DEBUG, ERROR
@dataclass
class ClassificationConfig:
"""Configuration for classification tasks"""
# Data source configuration
data_source: str ="huggingface" # "huggingface" or "custom"
dataset_name: Optional[str] = None # For Hugging Face datasets
data_path: Optional[str] = None # For custom datasets
data_format: str = "jsonl" # jsonl, csv, json
# Field mapping
input_field: str = "text" # Field containing input text
label_field: str = "label" # Field containing labels
id_field: Optional[str] = None # Optional ID field
# Data processing
max_samples: Optional[int] = None
train_split: float = 0.8
validation_split: float = 0.1
test_split: float = 0.1
# Text preprocessing
clean_text: bool = True
remove_special_chars: bool = False
lowercase: bool = True
min_length: int = 10
max_length: int = 1000
# Label processing
label_encoding: str = "auto" # auto, numeric, string
multilabel: bool = False
label_separator: str = "," # For multilabel datasets
# Output configuration
output_format: str = "classification" # instruction, conversation, qa
output_dir: str = "./data"
# Hugging Face specific
hf_split: str = "train"
hf_cache_dir: Optional[str] = None
# Split configuration - new flexible split handling
test_split_from: str = "train" # "train", "use_test_if_available", or "use_val_if_available"
val_split_from: str = "train" # "train", "use_val_if_available"
# Custom data specific
encoding: str = "utf-8"
delimiter: str = "," # For CSV files
class DataValidator:
"""Validates classification data quality and format"""
@staticmethod
def validate_classification_data(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
"""Validate classification dataset splits"""
errors = []
# Check if we have the expected splits
expected_splits = ["train", "validation", "test"]
for split in expected_splits:
if split not in data or not data[split]:
errors.append(f"Missing or empty '{split}' split")
if errors:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
logger.info(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
label_field = "label" # label field stays the same
# Validate each split
for split_name, split_data in data.items():
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_input_count = 0
missing_label_count = 0
for i, item in enumerate(split_data):
if input_field not in item:
errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}")
missing_input_count += 1
if label_field not in item:
errors.append(f"Missing label field '{label_field}' in {split_name} split, item {i}")
missing_label_count += 1
logger.info(f"{split_name} - Items missing input field: {missing_input_count}")
logger.info(f"{split_name} - Items missing label field: {missing_label_count}")
# Check data types
type_errors = 0
for i, item in enumerate(split_data):
if not isinstance(item.get(input_field, ""), str):
errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}")
type_errors += 1
logger.info(f"{split_name} - Type errors: {type_errors}")
# Check for empty inputs
empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip())
if empty_inputs > 0:
errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split")
logger.info(f"{split_name} - Empty inputs: {empty_inputs}")
# Check label distribution
labels = [item.get(label_field) for item in split_data if item.get(label_field) is not None]
unique_labels = set(labels)
logger.info(f"{split_name} - Found {len(unique_labels)} unique labels: {unique_labels}")
logger.info(f"{split_name} - Label distribution: {dict([(label, labels.count(label)) for label in unique_labels])}")
if len(unique_labels) < 1:
errors.append(f"{split_name} split must have at least 1 label, found: {unique_labels}")
# Show sample of processed data for debugging
if split_data:
logger.info(f"Sample processed items from {split_name}:")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', label='{item.get(label_field, '')}'")
return len(errors) == 0, errors
@staticmethod
def analyze_dataset(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Dict[str, Any]:
"""Analyze dataset characteristics across all splits"""
analysis = {
"splits": {},
"overall": {
"total_samples": 0,
"all_unique_labels": set(),
"split_sizes": {}
}
}
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
label_field = "label" # label field stays the same
# Analyze each split
for split_name, split_data in data.items():
split_analysis = {
"total_samples": len(split_data),
"unique_labels": len(set(item.get(label_field) for item in split_data)),
"label_distribution": {},
"text_length_stats": {},
"missing_values": {}
}
# Label distribution
labels = [item.get(label_field) for item in split_data]
for label in set(labels):
split_analysis["label_distribution"][str(label)] = labels.count(label)
analysis["overall"]["all_unique_labels"].add(str(label))
# Text length statistics
text_lengths = [len(item.get(input_field, "")) for item in split_data]
if text_lengths:
split_analysis["text_length_stats"] = {
"min": min(text_lengths),
"max": max(text_lengths),
"mean": np.mean(text_lengths),
"median": np.median(text_lengths)
}
# Missing values
for field in [input_field, label_field]:
missing_count = sum(1 for item in split_data if not item.get(field))
split_analysis["missing_values"][field] = missing_count
analysis["splits"][split_name] = split_analysis
analysis["overall"]["total_samples"] += len(split_data)
analysis["overall"]["split_sizes"][split_name] = len(split_data)
analysis["overall"]["all_unique_labels"] = len(analysis["overall"]["all_unique_labels"])
return analysis
class BaseDataLoader(ABC):
"""Abstract base class for data loaders"""
@abstractmethod
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Load data and return dictionary with train/val/test splits"""
pass
@abstractmethod
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits"""
pass
class HuggingFaceDataLoader(BaseDataLoader):
"""Load datasets from Hugging Face Hub"""
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Load dataset from Hugging Face Hub with flexible split handling"""
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
logger.info(f"Loading Hugging Face dataset: {config.dataset_name}")
try:
# First, let's check what splits are available in the dataset
dataset = load_dataset(
config.dataset_name,
cache_dir=config.hf_cache_dir
)
# Log available splits
available_splits = list(dataset.keys())
logger.info(f"Available splits in dataset: {available_splits}")
# Initialize split data
splits_data = {
"train": [],
"validation": [],
"test": []
}
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
logger.error("No 'train' split found in dataset!")
logger.error(f"Available splits: {available_splits}")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation split
if config.val_split_from == "use_val_if_available" and "validation" in available_splits:
val_dataset = dataset["validation"]
logger.info(f"Using 'validation' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available" and "val" in available_splits:
val_dataset = dataset["val"]
logger.info(f"Using 'val' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available":
logger.warning("No validation split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.validation_split * 100}% of train data for validation")
else:
logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)")
# Handle test split
if config.test_split_from == "use_test_if_available" and "test" in available_splits:
test_dataset = dataset["test"]
logger.info(f"Using 'test' split with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "validation" in available_splits:
test_dataset = dataset["validation"]
logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "val" in available_splits:
test_dataset = dataset["val"]
logger.info(f"Using 'val' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_test_if_available":
logger.warning("No test split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.test_split * 100}% of train data for test")
else:
logger.info(f"Will create test split from train data ({config.test_split * 100}%)")
# If we need to create splits from train data
if not splits_data["validation"] or not splits_data["test"]:
train_data = splits_data["train"]
# Calculate remaining percentages for train
total_train_percentage = config.train_split + config.validation_split + config.test_split
if total_train_percentage != 1.0:
logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...")
# Normalize percentages
config.train_split = config.train_split / total_train_percentage
config.validation_split = config.validation_split / total_train_percentage
config.test_split = config.test_split / total_train_percentage
# Create splits from train data
if not splits_data["validation"] and not splits_data["test"]:
# Split train into train, val, test
train_size = int(len(train_data) * config.train_split)
val_size = int(len(train_data) * config.validation_split)
# First split: train + (val+test)
new_train, temp_data = train_test_split(
train_data,
test_size=config.validation_split + config.test_split,
random_state=42,
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
)
# Second split: val + test
new_val, new_test = train_test_split(
temp_data,
test_size=config.test_split / (config.validation_split + config.test_split),
random_state=42,
stratify=[item.get(config.label_field) for item in temp_data] if config.label_field in temp_data[0] else None
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
splits_data["test"] = new_test
elif not splits_data["validation"]:
# Only need to create val from train
new_train, new_val = train_test_split(
train_data,
test_size=config.validation_split,
random_state=42,
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
elif not splits_data["test"]:
# Only need to create test from train
new_train, new_test = train_test_split(
train_data,
test_size=config.test_split,
random_state=42,
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
)
splits_data["train"] = new_train
splits_data["test"] = new_test
logger.info(f"Final split sizes:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
# Apply max_samples limit to each split if specified
if config.max_samples:
for split_name in splits_data:
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
# Log dataset info for debugging
for split_name, split_data in splits_data.items():
if split_data:
logger.info(f"Sample data item from {split_name}: {split_data[0]}")
logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}")
# Check if the required fields exist
if config.input_field not in split_data[0]:
logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])]
if text_fields:
logger.info(f"Suggested text fields for {split_name}: {text_fields}")
if config.label_field not in split_data[0]:
logger.warning(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
label_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['label', 'class', 'category', 'target', 'emotion', 'labels'])]
if label_fields:
logger.info(f"Suggested label fields for {split_name}: {label_fields}")
logger.info(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
# Log field availability for debugging
if split_data:
available_fields = set(split_data[0].keys())
logger.info(f"Available fields in {split_name}: {available_fields}")
logger.info(f"Looking for input field: '{config.input_field}', label field: '{config.label_field}'")
if config.input_field not in available_fields:
logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}")
if config.label_field not in available_fields:
logger.error(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {available_fields}")
# Count items with missing fields
missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field))
missing_label = sum(1 for item in split_data if config.label_field not in item or item.get(config.label_field) is None)
logger.info(f"{split_name} - Items missing input field: {missing_input}")
logger.info(f"{split_name} - Items missing label field: {missing_label}")
# Show sample of raw data before preprocessing
logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f"Raw item {i} from {split_name}:")
for key, value in item.items():
if isinstance(value, str) and len(value) > 100:
logger.info(f" {key}: '{value[:100]}...'")
else:
logger.info(f" {key}: {value}")
# Process each item in the split
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
# Show sample of processed data
if processed_data:
logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===")
for i in range(min(3, len(processed_data))):
logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}")
return processed_splits
def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and label
input_text = item.get(config.input_field, "")
label = item.get(config.label_field, "")
# Log what we're extracting (for first few items)
if hasattr(self, '_debug_count'):
self._debug_count += 1
else:
self._debug_count = 1
if self._debug_count <= 3:
logger.debug(f"Processing item {self._debug_count}:")
logger.debug(f" Looking for input field '{config.input_field}': {input_text}")
logger.debug(f" Looking for label field '{config.label_field}': {label}")
# Handle None values
if input_text is None:
input_text = ""
if label is None:
label = ""
# Convert to string if needed
input_text = str(input_text)
label = str(label)
if self._debug_count <= 3:
logger.debug(f" After conversion - input: '{input_text[:50]}...', label: '{label}'")
# Clean text if requested
if config.clean_text:
original_text = input_text
input_text = self._clean_text(input_text, config)
if self._debug_count <= 3:
logger.debug(f" After cleaning - original: '{original_text[:50]}...', cleaned: '{input_text[:50]}...'")
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
if self._debug_count <= 3:
logger.debug(f" Skipping - length {len(input_text)} not in range [{config.min_length}, {config.max_length}]")
return None
# Create processed item
processed_item = {
"input": input_text,
"label": label
}
# Add ID if available
if config.id_field and config.id_field in item:
processed_item["id"] = item[config.id_field]
if self._debug_count <= 3:
logger.debug(f" Final processed item: {processed_item}")
return processed_item
def _clean_text(self, text: str, config: ClassificationConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
"""Helper function to create a HuggingFace configuration"""
return ClassificationConfig(
data_source="huggingface",
dataset_name=dataset_name,
input_field=input_field,
label_field=label_field,
**kwargs
)
class CustomDataLoader(BaseDataLoader):
"""Load custom datasets from local files"""
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Load custom dataset from local file and create splits"""
if not config.data_path:
raise ValueError("Data path is required for custom datasets")
file_path = Path(config.data_path)
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
logger.info(f"Loading custom dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
elif config.data_format == "csv":
raw_data = self._load_csv(file_path, config)
elif config.data_format == "json":
raw_data = self._load_json(file_path, config)
else:
raise ValueError(f"Unsupported format: {config.data_format}")
if config.max_samples:
raw_data = raw_data[:config.max_samples]
logger.info(f"Loaded {len(raw_data)} samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
return splits_data
def _create_splits(self, data: List[Dict], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
logger.info(f"Creating splits from {len(data)} samples...")
# Calculate split sizes
total_samples = len(data)
train_size = int(total_samples * config.train_split)
val_size = int(total_samples * config.validation_split)
test_size = total_samples - train_size - val_size
# Create stratified splits if possible
try:
labels = [item.get(config.label_field) for item in data]
# First split: train + (val+test)
train_data, temp_data = train_test_split(
data,
test_size=config.validation_split + config.test_split,
random_state=42,
stratify=labels
)
# Second split: val + test
temp_labels = [item.get(config.label_field) for item in temp_data]
val_data, test_data = train_test_split(
temp_data,
test_size=config.test_split / (config.validation_split + config.test_split),
random_state=42,
stratify=temp_labels
)
except ValueError as e:
logger.warning(f"Could not create stratified splits: {e}. Using random splits.")
# Fallback to random splits
train_data, temp_data = train_test_split(data, test_size=config.validation_split + config.test_split, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=config.test_split / (config.validation_split + config.test_split), random_state=42)
splits_data = {
"train": train_data,
"validation": val_data,
"test": test_data
}
logger.info(f"Created splits:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
return splits_data
def _load_jsonl(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding=config.encoding) as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_csv(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
"""Load CSV file"""
df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter)
return df.to_dict('records')
def _load_json(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
"""Load JSON file"""
with open(file_path, 'r', encoding=config.encoding) as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "data" in data:
return data["data"]
else:
return [data]
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CUSTOM DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and label
input_text = item.get(config.input_field, "")
label = item.get(config.label_field, "")
# Handle None values
if input_text is None:
input_text = ""
if label is None:
label = ""
# Convert to string if needed
input_text = str(input_text)
label = str(label)
# Clean text if requested
if config.clean_text:
input_text = self._clean_text(input_text, config)
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
return None
# Create processed item
processed_item = {
"input": input_text,
"label": label
}
# Add ID if available
if config.id_field and config.id_field in item:
processed_item["id"] = item[config.id_field]
return processed_item
def _clean_text(self, text: str, config: ClassificationConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
class ClassificationDataPipeline:
"""Main classification pipeline"""
def __init__(self):
self.validator = DataValidator()
self.hf_loader = HuggingFaceDataLoader()
self.custom_loader = CustomDataLoader()
def create_config(
self,
data_source: str,
dataset_name: Optional[str] = None,
data_path: Optional[str] = None,
input_field: str = "text",
label_field: str = "label",
**kwargs
) -> ClassificationConfig:
"""Create classification configuration"""
return ClassificationConfig(
data_source=data_source,
dataset_name=dataset_name,
data_path=data_path,
input_field=input_field,
label_field=label_field,
**kwargs
)
def load_and_preprocess(self, config: ClassificationConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess data"""
# Load data
if config.data_source == "huggingface":
raw_splits = self.hf_loader.load(config)
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
raw_splits = self.custom_loader.load(config)
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
# Validate processed data
is_valid, errors = self.validator.validate_classification_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Data validation failed:")
for error in errors:
logger.error(f" - {error}")
raise ValueError("Data validation failed")
# Analyze dataset
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
return processed_splits, analysis
def convert_to_classification_format(self, data: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
"""Convert classification data to standard classification format"""
classification_splits = {}
for split_name, split_data in data.items():
classification_data = []
for item in split_data:
classification_data.append({
"text": item["input"],
"label": item["label"]
})
classification_splits[split_name] = classification_data
return classification_splits
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
"""Save processed data splits to files"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for split_name, split_data in data.items():
if format == "jsonl":
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == "json":
output_file = output_path / f"{split_name}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
elif format == "csv":
output_file = output_path / f"{split_name}.csv"
df = pd.DataFrame(split_data)
df.to_csv(output_file, index=False)
logger.info(f"Saved {len(split_data)} samples to {output_file}")
def run_pipeline(
self,
config: ClassificationConfig,
output_format: str = "classification",
save_splits: bool = True
) -> Dict[str, Any]:
"""Run complete classification pipeline"""
logger.info("Starting classification pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
# Convert to desired output format
if output_format == "classification":
formatted_splits = self.convert_to_classification_format(processed_splits)
else:
formatted_splits = processed_splits
# Save data if requested
if save_splits:
output_dir = Path(config.output_dir) / output_format
self.save_data(formatted_splits, str(output_dir))
# Create result summary
result = {
"config": config,
"analysis": analysis,
"splits": {
split_name: len(split_data) for split_name, split_data in formatted_splits.items()
},
"output_format": output_format,
"output_dir": config.output_dir,
"data": formatted_splits # Include the actual processed data
}
logger.info("Classification pipeline completed successfully!")
return result
def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
"""Helper function to create a HuggingFace configuration"""
return ClassificationConfig(
data_source="huggingface",
dataset_name=dataset_name,
input_field=input_field,
label_field=label_field,
**kwargs
)
def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
"""Helper function to create a custom data configuration"""
return ClassificationConfig(
data_source="custom",
data_path=data_path,
data_format=data_format,
input_field=input_field,
label_field=label_field,
**kwargs
)
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Classification Data Processing Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
# Field mapping
parser.add_argument("--input-field", type=str, help="Input field name")
parser.add_argument("--label-field", type=str, help="Label field name")
parser.add_argument("--id-field", type=str, help="Optional ID field name")
# Data processing
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--train-split", type=float, help="Training split ratio")
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
parser.add_argument("--test-split", type=float, help="Test split ratio")
# Text preprocessing
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
parser.add_argument("--min-length", type=int, help="Minimum text length")
parser.add_argument("--max-length", type=int, help="Maximum text length")
# Output configuration
parser.add_argument("--output-format", choices=["classification", "instruction", "conversation", "qa"], help="Output format")
parser.add_argument("--output-dir", type=str, help="Output directory")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration from: {args.config}")
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.data_source:
cli_overrides['data_source'] = args.data_source
if args.dataset_name:
cli_overrides['dataset_name'] = args.dataset_name
if args.data_path:
cli_overrides['data_path'] = args.data_path
if args.data_format:
cli_overrides['data_format'] = args.data_format
if args.input_field:
cli_overrides['input_field'] = args.input_field
if args.label_field:
cli_overrides['label_field'] = args.label_field
if args.id_field:
cli_overrides['id_field'] = args.id_field
if args.max_samples:
cli_overrides['max_samples'] = args.max_samples
if args.train_split:
cli_overrides['train_split'] = args.train_split
if args.validation_split:
cli_overrides['validation_split'] = args.validation_split
if args.test_split:
cli_overrides['test_split'] = args.test_split
if args.clean_text:
cli_overrides['clean_text'] = True
if args.remove_special_chars:
cli_overrides['remove_special_chars'] = True
if args.lowercase:
cli_overrides['lowercase'] = True
if args.min_length:
cli_overrides['min_length'] = args.min_length
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.output_format:
cli_overrides['output_format'] = args.output_format
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
if not config_dict.get('data', {}).get('source'):
parser.error("--data-source is required (either in YAML config or CLI)")
if config_dict.get('data', {}).get('source') == "huggingface" and not config_dict.get('data', {}).get('dataset_name'):
parser.error("--dataset-name is required for HuggingFace datasets")
if config_dict.get('data', {}).get('source') == "custom" and not config_dict.get('data', {}).get('data_path'):
parser.error("--data-path is required for custom datasets")
# Create configuration object
config = ClassificationConfig(
data_source=config_dict.get('data', {}).get('source', 'huggingface'),
dataset_name=config_dict.get('data', {}).get('dataset_name'),
data_path=config_dict.get('data', {}).get('data_path'),
data_format=config_dict.get('data', {}).get('data_format', 'jsonl'),
input_field=config_dict.get('data', {}).get('input_field', 'text'),
label_field=config_dict.get('data', {}).get('label_field', 'label'),
id_field=config_dict.get('data', {}).get('id_field'),
max_samples=config_dict.get('data', {}).get('max_samples'),
train_split=config_dict.get('data', {}).get('train_split', 0.8),
validation_split=config_dict.get('data', {}).get('validation_split', 0.1),
test_split=config_dict.get('data', {}).get('test_split', 0.1),
clean_text=config_dict.get('data', {}).get('clean_text', True),
remove_special_chars=config_dict.get('data', {}).get('remove_special_chars', False),
lowercase=config_dict.get('data', {}).get('lowercase', True),
min_length=config_dict.get('data', {}).get('min_length', 10),
max_length=config_dict.get('data', {}).get('max_length', 1000),
label_encoding=config_dict.get('data', {}).get('label_encoding', 'auto'),
multilabel=config_dict.get('data', {}).get('multilabel', False),
label_separator=config_dict.get('data', {}).get('label_separator', ','),
output_format=config_dict.get('data', {}).get('output_format', 'classification'),
output_dir=config_dict.get('data', {}).get('output_dir', './data'),
hf_split=config_dict.get('data', {}).get('hf_split', 'train'),
hf_cache_dir=config_dict.get('data', {}).get('hf_cache_dir'),
test_split_from=config_dict.get('data', {}).get('test_split_from', 'train'),
val_split_from=config_dict.get('data', {}).get('val_split_from', 'train'),
encoding=config_dict.get('data', {}).get('encoding', 'utf-8'),
delimiter=config_dict.get('data', {}).get('delimiter', ',')
)
# Initialize pipeline
pipeline = ClassificationDataPipeline()
try:
print(f"Starting classification pipeline with {config.data_source} data source...")
if args.config:
print(f"Using YAML configuration: {args.config}")
print()
result = pipeline.run_pipeline(config, config.output_format, save_splits=True)
print(f"✅ Pipeline completed successfully!")
print(f" Data source: {config.data_source}")
if config.data_source == "huggingface":
print(f" Dataset: {config.dataset_name}")
else:
print(f" Data file: {config.data_path}")
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
print(f" Unique labels: {result['analysis']['overall']['all_unique_labels']}")
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
print(f" Output directory: {config.output_dir}")
except Exception as e:
print(f"❌ Error running pipeline: {e}")
sys.exit(1)
if __name__ == "__main__":
main()