1073 lines
46 KiB
Python
1073 lines
46 KiB
Python
|
|
import json
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
||
|
|
from datasets import Dataset, load_dataset
|
||
|
|
import os
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
import logging
|
||
|
|
from sklearn.model_selection import train_test_split
|
||
|
|
from sklearn.preprocessing import LabelEncoder
|
||
|
|
import re
|
||
|
|
import argparse
|
||
|
|
import sys
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
logger.setLevel(logging.DEBUG) # Set logger level to DEBUG to capture INFO, DEBUG, ERROR
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ClassificationConfig:
|
||
|
|
"""Configuration for classification tasks"""
|
||
|
|
# Data source configuration
|
||
|
|
data_source: str ="huggingface" # "huggingface" or "custom"
|
||
|
|
dataset_name: Optional[str] = None # For Hugging Face datasets
|
||
|
|
data_path: Optional[str] = None # For custom datasets
|
||
|
|
data_format: str = "jsonl" # jsonl, csv, json
|
||
|
|
|
||
|
|
# Field mapping
|
||
|
|
input_field: str = "text" # Field containing input text
|
||
|
|
label_field: str = "label" # Field containing labels
|
||
|
|
id_field: Optional[str] = None # Optional ID field
|
||
|
|
|
||
|
|
# Data processing
|
||
|
|
max_samples: Optional[int] = None
|
||
|
|
train_split: float = 0.8
|
||
|
|
validation_split: float = 0.1
|
||
|
|
test_split: float = 0.1
|
||
|
|
|
||
|
|
# Text preprocessing
|
||
|
|
clean_text: bool = True
|
||
|
|
remove_special_chars: bool = False
|
||
|
|
lowercase: bool = True
|
||
|
|
min_length: int = 10
|
||
|
|
max_length: int = 1000
|
||
|
|
|
||
|
|
# Label processing
|
||
|
|
label_encoding: str = "auto" # auto, numeric, string
|
||
|
|
multilabel: bool = False
|
||
|
|
label_separator: str = "," # For multilabel datasets
|
||
|
|
|
||
|
|
# Output configuration
|
||
|
|
output_format: str = "classification" # instruction, conversation, qa
|
||
|
|
output_dir: str = "./data"
|
||
|
|
|
||
|
|
# Hugging Face specific
|
||
|
|
hf_split: str = "train"
|
||
|
|
hf_cache_dir: Optional[str] = None
|
||
|
|
|
||
|
|
# Split configuration - new flexible split handling
|
||
|
|
test_split_from: str = "train" # "train", "use_test_if_available", or "use_val_if_available"
|
||
|
|
val_split_from: str = "train" # "train", "use_val_if_available"
|
||
|
|
|
||
|
|
# Custom data specific
|
||
|
|
encoding: str = "utf-8"
|
||
|
|
delimiter: str = "," # For CSV files
|
||
|
|
|
||
|
|
|
||
|
|
class DataValidator:
|
||
|
|
"""Validates classification data quality and format"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def validate_classification_data(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
|
||
|
|
"""Validate classification dataset splits"""
|
||
|
|
errors = []
|
||
|
|
|
||
|
|
# Check if we have the expected splits
|
||
|
|
expected_splits = ["train", "validation", "test"]
|
||
|
|
for split in expected_splits:
|
||
|
|
if split not in data or not data[split]:
|
||
|
|
errors.append(f"Missing or empty '{split}' split")
|
||
|
|
|
||
|
|
if errors:
|
||
|
|
return False, errors
|
||
|
|
|
||
|
|
total_samples = sum(len(split_data) for split_data in data.values())
|
||
|
|
logger.info(f"Validating {total_samples} total samples across all splits...")
|
||
|
|
|
||
|
|
# Determine field names based on whether data is processed or not
|
||
|
|
input_field = "input" if is_processed else config.input_field
|
||
|
|
label_field = "label" # label field stays the same
|
||
|
|
|
||
|
|
# Validate each split
|
||
|
|
for split_name, split_data in data.items():
|
||
|
|
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
|
||
|
|
|
||
|
|
# Check required fields
|
||
|
|
missing_input_count = 0
|
||
|
|
missing_label_count = 0
|
||
|
|
|
||
|
|
for i, item in enumerate(split_data):
|
||
|
|
if input_field not in item:
|
||
|
|
errors.append(f"Missing input field '{input_field}' in {split_name} split, item {i}")
|
||
|
|
missing_input_count += 1
|
||
|
|
if label_field not in item:
|
||
|
|
errors.append(f"Missing label field '{label_field}' in {split_name} split, item {i}")
|
||
|
|
missing_label_count += 1
|
||
|
|
|
||
|
|
logger.info(f"{split_name} - Items missing input field: {missing_input_count}")
|
||
|
|
logger.info(f"{split_name} - Items missing label field: {missing_label_count}")
|
||
|
|
|
||
|
|
# Check data types
|
||
|
|
type_errors = 0
|
||
|
|
for i, item in enumerate(split_data):
|
||
|
|
if not isinstance(item.get(input_field, ""), str):
|
||
|
|
errors.append(f"Input field '{input_field}' must be string in {split_name} split, item {i}")
|
||
|
|
type_errors += 1
|
||
|
|
|
||
|
|
logger.info(f"{split_name} - Type errors: {type_errors}")
|
||
|
|
|
||
|
|
# Check for empty inputs
|
||
|
|
empty_inputs = sum(1 for item in split_data if not item.get(input_field, "").strip())
|
||
|
|
if empty_inputs > 0:
|
||
|
|
errors.append(f"Found {empty_inputs} items with empty input text in {split_name} split")
|
||
|
|
|
||
|
|
logger.info(f"{split_name} - Empty inputs: {empty_inputs}")
|
||
|
|
|
||
|
|
# Check label distribution
|
||
|
|
labels = [item.get(label_field) for item in split_data if item.get(label_field) is not None]
|
||
|
|
unique_labels = set(labels)
|
||
|
|
|
||
|
|
logger.info(f"{split_name} - Found {len(unique_labels)} unique labels: {unique_labels}")
|
||
|
|
logger.info(f"{split_name} - Label distribution: {dict([(label, labels.count(label)) for label in unique_labels])}")
|
||
|
|
|
||
|
|
if len(unique_labels) < 1:
|
||
|
|
errors.append(f"{split_name} split must have at least 1 label, found: {unique_labels}")
|
||
|
|
|
||
|
|
# Show sample of processed data for debugging
|
||
|
|
if split_data:
|
||
|
|
logger.info(f"Sample processed items from {split_name}:")
|
||
|
|
for i in range(min(3, len(split_data))):
|
||
|
|
item = split_data[i]
|
||
|
|
logger.info(f" Item {i}: input='{item.get(input_field, '')[:50]}...', label='{item.get(label_field, '')}'")
|
||
|
|
|
||
|
|
return len(errors) == 0, errors
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def analyze_dataset(data: Dict[str, List[Dict]], config: ClassificationConfig, is_processed: bool = False) -> Dict[str, Any]:
|
||
|
|
"""Analyze dataset characteristics across all splits"""
|
||
|
|
analysis = {
|
||
|
|
"splits": {},
|
||
|
|
"overall": {
|
||
|
|
"total_samples": 0,
|
||
|
|
"all_unique_labels": set(),
|
||
|
|
"split_sizes": {}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Determine field names based on whether data is processed or not
|
||
|
|
input_field = "input" if is_processed else config.input_field
|
||
|
|
label_field = "label" # label field stays the same
|
||
|
|
|
||
|
|
# Analyze each split
|
||
|
|
for split_name, split_data in data.items():
|
||
|
|
split_analysis = {
|
||
|
|
"total_samples": len(split_data),
|
||
|
|
"unique_labels": len(set(item.get(label_field) for item in split_data)),
|
||
|
|
"label_distribution": {},
|
||
|
|
"text_length_stats": {},
|
||
|
|
"missing_values": {}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Label distribution
|
||
|
|
labels = [item.get(label_field) for item in split_data]
|
||
|
|
for label in set(labels):
|
||
|
|
split_analysis["label_distribution"][str(label)] = labels.count(label)
|
||
|
|
analysis["overall"]["all_unique_labels"].add(str(label))
|
||
|
|
|
||
|
|
# Text length statistics
|
||
|
|
text_lengths = [len(item.get(input_field, "")) for item in split_data]
|
||
|
|
if text_lengths:
|
||
|
|
split_analysis["text_length_stats"] = {
|
||
|
|
"min": min(text_lengths),
|
||
|
|
"max": max(text_lengths),
|
||
|
|
"mean": np.mean(text_lengths),
|
||
|
|
"median": np.median(text_lengths)
|
||
|
|
}
|
||
|
|
|
||
|
|
# Missing values
|
||
|
|
for field in [input_field, label_field]:
|
||
|
|
missing_count = sum(1 for item in split_data if not item.get(field))
|
||
|
|
split_analysis["missing_values"][field] = missing_count
|
||
|
|
|
||
|
|
analysis["splits"][split_name] = split_analysis
|
||
|
|
analysis["overall"]["total_samples"] += len(split_data)
|
||
|
|
analysis["overall"]["split_sizes"][split_name] = len(split_data)
|
||
|
|
|
||
|
|
analysis["overall"]["all_unique_labels"] = len(analysis["overall"]["all_unique_labels"])
|
||
|
|
return analysis
|
||
|
|
|
||
|
|
|
||
|
|
class BaseDataLoader(ABC):
|
||
|
|
"""Abstract base class for data loaders"""
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Load data and return dictionary with train/val/test splits"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Apply preprocessing steps to all splits"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class HuggingFaceDataLoader(BaseDataLoader):
|
||
|
|
"""Load datasets from Hugging Face Hub"""
|
||
|
|
|
||
|
|
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Load dataset from Hugging Face Hub with flexible split handling"""
|
||
|
|
if not config.dataset_name:
|
||
|
|
raise ValueError("Dataset name is required for Hugging Face datasets")
|
||
|
|
|
||
|
|
logger.info(f"Loading Hugging Face dataset: {config.dataset_name}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# First, let's check what splits are available in the dataset
|
||
|
|
dataset = load_dataset(
|
||
|
|
config.dataset_name,
|
||
|
|
cache_dir=config.hf_cache_dir
|
||
|
|
)
|
||
|
|
|
||
|
|
# Log available splits
|
||
|
|
available_splits = list(dataset.keys())
|
||
|
|
logger.info(f"Available splits in dataset: {available_splits}")
|
||
|
|
|
||
|
|
# Initialize split data
|
||
|
|
splits_data = {
|
||
|
|
"train": [],
|
||
|
|
"validation": [],
|
||
|
|
"test": []
|
||
|
|
}
|
||
|
|
|
||
|
|
# Handle train split
|
||
|
|
if "train" in available_splits:
|
||
|
|
train_dataset = dataset["train"]
|
||
|
|
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
|
||
|
|
splits_data["train"] = list(train_dataset)
|
||
|
|
else:
|
||
|
|
logger.error("No 'train' split found in dataset!")
|
||
|
|
logger.error(f"Available splits: {available_splits}")
|
||
|
|
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
|
||
|
|
|
||
|
|
# Handle validation split
|
||
|
|
if config.val_split_from == "use_val_if_available" and "validation" in available_splits:
|
||
|
|
val_dataset = dataset["validation"]
|
||
|
|
logger.info(f"Using 'validation' split with {len(val_dataset)} samples")
|
||
|
|
splits_data["validation"] = list(val_dataset)
|
||
|
|
elif config.val_split_from == "use_val_if_available" and "val" in available_splits:
|
||
|
|
val_dataset = dataset["val"]
|
||
|
|
logger.info(f"Using 'val' split with {len(val_dataset)} samples")
|
||
|
|
splits_data["validation"] = list(val_dataset)
|
||
|
|
elif config.val_split_from == "use_val_if_available":
|
||
|
|
logger.warning("No validation split found in dataset. Will create from train split.")
|
||
|
|
logger.info(f"Available splits: {available_splits}")
|
||
|
|
logger.info(f"Will use {config.validation_split * 100}% of train data for validation")
|
||
|
|
else:
|
||
|
|
logger.info(f"Will create validation split from train data ({config.validation_split * 100}%)")
|
||
|
|
|
||
|
|
# Handle test split
|
||
|
|
if config.test_split_from == "use_test_if_available" and "test" in available_splits:
|
||
|
|
test_dataset = dataset["test"]
|
||
|
|
logger.info(f"Using 'test' split with {len(test_dataset)} samples")
|
||
|
|
splits_data["test"] = list(test_dataset)
|
||
|
|
elif config.test_split_from == "use_val_if_available" and "validation" in available_splits:
|
||
|
|
test_dataset = dataset["validation"]
|
||
|
|
logger.info(f"Using 'validation' split as test with {len(test_dataset)} samples")
|
||
|
|
splits_data["test"] = list(test_dataset)
|
||
|
|
elif config.test_split_from == "use_val_if_available" and "val" in available_splits:
|
||
|
|
test_dataset = dataset["val"]
|
||
|
|
logger.info(f"Using 'val' split as test with {len(test_dataset)} samples")
|
||
|
|
splits_data["test"] = list(test_dataset)
|
||
|
|
elif config.test_split_from == "use_test_if_available":
|
||
|
|
logger.warning("No test split found in dataset. Will create from train split.")
|
||
|
|
logger.info(f"Available splits: {available_splits}")
|
||
|
|
logger.info(f"Will use {config.test_split * 100}% of train data for test")
|
||
|
|
else:
|
||
|
|
logger.info(f"Will create test split from train data ({config.test_split * 100}%)")
|
||
|
|
|
||
|
|
# If we need to create splits from train data
|
||
|
|
if not splits_data["validation"] or not splits_data["test"]:
|
||
|
|
train_data = splits_data["train"]
|
||
|
|
|
||
|
|
# Calculate remaining percentages for train
|
||
|
|
total_train_percentage = config.train_split + config.validation_split + config.test_split
|
||
|
|
if total_train_percentage != 1.0:
|
||
|
|
logger.warning(f"Split percentages don't sum to 1.0 (got {total_train_percentage}). Normalizing...")
|
||
|
|
# Normalize percentages
|
||
|
|
config.train_split = config.train_split / total_train_percentage
|
||
|
|
config.validation_split = config.validation_split / total_train_percentage
|
||
|
|
config.test_split = config.test_split / total_train_percentage
|
||
|
|
|
||
|
|
# Create splits from train data
|
||
|
|
if not splits_data["validation"] and not splits_data["test"]:
|
||
|
|
# Split train into train, val, test
|
||
|
|
train_size = int(len(train_data) * config.train_split)
|
||
|
|
val_size = int(len(train_data) * config.validation_split)
|
||
|
|
|
||
|
|
# First split: train + (val+test)
|
||
|
|
new_train, temp_data = train_test_split(
|
||
|
|
train_data,
|
||
|
|
test_size=config.validation_split + config.test_split,
|
||
|
|
random_state=42,
|
||
|
|
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
|
||
|
|
)
|
||
|
|
|
||
|
|
# Second split: val + test
|
||
|
|
new_val, new_test = train_test_split(
|
||
|
|
temp_data,
|
||
|
|
test_size=config.test_split / (config.validation_split + config.test_split),
|
||
|
|
random_state=42,
|
||
|
|
stratify=[item.get(config.label_field) for item in temp_data] if config.label_field in temp_data[0] else None
|
||
|
|
)
|
||
|
|
|
||
|
|
splits_data["train"] = new_train
|
||
|
|
splits_data["validation"] = new_val
|
||
|
|
splits_data["test"] = new_test
|
||
|
|
|
||
|
|
elif not splits_data["validation"]:
|
||
|
|
# Only need to create val from train
|
||
|
|
new_train, new_val = train_test_split(
|
||
|
|
train_data,
|
||
|
|
test_size=config.validation_split,
|
||
|
|
random_state=42,
|
||
|
|
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
|
||
|
|
)
|
||
|
|
splits_data["train"] = new_train
|
||
|
|
splits_data["validation"] = new_val
|
||
|
|
|
||
|
|
elif not splits_data["test"]:
|
||
|
|
# Only need to create test from train
|
||
|
|
new_train, new_test = train_test_split(
|
||
|
|
train_data,
|
||
|
|
test_size=config.test_split,
|
||
|
|
random_state=42,
|
||
|
|
stratify=[item.get(config.label_field) for item in train_data] if config.label_field in train_data[0] else None
|
||
|
|
)
|
||
|
|
splits_data["train"] = new_train
|
||
|
|
splits_data["test"] = new_test
|
||
|
|
|
||
|
|
logger.info(f"Final split sizes:")
|
||
|
|
logger.info(f" Train: {len(splits_data['train'])} samples")
|
||
|
|
logger.info(f" Validation: {len(splits_data['validation'])} samples")
|
||
|
|
logger.info(f" Test: {len(splits_data['test'])} samples")
|
||
|
|
|
||
|
|
# Apply max_samples limit to each split if specified
|
||
|
|
if config.max_samples:
|
||
|
|
for split_name in splits_data:
|
||
|
|
if splits_data[split_name]:
|
||
|
|
original_size = len(splits_data[split_name])
|
||
|
|
splits_data[split_name] = splits_data[split_name][:config.max_samples]
|
||
|
|
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
|
||
|
|
|
||
|
|
# Log dataset info for debugging
|
||
|
|
for split_name, split_data in splits_data.items():
|
||
|
|
if split_data:
|
||
|
|
logger.info(f"Sample data item from {split_name}: {split_data[0]}")
|
||
|
|
logger.info(f"Available fields in {split_name} split: {list(split_data[0].keys())}")
|
||
|
|
|
||
|
|
# Check if the required fields exist
|
||
|
|
if config.input_field not in split_data[0]:
|
||
|
|
logger.warning(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
|
||
|
|
# Suggest alternative fields
|
||
|
|
text_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['text', 'sentence', 'content', 'input', 'comment', 'message'])]
|
||
|
|
if text_fields:
|
||
|
|
logger.info(f"Suggested text fields for {split_name}: {text_fields}")
|
||
|
|
if config.label_field not in split_data[0]:
|
||
|
|
logger.warning(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {list(split_data[0].keys())}")
|
||
|
|
# Suggest alternative fields
|
||
|
|
label_fields = [f for f in split_data[0].keys() if any(keyword in f.lower() for keyword in ['label', 'class', 'category', 'target', 'emotion', 'labels'])]
|
||
|
|
if label_fields:
|
||
|
|
logger.info(f"Suggested label fields for {split_name}: {label_fields}")
|
||
|
|
|
||
|
|
logger.info(f"Successfully loaded dataset {config.dataset_name}")
|
||
|
|
return splits_data
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Apply preprocessing steps to all splits separately"""
|
||
|
|
processed_splits = {}
|
||
|
|
|
||
|
|
logger.info(f"=== PREPROCESSING DATA ===")
|
||
|
|
|
||
|
|
for split_name, split_data in data.items():
|
||
|
|
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
|
||
|
|
|
||
|
|
# Log field availability for debugging
|
||
|
|
if split_data:
|
||
|
|
available_fields = set(split_data[0].keys())
|
||
|
|
logger.info(f"Available fields in {split_name}: {available_fields}")
|
||
|
|
logger.info(f"Looking for input field: '{config.input_field}', label field: '{config.label_field}'")
|
||
|
|
|
||
|
|
if config.input_field not in available_fields:
|
||
|
|
logger.error(f"Input field '{config.input_field}' not found in {split_name}. Available fields: {available_fields}")
|
||
|
|
if config.label_field not in available_fields:
|
||
|
|
logger.error(f"Label field '{config.label_field}' not found in {split_name}. Available fields: {available_fields}")
|
||
|
|
|
||
|
|
# Count items with missing fields
|
||
|
|
missing_input = sum(1 for item in split_data if config.input_field not in item or not item.get(config.input_field))
|
||
|
|
missing_label = sum(1 for item in split_data if config.label_field not in item or item.get(config.label_field) is None)
|
||
|
|
|
||
|
|
logger.info(f"{split_name} - Items missing input field: {missing_input}")
|
||
|
|
logger.info(f"{split_name} - Items missing label field: {missing_label}")
|
||
|
|
|
||
|
|
# Show sample of raw data before preprocessing
|
||
|
|
logger.info(f"=== SAMPLE RAW DATA FROM {split_name.upper()} BEFORE PREPROCESSING ===")
|
||
|
|
for i in range(min(3, len(split_data))):
|
||
|
|
item = split_data[i]
|
||
|
|
logger.info(f"Raw item {i} from {split_name}:")
|
||
|
|
for key, value in item.items():
|
||
|
|
if isinstance(value, str) and len(value) > 100:
|
||
|
|
logger.info(f" {key}: '{value[:100]}...'")
|
||
|
|
else:
|
||
|
|
logger.info(f" {key}: {value}")
|
||
|
|
|
||
|
|
# Process each item in the split
|
||
|
|
processed_data = []
|
||
|
|
processed_count = 0
|
||
|
|
skipped_count = 0
|
||
|
|
|
||
|
|
# Reset debug counter for each split
|
||
|
|
self._debug_count = 0
|
||
|
|
|
||
|
|
for i, item in enumerate(split_data):
|
||
|
|
processed_item = self._preprocess_item(item, config)
|
||
|
|
if processed_item is not None:
|
||
|
|
processed_data.append(processed_item)
|
||
|
|
processed_count += 1
|
||
|
|
else:
|
||
|
|
skipped_count += 1
|
||
|
|
if skipped_count <= 3: # Log first few skipped items
|
||
|
|
logger.info(f"Skipped item {i} from {split_name}: {item}")
|
||
|
|
|
||
|
|
processed_splits[split_name] = processed_data
|
||
|
|
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
||
|
|
|
||
|
|
# Show sample of processed data
|
||
|
|
if processed_data:
|
||
|
|
logger.info(f"=== SAMPLE PROCESSED DATA FROM {split_name.upper()} ===")
|
||
|
|
for i in range(min(3, len(processed_data))):
|
||
|
|
logger.info(f"Processed item {i} from {split_name}: {processed_data[i]}")
|
||
|
|
|
||
|
|
return processed_splits
|
||
|
|
|
||
|
|
def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]:
|
||
|
|
"""Preprocess a single item"""
|
||
|
|
# Extract input and label
|
||
|
|
input_text = item.get(config.input_field, "")
|
||
|
|
label = item.get(config.label_field, "")
|
||
|
|
|
||
|
|
# Log what we're extracting (for first few items)
|
||
|
|
if hasattr(self, '_debug_count'):
|
||
|
|
self._debug_count += 1
|
||
|
|
else:
|
||
|
|
self._debug_count = 1
|
||
|
|
|
||
|
|
if self._debug_count <= 3:
|
||
|
|
logger.debug(f"Processing item {self._debug_count}:")
|
||
|
|
logger.debug(f" Looking for input field '{config.input_field}': {input_text}")
|
||
|
|
logger.debug(f" Looking for label field '{config.label_field}': {label}")
|
||
|
|
|
||
|
|
# Handle None values
|
||
|
|
if input_text is None:
|
||
|
|
input_text = ""
|
||
|
|
if label is None:
|
||
|
|
label = ""
|
||
|
|
|
||
|
|
# Convert to string if needed
|
||
|
|
input_text = str(input_text)
|
||
|
|
label = str(label)
|
||
|
|
|
||
|
|
if self._debug_count <= 3:
|
||
|
|
logger.debug(f" After conversion - input: '{input_text[:50]}...', label: '{label}'")
|
||
|
|
|
||
|
|
# Clean text if requested
|
||
|
|
if config.clean_text:
|
||
|
|
original_text = input_text
|
||
|
|
input_text = self._clean_text(input_text, config)
|
||
|
|
if self._debug_count <= 3:
|
||
|
|
logger.debug(f" After cleaning - original: '{original_text[:50]}...', cleaned: '{input_text[:50]}...'")
|
||
|
|
|
||
|
|
# Check length constraints
|
||
|
|
if len(input_text) < config.min_length or len(input_text) > config.max_length:
|
||
|
|
if self._debug_count <= 3:
|
||
|
|
logger.debug(f" Skipping - length {len(input_text)} not in range [{config.min_length}, {config.max_length}]")
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Create processed item
|
||
|
|
processed_item = {
|
||
|
|
"input": input_text,
|
||
|
|
"label": label
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add ID if available
|
||
|
|
if config.id_field and config.id_field in item:
|
||
|
|
processed_item["id"] = item[config.id_field]
|
||
|
|
|
||
|
|
if self._debug_count <= 3:
|
||
|
|
logger.debug(f" Final processed item: {processed_item}")
|
||
|
|
|
||
|
|
return processed_item
|
||
|
|
|
||
|
|
def _clean_text(self, text: str, config: ClassificationConfig) -> str:
|
||
|
|
"""Clean and normalize text"""
|
||
|
|
if not isinstance(text, str):
|
||
|
|
return ""
|
||
|
|
|
||
|
|
# Remove extra whitespace
|
||
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
|
|
||
|
|
# Convert to lowercase if requested
|
||
|
|
if config.lowercase:
|
||
|
|
text = text.lower()
|
||
|
|
|
||
|
|
# Remove special characters if requested
|
||
|
|
if config.remove_special_chars:
|
||
|
|
text = re.sub(r'[^\w\s]', '', text)
|
||
|
|
|
||
|
|
return text
|
||
|
|
|
||
|
|
|
||
|
|
def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
|
||
|
|
"""Helper function to create a HuggingFace configuration"""
|
||
|
|
return ClassificationConfig(
|
||
|
|
data_source="huggingface",
|
||
|
|
dataset_name=dataset_name,
|
||
|
|
input_field=input_field,
|
||
|
|
label_field=label_field,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class CustomDataLoader(BaseDataLoader):
|
||
|
|
"""Load custom datasets from local files"""
|
||
|
|
|
||
|
|
def load(self, config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Load custom dataset from local file and create splits"""
|
||
|
|
if not config.data_path:
|
||
|
|
raise ValueError("Data path is required for custom datasets")
|
||
|
|
|
||
|
|
file_path = Path(config.data_path)
|
||
|
|
|
||
|
|
if not file_path.exists():
|
||
|
|
raise FileNotFoundError(f"Data file not found: {file_path}")
|
||
|
|
|
||
|
|
logger.info(f"Loading custom dataset: {file_path}")
|
||
|
|
|
||
|
|
if config.data_format == "jsonl":
|
||
|
|
raw_data = self._load_jsonl(file_path, config)
|
||
|
|
elif config.data_format == "csv":
|
||
|
|
raw_data = self._load_csv(file_path, config)
|
||
|
|
elif config.data_format == "json":
|
||
|
|
raw_data = self._load_json(file_path, config)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported format: {config.data_format}")
|
||
|
|
|
||
|
|
if config.max_samples:
|
||
|
|
raw_data = raw_data[:config.max_samples]
|
||
|
|
|
||
|
|
logger.info(f"Loaded {len(raw_data)} samples from {file_path}")
|
||
|
|
|
||
|
|
# Create splits from the raw data
|
||
|
|
splits_data = self._create_splits(raw_data, config)
|
||
|
|
|
||
|
|
return splits_data
|
||
|
|
|
||
|
|
def _create_splits(self, data: List[Dict], config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Create train/validation/test splits from raw data"""
|
||
|
|
logger.info(f"Creating splits from {len(data)} samples...")
|
||
|
|
|
||
|
|
# Calculate split sizes
|
||
|
|
total_samples = len(data)
|
||
|
|
train_size = int(total_samples * config.train_split)
|
||
|
|
val_size = int(total_samples * config.validation_split)
|
||
|
|
test_size = total_samples - train_size - val_size
|
||
|
|
|
||
|
|
# Create stratified splits if possible
|
||
|
|
try:
|
||
|
|
labels = [item.get(config.label_field) for item in data]
|
||
|
|
|
||
|
|
# First split: train + (val+test)
|
||
|
|
train_data, temp_data = train_test_split(
|
||
|
|
data,
|
||
|
|
test_size=config.validation_split + config.test_split,
|
||
|
|
random_state=42,
|
||
|
|
stratify=labels
|
||
|
|
)
|
||
|
|
|
||
|
|
# Second split: val + test
|
||
|
|
temp_labels = [item.get(config.label_field) for item in temp_data]
|
||
|
|
val_data, test_data = train_test_split(
|
||
|
|
temp_data,
|
||
|
|
test_size=config.test_split / (config.validation_split + config.test_split),
|
||
|
|
random_state=42,
|
||
|
|
stratify=temp_labels
|
||
|
|
)
|
||
|
|
|
||
|
|
except ValueError as e:
|
||
|
|
logger.warning(f"Could not create stratified splits: {e}. Using random splits.")
|
||
|
|
# Fallback to random splits
|
||
|
|
train_data, temp_data = train_test_split(data, test_size=config.validation_split + config.test_split, random_state=42)
|
||
|
|
val_data, test_data = train_test_split(temp_data, test_size=config.test_split / (config.validation_split + config.test_split), random_state=42)
|
||
|
|
|
||
|
|
splits_data = {
|
||
|
|
"train": train_data,
|
||
|
|
"validation": val_data,
|
||
|
|
"test": test_data
|
||
|
|
}
|
||
|
|
|
||
|
|
logger.info(f"Created splits:")
|
||
|
|
logger.info(f" Train: {len(splits_data['train'])} samples")
|
||
|
|
logger.info(f" Validation: {len(splits_data['validation'])} samples")
|
||
|
|
logger.info(f" Test: {len(splits_data['test'])} samples")
|
||
|
|
|
||
|
|
return splits_data
|
||
|
|
|
||
|
|
def _load_jsonl(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
|
||
|
|
"""Load JSONL file"""
|
||
|
|
data = []
|
||
|
|
with open(file_path, 'r', encoding=config.encoding) as f:
|
||
|
|
for line_num, line in enumerate(f, 1):
|
||
|
|
if line.strip():
|
||
|
|
try:
|
||
|
|
data.append(json.loads(line))
|
||
|
|
except json.JSONDecodeError as e:
|
||
|
|
logger.warning(f"Invalid JSON at line {line_num}: {e}")
|
||
|
|
return data
|
||
|
|
|
||
|
|
def _load_csv(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
|
||
|
|
"""Load CSV file"""
|
||
|
|
df = pd.read_csv(file_path, encoding=config.encoding, delimiter=config.delimiter)
|
||
|
|
return df.to_dict('records')
|
||
|
|
|
||
|
|
def _load_json(self, file_path: Path, config: ClassificationConfig) -> List[Dict]:
|
||
|
|
"""Load JSON file"""
|
||
|
|
with open(file_path, 'r', encoding=config.encoding) as f:
|
||
|
|
data = json.load(f)
|
||
|
|
|
||
|
|
if isinstance(data, list):
|
||
|
|
return data
|
||
|
|
elif isinstance(data, dict) and "data" in data:
|
||
|
|
return data["data"]
|
||
|
|
else:
|
||
|
|
return [data]
|
||
|
|
|
||
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: ClassificationConfig) -> Dict[str, List[Dict]]:
|
||
|
|
"""Apply preprocessing steps to all splits separately"""
|
||
|
|
processed_splits = {}
|
||
|
|
|
||
|
|
logger.info(f"=== PREPROCESSING CUSTOM DATA ===")
|
||
|
|
|
||
|
|
for split_name, split_data in data.items():
|
||
|
|
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
|
||
|
|
|
||
|
|
processed_data = []
|
||
|
|
processed_count = 0
|
||
|
|
skipped_count = 0
|
||
|
|
|
||
|
|
# Reset debug counter for each split
|
||
|
|
self._debug_count = 0
|
||
|
|
|
||
|
|
for i, item in enumerate(split_data):
|
||
|
|
processed_item = self._preprocess_item(item, config)
|
||
|
|
if processed_item is not None:
|
||
|
|
processed_data.append(processed_item)
|
||
|
|
processed_count += 1
|
||
|
|
else:
|
||
|
|
skipped_count += 1
|
||
|
|
if skipped_count <= 3: # Log first few skipped items
|
||
|
|
logger.info(f"Skipped item {i} from {split_name}: {item}")
|
||
|
|
|
||
|
|
processed_splits[split_name] = processed_data
|
||
|
|
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
||
|
|
|
||
|
|
return processed_splits
|
||
|
|
|
||
|
|
def _preprocess_item(self, item: Dict, config: ClassificationConfig) -> Optional[Dict]:
|
||
|
|
"""Preprocess a single item"""
|
||
|
|
# Extract input and label
|
||
|
|
input_text = item.get(config.input_field, "")
|
||
|
|
label = item.get(config.label_field, "")
|
||
|
|
|
||
|
|
# Handle None values
|
||
|
|
if input_text is None:
|
||
|
|
input_text = ""
|
||
|
|
if label is None:
|
||
|
|
label = ""
|
||
|
|
|
||
|
|
# Convert to string if needed
|
||
|
|
input_text = str(input_text)
|
||
|
|
label = str(label)
|
||
|
|
|
||
|
|
# Clean text if requested
|
||
|
|
if config.clean_text:
|
||
|
|
input_text = self._clean_text(input_text, config)
|
||
|
|
|
||
|
|
# Check length constraints
|
||
|
|
if len(input_text) < config.min_length or len(input_text) > config.max_length:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Create processed item
|
||
|
|
processed_item = {
|
||
|
|
"input": input_text,
|
||
|
|
"label": label
|
||
|
|
}
|
||
|
|
|
||
|
|
# Add ID if available
|
||
|
|
if config.id_field and config.id_field in item:
|
||
|
|
processed_item["id"] = item[config.id_field]
|
||
|
|
|
||
|
|
return processed_item
|
||
|
|
|
||
|
|
def _clean_text(self, text: str, config: ClassificationConfig) -> str:
|
||
|
|
"""Clean and normalize text"""
|
||
|
|
if not isinstance(text, str):
|
||
|
|
return ""
|
||
|
|
|
||
|
|
# Remove extra whitespace
|
||
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
|
|
||
|
|
# Convert to lowercase if requested
|
||
|
|
if config.lowercase:
|
||
|
|
text = text.lower()
|
||
|
|
|
||
|
|
# Remove special characters if requested
|
||
|
|
if config.remove_special_chars:
|
||
|
|
text = re.sub(r'[^\w\s]', '', text)
|
||
|
|
|
||
|
|
return text
|
||
|
|
|
||
|
|
|
||
|
|
class ClassificationDataPipeline:
|
||
|
|
"""Main classification pipeline"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.validator = DataValidator()
|
||
|
|
self.hf_loader = HuggingFaceDataLoader()
|
||
|
|
self.custom_loader = CustomDataLoader()
|
||
|
|
|
||
|
|
def create_config(
|
||
|
|
self,
|
||
|
|
data_source: str,
|
||
|
|
dataset_name: Optional[str] = None,
|
||
|
|
data_path: Optional[str] = None,
|
||
|
|
input_field: str = "text",
|
||
|
|
label_field: str = "label",
|
||
|
|
**kwargs
|
||
|
|
) -> ClassificationConfig:
|
||
|
|
"""Create classification configuration"""
|
||
|
|
return ClassificationConfig(
|
||
|
|
data_source=data_source,
|
||
|
|
dataset_name=dataset_name,
|
||
|
|
data_path=data_path,
|
||
|
|
input_field=input_field,
|
||
|
|
label_field=label_field,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
def load_and_preprocess(self, config: ClassificationConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
|
||
|
|
"""Load and preprocess data"""
|
||
|
|
|
||
|
|
# Load data
|
||
|
|
if config.data_source == "huggingface":
|
||
|
|
raw_splits = self.hf_loader.load(config)
|
||
|
|
processed_splits = self.hf_loader.preprocess(raw_splits, config)
|
||
|
|
elif config.data_source == "custom":
|
||
|
|
raw_splits = self.custom_loader.load(config)
|
||
|
|
processed_splits = self.custom_loader.preprocess(raw_splits, config)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unsupported data source: {config.data_source}")
|
||
|
|
|
||
|
|
# Validate processed data
|
||
|
|
is_valid, errors = self.validator.validate_classification_data(processed_splits, config, is_processed=True)
|
||
|
|
if not is_valid:
|
||
|
|
logger.error("Data validation failed:")
|
||
|
|
for error in errors:
|
||
|
|
logger.error(f" - {error}")
|
||
|
|
raise ValueError("Data validation failed")
|
||
|
|
|
||
|
|
# Analyze dataset
|
||
|
|
analysis = self.validator.analyze_dataset(processed_splits, config, is_processed=True)
|
||
|
|
|
||
|
|
return processed_splits, analysis
|
||
|
|
|
||
|
|
def convert_to_classification_format(self, data: Dict[str, List[Dict]]) -> Dict[str, List[Dict]]:
|
||
|
|
"""Convert classification data to standard classification format"""
|
||
|
|
classification_splits = {}
|
||
|
|
|
||
|
|
for split_name, split_data in data.items():
|
||
|
|
classification_data = []
|
||
|
|
for item in split_data:
|
||
|
|
classification_data.append({
|
||
|
|
"text": item["input"],
|
||
|
|
"label": item["label"]
|
||
|
|
})
|
||
|
|
classification_splits[split_name] = classification_data
|
||
|
|
|
||
|
|
return classification_splits
|
||
|
|
|
||
|
|
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
|
||
|
|
"""Save processed data splits to files"""
|
||
|
|
output_path = Path(output_dir)
|
||
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
for split_name, split_data in data.items():
|
||
|
|
if format == "jsonl":
|
||
|
|
output_file = output_path / f"{split_name}.jsonl"
|
||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||
|
|
for item in split_data:
|
||
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||
|
|
elif format == "json":
|
||
|
|
output_file = output_path / f"{split_name}.json"
|
||
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(split_data, f, ensure_ascii=False, indent=2)
|
||
|
|
elif format == "csv":
|
||
|
|
output_file = output_path / f"{split_name}.csv"
|
||
|
|
df = pd.DataFrame(split_data)
|
||
|
|
df.to_csv(output_file, index=False)
|
||
|
|
|
||
|
|
logger.info(f"Saved {len(split_data)} samples to {output_file}")
|
||
|
|
|
||
|
|
def run_pipeline(
|
||
|
|
self,
|
||
|
|
config: ClassificationConfig,
|
||
|
|
output_format: str = "classification",
|
||
|
|
save_splits: bool = True
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
"""Run complete classification pipeline"""
|
||
|
|
|
||
|
|
logger.info("Starting classification pipeline...")
|
||
|
|
|
||
|
|
# Load and preprocess data
|
||
|
|
processed_splits, analysis = self.load_and_preprocess(config)
|
||
|
|
|
||
|
|
# Convert to desired output format
|
||
|
|
if output_format == "classification":
|
||
|
|
formatted_splits = self.convert_to_classification_format(processed_splits)
|
||
|
|
else:
|
||
|
|
formatted_splits = processed_splits
|
||
|
|
|
||
|
|
# Save data if requested
|
||
|
|
if save_splits:
|
||
|
|
output_dir = Path(config.output_dir) / output_format
|
||
|
|
self.save_data(formatted_splits, str(output_dir))
|
||
|
|
|
||
|
|
# Create result summary
|
||
|
|
result = {
|
||
|
|
"config": config,
|
||
|
|
"analysis": analysis,
|
||
|
|
"splits": {
|
||
|
|
split_name: len(split_data) for split_name, split_data in formatted_splits.items()
|
||
|
|
},
|
||
|
|
"output_format": output_format,
|
||
|
|
"output_dir": config.output_dir,
|
||
|
|
"data": formatted_splits # Include the actual processed data
|
||
|
|
}
|
||
|
|
|
||
|
|
logger.info("Classification pipeline completed successfully!")
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def create_huggingface_config(dataset_name: str, input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
|
||
|
|
"""Helper function to create a HuggingFace configuration"""
|
||
|
|
return ClassificationConfig(
|
||
|
|
data_source="huggingface",
|
||
|
|
dataset_name=dataset_name,
|
||
|
|
input_field=input_field,
|
||
|
|
label_field=label_field,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def create_custom_config(data_path: str, data_format: str = "jsonl", input_field: str = "text", label_field: str = "label", **kwargs) -> ClassificationConfig:
|
||
|
|
"""Helper function to create a custom data configuration"""
|
||
|
|
return ClassificationConfig(
|
||
|
|
data_source="custom",
|
||
|
|
data_path=data_path,
|
||
|
|
data_format=data_format,
|
||
|
|
input_field=input_field,
|
||
|
|
label_field=label_field,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main function with YAML configuration support"""
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser(description="Classification Data Processing Pipeline")
|
||
|
|
|
||
|
|
# YAML configuration
|
||
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||
|
|
|
||
|
|
# Data source arguments
|
||
|
|
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
|
||
|
|
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
|
||
|
|
parser.add_argument("--data-path", type=str, help="Path to custom data file")
|
||
|
|
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
|
||
|
|
|
||
|
|
# Field mapping
|
||
|
|
parser.add_argument("--input-field", type=str, help="Input field name")
|
||
|
|
parser.add_argument("--label-field", type=str, help="Label field name")
|
||
|
|
parser.add_argument("--id-field", type=str, help="Optional ID field name")
|
||
|
|
|
||
|
|
# Data processing
|
||
|
|
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
|
||
|
|
parser.add_argument("--train-split", type=float, help="Training split ratio")
|
||
|
|
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
|
||
|
|
parser.add_argument("--test-split", type=float, help="Test split ratio")
|
||
|
|
|
||
|
|
# Text preprocessing
|
||
|
|
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
|
||
|
|
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
|
||
|
|
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
|
||
|
|
parser.add_argument("--min-length", type=int, help="Minimum text length")
|
||
|
|
parser.add_argument("--max-length", type=int, help="Maximum text length")
|
||
|
|
|
||
|
|
# Output configuration
|
||
|
|
parser.add_argument("--output-format", choices=["classification", "instruction", "conversation", "qa"], help="Output format")
|
||
|
|
parser.add_argument("--output-dir", type=str, help="Output directory")
|
||
|
|
|
||
|
|
# Logging
|
||
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Set up logging
|
||
|
|
logging.basicConfig(
|
||
|
|
level=getattr(logging, args.log_level),
|
||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
|
|
)
|
||
|
|
|
||
|
|
# Load configuration
|
||
|
|
config_dict = {}
|
||
|
|
|
||
|
|
# Load YAML config if provided
|
||
|
|
if args.config:
|
||
|
|
try:
|
||
|
|
with open(args.config, 'r', encoding='utf-8') as f:
|
||
|
|
config_dict = yaml.safe_load(f)
|
||
|
|
logger.info(f"Loaded YAML configuration from: {args.config}")
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error loading YAML config: {e}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Override YAML config with CLI arguments
|
||
|
|
cli_overrides = {}
|
||
|
|
if args.data_source:
|
||
|
|
cli_overrides['data_source'] = args.data_source
|
||
|
|
if args.dataset_name:
|
||
|
|
cli_overrides['dataset_name'] = args.dataset_name
|
||
|
|
if args.data_path:
|
||
|
|
cli_overrides['data_path'] = args.data_path
|
||
|
|
if args.data_format:
|
||
|
|
cli_overrides['data_format'] = args.data_format
|
||
|
|
if args.input_field:
|
||
|
|
cli_overrides['input_field'] = args.input_field
|
||
|
|
if args.label_field:
|
||
|
|
cli_overrides['label_field'] = args.label_field
|
||
|
|
if args.id_field:
|
||
|
|
cli_overrides['id_field'] = args.id_field
|
||
|
|
if args.max_samples:
|
||
|
|
cli_overrides['max_samples'] = args.max_samples
|
||
|
|
if args.train_split:
|
||
|
|
cli_overrides['train_split'] = args.train_split
|
||
|
|
if args.validation_split:
|
||
|
|
cli_overrides['validation_split'] = args.validation_split
|
||
|
|
if args.test_split:
|
||
|
|
cli_overrides['test_split'] = args.test_split
|
||
|
|
if args.clean_text:
|
||
|
|
cli_overrides['clean_text'] = True
|
||
|
|
if args.remove_special_chars:
|
||
|
|
cli_overrides['remove_special_chars'] = True
|
||
|
|
if args.lowercase:
|
||
|
|
cli_overrides['lowercase'] = True
|
||
|
|
if args.min_length:
|
||
|
|
cli_overrides['min_length'] = args.min_length
|
||
|
|
if args.max_length:
|
||
|
|
cli_overrides['max_length'] = args.max_length
|
||
|
|
if args.output_format:
|
||
|
|
cli_overrides['output_format'] = args.output_format
|
||
|
|
if args.output_dir:
|
||
|
|
cli_overrides['output_dir'] = args.output_dir
|
||
|
|
|
||
|
|
# Merge configurations
|
||
|
|
for key, value in cli_overrides.items():
|
||
|
|
if key in config_dict:
|
||
|
|
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||
|
|
config_dict[key] = value
|
||
|
|
|
||
|
|
# Validate required arguments
|
||
|
|
if not config_dict.get('data', {}).get('source'):
|
||
|
|
parser.error("--data-source is required (either in YAML config or CLI)")
|
||
|
|
|
||
|
|
if config_dict.get('data', {}).get('source') == "huggingface" and not config_dict.get('data', {}).get('dataset_name'):
|
||
|
|
parser.error("--dataset-name is required for HuggingFace datasets")
|
||
|
|
|
||
|
|
if config_dict.get('data', {}).get('source') == "custom" and not config_dict.get('data', {}).get('data_path'):
|
||
|
|
parser.error("--data-path is required for custom datasets")
|
||
|
|
|
||
|
|
# Create configuration object
|
||
|
|
config = ClassificationConfig(
|
||
|
|
data_source=config_dict.get('data', {}).get('source', 'huggingface'),
|
||
|
|
dataset_name=config_dict.get('data', {}).get('dataset_name'),
|
||
|
|
data_path=config_dict.get('data', {}).get('data_path'),
|
||
|
|
data_format=config_dict.get('data', {}).get('data_format', 'jsonl'),
|
||
|
|
input_field=config_dict.get('data', {}).get('input_field', 'text'),
|
||
|
|
label_field=config_dict.get('data', {}).get('label_field', 'label'),
|
||
|
|
id_field=config_dict.get('data', {}).get('id_field'),
|
||
|
|
max_samples=config_dict.get('data', {}).get('max_samples'),
|
||
|
|
train_split=config_dict.get('data', {}).get('train_split', 0.8),
|
||
|
|
validation_split=config_dict.get('data', {}).get('validation_split', 0.1),
|
||
|
|
test_split=config_dict.get('data', {}).get('test_split', 0.1),
|
||
|
|
clean_text=config_dict.get('data', {}).get('clean_text', True),
|
||
|
|
remove_special_chars=config_dict.get('data', {}).get('remove_special_chars', False),
|
||
|
|
lowercase=config_dict.get('data', {}).get('lowercase', True),
|
||
|
|
min_length=config_dict.get('data', {}).get('min_length', 10),
|
||
|
|
max_length=config_dict.get('data', {}).get('max_length', 1000),
|
||
|
|
label_encoding=config_dict.get('data', {}).get('label_encoding', 'auto'),
|
||
|
|
multilabel=config_dict.get('data', {}).get('multilabel', False),
|
||
|
|
label_separator=config_dict.get('data', {}).get('label_separator', ','),
|
||
|
|
output_format=config_dict.get('data', {}).get('output_format', 'classification'),
|
||
|
|
output_dir=config_dict.get('data', {}).get('output_dir', './data'),
|
||
|
|
hf_split=config_dict.get('data', {}).get('hf_split', 'train'),
|
||
|
|
hf_cache_dir=config_dict.get('data', {}).get('hf_cache_dir'),
|
||
|
|
test_split_from=config_dict.get('data', {}).get('test_split_from', 'train'),
|
||
|
|
val_split_from=config_dict.get('data', {}).get('val_split_from', 'train'),
|
||
|
|
encoding=config_dict.get('data', {}).get('encoding', 'utf-8'),
|
||
|
|
delimiter=config_dict.get('data', {}).get('delimiter', ',')
|
||
|
|
)
|
||
|
|
|
||
|
|
# Initialize pipeline
|
||
|
|
pipeline = ClassificationDataPipeline()
|
||
|
|
|
||
|
|
try:
|
||
|
|
print(f"Starting classification pipeline with {config.data_source} data source...")
|
||
|
|
if args.config:
|
||
|
|
print(f"Using YAML configuration: {args.config}")
|
||
|
|
print()
|
||
|
|
|
||
|
|
result = pipeline.run_pipeline(config, config.output_format, save_splits=True)
|
||
|
|
|
||
|
|
print(f"✅ Pipeline completed successfully!")
|
||
|
|
print(f" Data source: {config.data_source}")
|
||
|
|
if config.data_source == "huggingface":
|
||
|
|
print(f" Dataset: {config.dataset_name}")
|
||
|
|
else:
|
||
|
|
print(f" Data file: {config.data_path}")
|
||
|
|
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
|
||
|
|
print(f" Unique labels: {result['analysis']['overall']['all_unique_labels']}")
|
||
|
|
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
|
||
|
|
print(f" Output directory: {config.output_dir}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error running pipeline: {e}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|