Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/classification/data_processor.py
T
OwusuBlessing fef3f5ae35 initial setupt
2025-08-06 22:45:37 +01:00

1073 lines
46 KiB
Python

import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import re
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Set logger level to DEBUG to capture INFO, DEBUG, ERROR
@dataclass
class ClassificationConfig:
"""Configuration for classification tasks"""
# Data source configuration
data_source: str ="huggingface" # "huggingface" or "custom"
dataset_name: Optional[str] = None # For Hugging Face datasets
data_path: Optional[str] = None # For custom datasets
data_format: str = "jsonl" # jsonl, csv, json
# Field mapping
input_field: str = "text" # Field containing input text
label_field: str = "label" # Field containing labels
id_field: Optional[str] = None # Optional ID field
# Data processing
max_samples: Optional[int] = None
train_split: float = 0.8
validation_split: float = 0.1
test_split: float = 0.1
# Text preprocessing
clean_text: bool = True
remove_special_chars: bool = False
lowercase: bool = True
min_length: int = 10
max_length: int = 1000
# Label processing
label_encoding: str = "auto" # auto, numeric, string
multilabel: bool = False
label_separator: str = "," # For multilabel datasets
# Output configuration
output_format: str = "classification" # instruction, conversation, qa
output_dir: str = "./data"
# Hugging Face specific
hf_split: str = "train"
hf_cache_dir: Optional[str] = None
# Split configuration - new flexible split handling
test_split_from: str = "train" # "train", "use_test_if_available", or "use_val_if_available"
val_split_from: str = "train" # "train", "use_val_if_available"
# Custom data specific
encoding: str = "utf-8"
delimiter: str = "," # For CSV files
class DataValidator:
"""Validates classification data quality and format"""
@staticmethod
def validate_classification_data(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
"""Validate classification dataset splits"""
errors = []
# Check if we have the expected splits
expected_splits = ["train", "validation", "test"]
for split in expected_splits:
if split not in data or not data[split]:
errors.append(f"Missing or empty '{split}' split")
if errors:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
logger.info(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
label_field = "label" # label field stays the same
# Validate each split
for split_name, split_data in data.items():
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_input_count = 0
missing_label_count = 0
for i, item in enumerate(split_data):
if input_field not in item:
errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}")
missing_input_count += 1
if label_field not in item:
errors.append(f"Missing label field '{label_field}' in {split_name} split, item {i}")
missing_label_count += 1
logger.info(f"{split_name} - Items missing input field: {missing_input_count}")
logger.info(f"{split_name} - Items missing label field: {missing_label_count}")
# Check data types
type_errors = 0
for i, item in enumerate(split_data):
if not isinstance(item.get(input_field, ""), str):
errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}")
type_errors += 1
logger.info(f"{split_name} - Type errors: {type_errors}")
# Check for empty inputs
empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip())
if empty_inputs > 0:
errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split")
logger.info(f"{split_name} - Empty inputs: {empty_inputs}")
# Check label distribution
labels = [item.get(label_field) for item in split_data if item.get(label_field) is not None]
unique_labels = set(labels)
logger.info(f"{split_name} - Found {len(unique_labels)} unique labels: {unique_labels}")
logger.info(f"{split_name} - Label distribution: {dict([(label, labels.count(label)) for label in unique_labels])}")
if len(unique_labels) < 1:
errors.append(f"{split_name} split must have at least 1 label, found: {unique_labels}")
# Show sample of processed data for debugging
if split_data:
logger.info(f"Sample processed items from {split_name}:")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', label='{item.get(label_field, '')}'")
return len(errors) == 0, errors
@staticmethod
def analyze_dataset(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Dict[str, Any]:
"""Analyze dataset characteristics across all splits"""
analysis = {
"splits": {},
"overall": {
"total_samples": 0,
"all_unique_labels": set(),
"split_sizes": {}
}
}
# Determine field names based on whether data is processed or not
input_field = "input" if is_processed else config.input_field
label_field = "label" # label field stays the same
# Analyze each split
for split_name, split_data in data.items():
split_analysis = {
"total_samples": len(split_data),
"unique_labels": len(set(item.get(label_field) for item in split_data)),
"label_distribution": {},
"text_length_stats": {},
"missing_values": {}
}
# Label distribution
labels = [item.get(label_field) for item in split_data]
for label in set(labels):
split_analysis["label_distribution"][str(label)] = labels.count(label)
analysis["overall"]["all_unique_labels"].add(str(label))
# Text length statistics
text_lengths = [len(item.get(input_field, "")) for item in split_data]
if text_lengths:
split_analysis["text_length_stats"] = {
"min": min(text_lengths),
"max": max(text_lengths),
"mean": np.mean(text_lengths),
"median": np.median(text_lengths)
}
# Missing values
for field in [input_field, label_field]:
missing_count = sum(1 for item in split_data if not item.get(field))
split_analysis["missing_values"][field] = missing_count
analysis["splits"][split_name] = split_analysis
analysis["overall"]["total_samples"] += len(split_data)
analysis["overall"]["split_sizes"][split_name] = len(split_data)
analysis["overall"]["all_unique_labels"] = len(analysis["overall"]["all_unique_labels"])
return analysis
class BaseDataLoader(ABC):
"""Abstract base class for data loaders"""
@abstractmethod
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Load data and return dictionary with train/val/test splits"""
pass
@abstractmethod
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits"""
pass
class HuggingFaceDataLoader(BaseDataLoader):
"""Load datasets from Hugging Face Hub"""
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Load dataset from Hugging Face Hub with flexible split handling"""
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
logger.info(f"Loading Hugging Face dataset: {config.dataset_name}")
try:
# First, let's check what splits are available in the dataset
dataset = load_dataset(
config.dataset_name,
cache_dir=config.hf_cache_dir
)
# Log available splits
available_splits = list(dataset.keys())
logger.info(f"Available splits in dataset: {available_splits}")
# Initialize split data
splits_data = {
"train": [],
"validation": [],
"test": []
}
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
logger.error("No 'train' split found in dataset!")
logger.error(f"Available splits: {available_splits}")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation split
if config.val_split_from == "use_val_if_available" and "validation" in available_splits:
val_dataset = dataset["validation"]
logger.info(f"Using 'validation' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available" and "val" in available_splits:
val_dataset = dataset["val"]
logger.info(f"Using 'val' split with {len(val_dataset)} samples")
splits_data["validation"] = list(val_dataset)
elif config.val_split_from == "use_val_if_available":
logger.warning("No validation split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.validation_split * 100}% of train data for validation")
else:
logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)")
# Handle test split
if config.test_split_from == "use_test_if_available" and "test" in available_splits:
test_dataset = dataset["test"]
logger.info(f"Using 'test' split with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "validation" in available_splits:
test_dataset = dataset["validation"]
logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_val_if_available" and "val" in available_splits:
test_dataset = dataset["val"]
logger.info(f"Using 'val' split as test with {len(test_dataset)} samples")
splits_data["test"] = list(test_dataset)
elif config.test_split_from == "use_test_if_available":
logger.warning("No test split found in dataset. Will create from train split.")
logger.info(f"Available splits: {available_splits}")
logger.info(f"Will use {config.test_split * 100}% of train data for test")
else:
logger.info(f"Will create test split from train data ({config.test_split * 100}%)")
# If we need to create splits from train data
if not splits_data["validation"] or not splits_data["test"]:
train_data = splits_data["train"]
# Calculate remaining percentages for train
total_train_percentage = config.train_split + config.validation_split + config.test_split
if total_train_percentage != 1.0:
logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...")
# Normalize percentages
config.train_split = config.train_split / total_train_percentage
config.validation_split = config.validation_split / total_train_percentage
config.test_split = config.test_split / total_train_percentage
# Create splits from train data
if not splits_data["validation"] and not splits_data["test"]:
# Split train into train, val, test
train_size = int(len(train_data) * config.train_split)
val_size = int(len(train_data) * config.validation_split)
# First split: train + (val+test)
new_train, temp_data = train_test_split(
train_data,
test_size=config.validation_split + config.test_split,
random_state=42,
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
)
# Second split: val + test
new_val, new_test = train_test_split(
temp_data,
test_size=config.test_split / (config.validation_split + config.test_split),
random_state=42,
stratify=[item.get(config.label_field) for item in temp_data] if config.label_field in temp_data[0] else None
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
splits_data["test"] = new_test
elif not splits_data["validation"]:
# Only need to create val from train
new_train, new_val = train_test_split(
train_data,
test_size=config.validation_split,
random_state=42,
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
)
splits_data["train"] = new_train
splits_data["validation"] = new_val
elif not splits_data["test"]:
# Only need to create test from train
new_train, new_test = train_test_split(
train_data,
test_size=config.test_split,
random_state=42,
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
)
splits_data["train"] = new_train
splits_data["test"] = new_test
logger.info(f"Final split sizes:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
# Apply max_samples limit to each split if specified
if config.max_samples:
for split_name in splits_data:
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
# Log dataset info for debugging
for split_name, split_data in splits_data.items():
if split_data:
logger.info(f"Sample data item from {split_name}: {split_data[0]}")
logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}")
# Check if the required fields exist
if config.input_field not in split_data[0]:
logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])]
if text_fields:
logger.info(f"Suggested text fields for {split_name}: {text_fields}")
if config.label_field not in split_data[0]:
logger.warning(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
# Suggest alternative fields
label_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['label', 'class', 'category', 'target', 'emotion', 'labels'])]
if label_fields:
logger.info(f"Suggested label fields for {split_name}: {label_fields}")
logger.info(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
# Log field availability for debugging
if split_data:
available_fields = set(split_data[0].keys())
logger.info(f"Available fields in {split_name}: {available_fields}")
logger.info(f"Looking for input field: '{config.input_field}', label field: '{config.label_field}'")
if config.input_field not in available_fields:
logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}")
if config.label_field not in available_fields:
logger.error(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {available_fields}")
# Count items with missing fields
missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field))
missing_label = sum(1 for item in split_data if config.label_field not in item or item.get(config.label_field) is None)
logger.info(f"{split_name} - Items missing input field: {missing_input}")
logger.info(f"{split_name} - Items missing label field: {missing_label}")
# Show sample of raw data before preprocessing
logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===")
for i in range(min(3, len(split_data))):
item = split_data[i]
logger.info(f"Raw item {i} from {split_name}:")
for key, value in item.items():
if isinstance(value, str) and len(value) > 100:
logger.info(f" {key}: '{value[:100]}...'")
else:
logger.info(f" {key}: {value}")
# Process each item in the split
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
# Show sample of processed data
if processed_data:
logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===")
for i in range(min(3, len(processed_data))):
logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}")
return processed_splits
def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and label
input_text = item.get(config.input_field, "")
label = item.get(config.label_field, "")
# Log what we're extracting (for first few items)
if hasattr(self, '_debug_count'):
self._debug_count += 1
else:
self._debug_count = 1
if self._debug_count <= 3:
logger.debug(f"Processing item {self._debug_count}:")
logger.debug(f" Looking for input field '{config.input_field}': {input_text}")
logger.debug(f" Looking for label field '{config.label_field}': {label}")
# Handle None values
if input_text is None:
input_text = ""
if label is None:
label = ""
# Convert to string if needed
input_text = str(input_text)
label = str(label)
if self._debug_count <= 3:
logger.debug(f" After conversion - input: '{input_text[:50]}...', label: '{label}'")
# Clean text if requested
if config.clean_text:
original_text = input_text
input_text = self._clean_text(input_text, config)
if self._debug_count <= 3:
logger.debug(f" After cleaning - original: '{original_text[:50]}...', cleaned: '{input_text[:50]}...'")
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
if self._debug_count <= 3:
logger.debug(f" Skipping - length {len(input_text)} not in range [{config.min_length}, {config.max_length}]")
return None
# Create processed item
processed_item = {
"input": input_text,
"label": label
}
# Add ID if available
if config.id_field and config.id_field in item:
processed_item["id"] = item[config.id_field]
if self._debug_count <= 3:
logger.debug(f" Final processed item: {processed_item}")
return processed_item
def _clean_text(self, text: str, config: ClassificationConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
"""Helper function to create a HuggingFace configuration"""
return ClassificationConfig(
data_source="huggingface",
dataset_name=dataset_name,
input_field=input_field,
label_field=label_field,
**kwargs
)
class CustomDataLoader(BaseDataLoader):
"""Load custom datasets from local files"""
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Load custom dataset from local file and create splits"""
if not config.data_path:
raise ValueError("Data path is required for custom datasets")
file_path = Path(config.data_path)
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
logger.info(f"Loading custom dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
elif config.data_format == "csv":
raw_data = self._load_csv(file_path, config)
elif config.data_format == "json":
raw_data = self._load_json(file_path, config)
else:
raise ValueError(f"Unsupported format: {config.data_format}")
if config.max_samples:
raw_data = raw_data[:config.max_samples]
logger.info(f"Loaded {len(raw_data)} samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
return splits_data
def _create_splits(self, data: List[Dict], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
logger.info(f"Creating splits from {len(data)} samples...")
# Calculate split sizes
total_samples = len(data)
train_size = int(total_samples * config.train_split)
val_size = int(total_samples * config.validation_split)
test_size = total_samples - train_size - val_size
# Create stratified splits if possible
try:
labels = [item.get(config.label_field) for item in data]
# First split: train + (val+test)
train_data, temp_data = train_test_split(
data,
test_size=config.validation_split + config.test_split,
random_state=42,
stratify=labels
)
# Second split: val + test
temp_labels = [item.get(config.label_field) for item in temp_data]
val_data, test_data = train_test_split(
temp_data,
test_size=config.test_split / (config.validation_split + config.test_split),
random_state=42,
stratify=temp_labels
)
except ValueError as e:
logger.warning(f"Could not create stratified splits: {e}. Using random splits.")
# Fallback to random splits
train_data, temp_data = train_test_split(data, test_size=config.validation_split + config.test_split, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=config.test_split / (config.validation_split + config.test_split), random_state=42)
splits_data = {
"train": train_data,
"validation": val_data,
"test": test_data
}
logger.info(f"Created splits:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
return splits_data
def _load_jsonl(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding=config.encoding) as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_csv(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
"""Load CSV file"""
df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter)
return df.to_dict('records')
def _load_json(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
"""Load JSON file"""
with open(file_path, 'r', encoding=config.encoding) as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "data" in data:
return data["data"]
else:
return [data]
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CUSTOM DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
# Reset debug counter for each split
self._debug_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]:
"""Preprocess a single item"""
# Extract input and label
input_text = item.get(config.input_field, "")
label = item.get(config.label_field, "")
# Handle None values
if input_text is None:
input_text = ""
if label is None:
label = ""
# Convert to string if needed
input_text = str(input_text)
label = str(label)
# Clean text if requested
if config.clean_text:
input_text = self._clean_text(input_text, config)
# Check length constraints
if len(input_text) < config.min_length or len(input_text) > config.max_length:
return None
# Create processed item
processed_item = {
"input": input_text,
"label": label
}
# Add ID if available
if config.id_field and config.id_field in item:
processed_item["id"] = item[config.id_field]
return processed_item
def _clean_text(self, text: str, config: ClassificationConfig) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Convert to lowercase if requested
if config.lowercase:
text = text.lower()
# Remove special characters if requested
if config.remove_special_chars:
text = re.sub(r'[^\w\s]', '', text)
return text
class ClassificationDataPipeline:
"""Main classification pipeline"""
def __init__(self):
self.validator = DataValidator()
self.hf_loader = HuggingFaceDataLoader()
self.custom_loader = CustomDataLoader()
def create_config(
self,
data_source: str,
dataset_name: Optional[str] = None,
data_path: Optional[str] = None,
input_field: str = "text",
label_field: str = "label",
**kwargs
) -> ClassificationConfig:
"""Create classification configuration"""
return ClassificationConfig(
data_source=data_source,
dataset_name=dataset_name,
data_path=data_path,
input_field=input_field,
label_field=label_field,
**kwargs
)
def load_and_preprocess(self, config: ClassificationConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess data"""
# Load data
if config.data_source == "huggingface":
raw_splits = self.hf_loader.load(config)
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
raw_splits = self.custom_loader.load(config)
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
# Validate processed data
is_valid, errors = self.validator.validate_classification_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Data validation failed:")
for error in errors:
logger.error(f" - {error}")
raise ValueError("Data validation failed")
# Analyze dataset
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
return processed_splits, analysis
def convert_to_classification_format(self, data: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
"""Convert classification data to standard classification format"""
classification_splits = {}
for split_name, split_data in data.items():
classification_data = []
for item in split_data:
classification_data.append({
"text": item["input"],
"label": item["label"]
})
classification_splits[split_name] = classification_data
return classification_splits
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
"""Save processed data splits to files"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for split_name, split_data in data.items():
if format == "jsonl":
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == "json":
output_file = output_path / f"{split_name}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
elif format == "csv":
output_file = output_path / f"{split_name}.csv"
df = pd.DataFrame(split_data)
df.to_csv(output_file, index=False)
logger.info(f"Saved {len(split_data)} samples to {output_file}")
def run_pipeline(
self,
config: ClassificationConfig,
output_format: str = "classification",
save_splits: bool = True
) -> Dict[str, Any]:
"""Run complete classification pipeline"""
logger.info("Starting classification pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
# Convert to desired output format
if output_format == "classification":
formatted_splits = self.convert_to_classification_format(processed_splits)
else:
formatted_splits = processed_splits
# Save data if requested
if save_splits:
output_dir = Path(config.output_dir) / output_format
self.save_data(formatted_splits, str(output_dir))
# Create result summary
result = {
"config": config,
"analysis": analysis,
"splits": {
split_name: len(split_data) for split_name, split_data in formatted_splits.items()
},
"output_format": output_format,
"output_dir": config.output_dir,
"data": formatted_splits # Include the actual processed data
}
logger.info("Classification pipeline completed successfully!")
return result
def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
"""Helper function to create a HuggingFace configuration"""
return ClassificationConfig(
data_source="huggingface",
dataset_name=dataset_name,
input_field=input_field,
label_field=label_field,
**kwargs
)
def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
"""Helper function to create a custom data configuration"""
return ClassificationConfig(
data_source="custom",
data_path=data_path,
data_format=data_format,
input_field=input_field,
label_field=label_field,
**kwargs
)
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Classification Data Processing Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
# Field mapping
parser.add_argument("--input-field", type=str, help="Input field name")
parser.add_argument("--label-field", type=str, help="Label field name")
parser.add_argument("--id-field", type=str, help="Optional ID field name")
# Data processing
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--train-split", type=float, help="Training split ratio")
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
parser.add_argument("--test-split", type=float, help="Test split ratio")
# Text preprocessing
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
parser.add_argument("--min-length", type=int, help="Minimum text length")
parser.add_argument("--max-length", type=int, help="Maximum text length")
# Output configuration
parser.add_argument("--output-format", choices=["classification", "instruction", "conversation", "qa"], help="Output format")
parser.add_argument("--output-dir", type=str, help="Output directory")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
with open(args.config, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
logger.info(f"Loaded YAML configuration from: {args.config}")
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments
cli_overrides = {}
if args.data_source:
cli_overrides['data_source'] = args.data_source
if args.dataset_name:
cli_overrides['dataset_name'] = args.dataset_name
if args.data_path:
cli_overrides['data_path'] = args.data_path
if args.data_format:
cli_overrides['data_format'] = args.data_format
if args.input_field:
cli_overrides['input_field'] = args.input_field
if args.label_field:
cli_overrides['label_field'] = args.label_field
if args.id_field:
cli_overrides['id_field'] = args.id_field
if args.max_samples:
cli_overrides['max_samples'] = args.max_samples
if args.train_split:
cli_overrides['train_split'] = args.train_split
if args.validation_split:
cli_overrides['validation_split'] = args.validation_split
if args.test_split:
cli_overrides['test_split'] = args.test_split
if args.clean_text:
cli_overrides['clean_text'] = True
if args.remove_special_chars:
cli_overrides['remove_special_chars'] = True
if args.lowercase:
cli_overrides['lowercase'] = True
if args.min_length:
cli_overrides['min_length'] = args.min_length
if args.max_length:
cli_overrides['max_length'] = args.max_length
if args.output_format:
cli_overrides['output_format'] = args.output_format
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
if not config_dict.get('data', {}).get('source'):
parser.error("--data-source is required (either in YAML config or CLI)")
if config_dict.get('data', {}).get('source') == "huggingface" and not config_dict.get('data', {}).get('dataset_name'):
parser.error("--dataset-name is required for HuggingFace datasets")
if config_dict.get('data', {}).get('source') == "custom" and not config_dict.get('data', {}).get('data_path'):
parser.error("--data-path is required for custom datasets")
# Create configuration object
config = ClassificationConfig(
data_source=config_dict.get('data', {}).get('source', 'huggingface'),
dataset_name=config_dict.get('data', {}).get('dataset_name'),
data_path=config_dict.get('data', {}).get('data_path'),
data_format=config_dict.get('data', {}).get('data_format', 'jsonl'),
input_field=config_dict.get('data', {}).get('input_field', 'text'),
label_field=config_dict.get('data', {}).get('label_field', 'label'),
id_field=config_dict.get('data', {}).get('id_field'),
max_samples=config_dict.get('data', {}).get('max_samples'),
train_split=config_dict.get('data', {}).get('train_split', 0.8),
validation_split=config_dict.get('data', {}).get('validation_split', 0.1),
test_split=config_dict.get('data', {}).get('test_split', 0.1),
clean_text=config_dict.get('data', {}).get('clean_text', True),
remove_special_chars=config_dict.get('data', {}).get('remove_special_chars', False),
lowercase=config_dict.get('data', {}).get('lowercase', True),
min_length=config_dict.get('data', {}).get('min_length', 10),
max_length=config_dict.get('data', {}).get('max_length', 1000),
label_encoding=config_dict.get('data', {}).get('label_encoding', 'auto'),
multilabel=config_dict.get('data', {}).get('multilabel', False),
label_separator=config_dict.get('data', {}).get('label_separator', ','),
output_format=config_dict.get('data', {}).get('output_format', 'classification'),
output_dir=config_dict.get('data', {}).get('output_dir', './data'),
hf_split=config_dict.get('data', {}).get('hf_split', 'train'),
hf_cache_dir=config_dict.get('data', {}).get('hf_cache_dir'),
test_split_from=config_dict.get('data', {}).get('test_split_from', 'train'),
val_split_from=config_dict.get('data', {}).get('val_split_from', 'train'),
encoding=config_dict.get('data', {}).get('encoding', 'utf-8'),
delimiter=config_dict.get('data', {}).get('delimiter', ',')
)
# Initialize pipeline
pipeline = ClassificationDataPipeline()
try:
print(f"Starting classification pipeline with {config.data_source} data source...")
if args.config:
print(f"Using YAML configuration: {args.config}")
print()
result = pipeline.run_pipeline(config, config.output_format, save_splits=True)
print(f"✅ Pipeline completed successfully!")
print(f" Data source: {config.data_source}")
if config.data_source == "huggingface":
print(f" Dataset: {config.dataset_name}")
else:
print(f" Data file: {config.data_path}")
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
print(f" Unique labels: {result['analysis']['overall']['all_unique_labels']}")
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
print(f" Output directory: {config.output_dir}")
except Exception as e:
print(f"❌ Error running pipeline: {e}")
sys.exit(1)
if __name__ == "__main__":
main()