Files
DS-LLM-TEMPLATE-FINETUNING/pipelines/instruct/data_processor.py
T
2025-08-28 16:46:24 +00:00

914 lines
37 KiB
Python

import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
from sklearn.model_selection import train_test_split
import re
import argparse
import sys
import yaml
@dataclass
class InstructConfig:
"""Configuration for instruction fine-tuning tasks"""
# Data source configuration
data_source: str = "custom" # "huggingface" or "custom"
dataset_name: Optional[str] = None # For Hugging Face datasets
data_path: Optional[str] = None # For custom datasets
data_format: str = "jsonl" # jsonl, json
# Field mapping - conversation data specific
conversation_field: str = "conversation" # Field containing conversation array
# Data processing
max_samples: Optional[int] = None
train_split: float = 0.8
validation_split: float = 0.1
test_split: float = 0.1
# Text preprocessing
clean_text: bool = True
min_length: int = 10
max_length: int = 2048
# Output configuration
output_format: str = "conversation" # conversation, alpaca
output_dir: str = "./data/processed/instruct"
# Hugging Face specific
hf_split: str = "train"
hf_cache_dir: Optional[str] = None
# Split configuration
test_split_from: str = "train"
val_split_from: str = "train"
# Custom data specific
encoding: str = "utf-8"
class ConversationValidator:
"""Validates conversation data quality and format"""
@staticmethod
def validate_conversation_data(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
"""Validate conversation dataset splits"""
errors = []
# Check if we have the expected splits
expected_splits = ["train", "validation", "test"]
for split in expected_splits:
if split not in data:
errors.append(f"Missing '{split}' split")
elif split == "train" and not data[split]:
errors.append(f"Train split cannot be empty")
if errors:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
print(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
conversation_field = "conversation" if not is_processed else "conversation"
# Validate each split
for split_name, split_data in data.items():
if not split_data:
print(f"Skipping validation for empty {split_name} split")
continue
print(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_conversation_count = 0
for i, item in enumerate(split_data):
if conversation_field not in item:
errors.append(f"Missing conversation field '{conversation_field}' in {split_name} split, item {i}")
missing_conversation_count += 1
else:
# Validate conversation structure
conversation = item[conversation_field]
if not isinstance(conversation, list):
errors.append(f"Conversation field must be a list in {split_name} split, item {i}")
else:
# Validate each turn in conversation
for j, turn in enumerate(conversation):
if not isinstance(turn, dict):
errors.append(f"Each conversation turn must be a dict in {split_name} split, item {i}, turn {j}")
continue
# Check for required fields in conversation turn
if "role" not in turn:
errors.append(f"Missing 'role' field in conversation turn {j}, {split_name} split, item {i}")
if "content" not in turn:
errors.append(f"Missing 'content' field in conversation turn {j}, {split_name} split, item {i}")
# Validate role values
if "role" in turn and turn["role"] not in ["user", "assistant", "system"]:
errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'")
print(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
# Show sample of processed data for debugging
if split_data:
print(f"Sample conversation from {split_name}:")
for i in range(min(2, len(split_data))):
item = split_data[i]
conversation = item.get(conversation_field, [])
print(f" Item {i} conversation length: {len(conversation)} turns")
for j, turn in enumerate(conversation[:3]): # Show first 3 turns
role = turn.get("role", "unknown")
content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "")
print(f" Turn {j}: {role} -> '{content}'")
return len(errors) == 0, errors
@staticmethod
def analyze_conversation_dataset(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Dict[str, Any]:
"""Analyze conversation dataset characteristics across all splits"""
analysis = {
"splits": {},
"overall": {
"total_samples": 0,
"split_sizes": {},
"conversation_stats": {
"total_turns": 0,
"avg_turns_per_conversation": 0,
"role_distribution": {"user": 0, "assistant": 0, "system": 0}
}
}
}
conversation_field = "conversation" if not is_processed else "conversation"
total_turns = 0
total_conversations = 0
role_counts = {"user": 0, "assistant": 0, "system": 0}
# Analyze each split
for split_name, split_data in data.items():
if not split_data:
split_analysis = {
"total_samples": 0,
"conversation_stats": {},
"missing_values": {}
}
analysis["splits"][split_name] = split_analysis
analysis["overall"]["split_sizes"][split_name] = 0
continue
split_analysis = {
"total_samples": len(split_data),
"conversation_stats": {},
"missing_values": {}
}
# Conversation statistics
split_turns = 0
split_conversations = len(split_data)
split_role_counts = {"user": 0, "assistant": 0, "system": 0}
conversation_lengths = []
for item in split_data:
conversation = item.get(conversation_field, [])
if isinstance(conversation, list):
conversation_lengths.append(len(conversation))
split_turns += len(conversation)
for turn in conversation:
if isinstance(turn, dict) and "role" in turn:
role = turn["role"]
if role in split_role_counts:
split_role_counts[role] += 1
if conversation_lengths:
split_analysis["conversation_stats"] = {
"total_turns": split_turns,
"avg_turns_per_conversation": np.mean(conversation_lengths),
"min_turns": min(conversation_lengths),
"max_turns": max(conversation_lengths),
"median_turns": np.median(conversation_lengths),
"role_distribution": split_role_counts
}
# Missing values
missing_count = sum(1 for item in split_data if not item.get(conversation_field))
split_analysis["missing_values"][conversation_field] = missing_count
analysis["splits"][split_name] = split_analysis
analysis["overall"]["total_samples"] += len(split_data)
analysis["overall"]["split_sizes"][split_name] = len(split_data)
# Accumulate overall stats
total_turns += split_turns
total_conversations += split_conversations
for role, count in split_role_counts.items():
role_counts[role] += count
# Calculate overall conversation stats
if total_conversations > 0:
analysis["overall"]["conversation_stats"]["total_turns"] = total_turns
analysis["overall"]["conversation_stats"]["avg_turns_per_conversation"] = total_turns / total_conversations
analysis["overall"]["conversation_stats"]["role_distribution"] = role_counts
return analysis
class BaseInstructDataLoader(ABC):
"""Abstract base class for instruction data loaders"""
@abstractmethod
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
"""Load data and return dictionary with train/val/test splits"""
pass
@abstractmethod
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits"""
pass
class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
"""Load conversation datasets from Hugging Face Hub"""
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
"""Load dataset from Hugging Face Hub with flexible split handling"""
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
print(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
try:
dataset = load_dataset(
config.dataset_name,
cache_dir=config.hf_cache_dir
)
available_splits = list(dataset.keys())
print(f"Available splits in dataset: {available_splits}")
splits_data = {
"train": [],
"validation": [],
"test": []
}
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
print(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
print("No 'train' split found in dataset!")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation and test splits (similar logic to styling pipeline)
# ... [validation and test split handling logic similar to styling pipeline]
# Apply max_samples limit if specified
if config.max_samples:
for split_name in splits_data:
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
print(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
print(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
print(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
print(f"=== PREPROCESSING CONVERSATION DATA ===")
for split_name, split_data in data.items():
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
processed_splits[split_name] = processed_data
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]:
"""Preprocess a single conversation item"""
conversation = item.get(config.conversation_field, [])
if not isinstance(conversation, list) or not conversation:
return None
# Validate conversation structure
valid_conversation = []
for turn in conversation:
if not isinstance(turn, dict):
continue
if "role" not in turn or "content" not in turn:
continue
if turn["role"] not in ["user", "assistant", "system"]:
continue
content = str(turn["content"]).strip()
if len(content) < config.min_length or len(content) > config.max_length:
continue
if config.clean_text:
content = self._clean_text(content)
valid_conversation.append({
"role": turn["role"],
"content": content
})
if len(valid_conversation) < 2: # Need at least 2 turns for a conversation
return None
return {"conversation": valid_conversation}
def _clean_text(self, text: str) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
class CustomInstructDataLoader(BaseInstructDataLoader):
"""Load custom conversation datasets from local files"""
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
"""Load custom conversation dataset from local file and create splits"""
if not config.data_path:
raise ValueError("Data path is required for custom datasets")
file_path = Path(config.data_path)
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
print(f"Loading custom conversation dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
elif config.data_format == "json":
raw_data = self._load_json(file_path, config)
else:
raise ValueError(f"Unsupported format: {config.data_format}")
if config.max_samples:
raw_data = raw_data[:config.max_samples]
print(f"Loaded {len(raw_data)} conversation samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
return splits_data
def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
print(f"Creating splits from {len(data)} conversation samples...")
# Handle very small datasets
if len(data) < 3:
print(f"Dataset has only {len(data)} samples. Using all data for training.")
return {
"train": data,
"validation": [],
"test": []
}
# Calculate split sizes
total_samples = len(data)
# Adjust split ratios if dataset is too small
if total_samples < 10:
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
print(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
val_size = max(1, int(total_samples * config.validation_split))
test_size = max(1, int(total_samples * config.test_split))
train_size = total_samples - val_size - test_size
# Ensure train split has at least 1 sample
if train_size < 1:
if val_size > 1:
val_size -= 1
train_size += 1
elif test_size > 1:
test_size -= 1
train_size += 1
print(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
# Create splits
if val_size == 0 and test_size == 0:
splits_data = {
"train": data,
"validation": [],
"test": []
}
elif val_size == 0:
train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
splits_data = {
"train": train_data,
"validation": [],
"test": test_data
}
elif test_size == 0:
train_data, val_data = train_test_split(data, test_size=val_size, random_state=42)
splits_data = {
"train": train_data,
"validation": val_data,
"test": []
}
else:
# Full three-way split
train_data, temp_data = train_test_split(
data,
test_size=val_size + test_size,
random_state=42
)
val_data, test_data = train_test_split(
temp_data,
test_size=test_size,
random_state=42
)
splits_data = {
"train": train_data,
"validation": val_data,
"test": test_data
}
print(f"Created conversation splits:")
print(f" Train: {len(splits_data['train'])} samples")
print(f" Validation: {len(splits_data['validation'])} samples")
print(f" Test: {len(splits_data['test'])} samples")
return splits_data
def _load_jsonl(self, file_path: Path, config: InstructConfig) -> List[Dict]:
"""Load JSONL file"""
data = []
with open(file_path, 'r', encoding=config.encoding) as f:
for line_num, line in enumerate(f, 1):
if line.strip():
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]:
"""Load JSON file"""
with open(file_path, 'r', encoding=config.encoding) as f:
data = json.load(f)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "data" in data:
return data["data"]
else:
return [data]
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
print(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
for split_name, split_data in data.items():
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
skipped_count = 0
for i, item in enumerate(split_data):
processed_item = self._preprocess_item(item, config)
if processed_item is not None:
processed_data.append(processed_item)
processed_count += 1
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
print(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]:
"""Preprocess a single conversation item"""
conversation = item.get(config.conversation_field, [])
if not isinstance(conversation, list) or not conversation:
return None
# Validate conversation structure
valid_conversation = []
for turn in conversation:
if not isinstance(turn, dict):
continue
if "role" not in turn or "content" not in turn:
continue
if turn["role"] not in ["user", "assistant", "system"]:
continue
content = str(turn["content"]).strip()
if len(content) < config.min_length or len(content) > config.max_length:
continue
if config.clean_text:
content = self._clean_text(content)
valid_conversation.append({
"role": turn["role"],
"content": content
})
if len(valid_conversation) < 2: # Need at least 2 turns for a conversation
return None
return {"conversation": valid_conversation}
def _clean_text(self, text: str) -> str:
"""Clean and normalize text"""
if not isinstance(text, str):
return ""
# Remove extra whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
class InstructDataPipeline:
"""Main instruction fine-tuning data pipeline"""
def __init__(self):
self.validator = ConversationValidator()
self.hf_loader = HuggingFaceInstructDataLoader()
self.custom_loader = CustomInstructDataLoader()
def create_config(
self,
data_source: str,
dataset_name: Optional[str] = None,
data_path: Optional[str] = None,
conversation_field: str = "conversation",
**kwargs
) -> InstructConfig:
"""Create instruction configuration"""
return InstructConfig(
data_source=data_source,
dataset_name=dataset_name,
data_path=data_path,
conversation_field=conversation_field,
**kwargs
)
def load_config_from_yaml(self, yaml_path: str) -> InstructConfig:
"""Load configuration from YAML file"""
try:
config_dict = load_yaml_config(yaml_path)
# Create configuration object from YAML data
config = InstructConfig(
data_source=config_dict.get('data_source', 'custom'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
conversation_field=config_dict.get('conversation_field', 'conversation'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 2048),
output_format=config_dict.get('output_format', 'conversation'),
output_dir=config_dict.get('output_dir', './data/processed/instruct'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8')
)
print(f"Configuration loaded from YAML: {yaml_path}")
print(f"Output directory: {config.output_dir}")
print(f"Conversation field: {config.conversation_field}")
return config
except Exception as e:
print(f"Error loading configuration from YAML {yaml_path}: {e}")
raise
def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess conversation data"""
print(f"Starting conversation data loading and preprocessing...")
print(f"Data source: {config.data_source}")
try:
# Load data
if config.data_source == "huggingface":
print("Loading HuggingFace conversation dataset...")
raw_splits = self.hf_loader.load(config)
print("Preprocessing HuggingFace conversation dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
print("Loading custom conversation dataset...")
raw_splits = self.custom_loader.load(config)
print("Preprocessing custom conversation dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
print(f"Conversation data loading and preprocessing completed successfully")
# Validate processed data
print("Validating processed conversation data...")
is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True)
if not is_valid:
print("Conversation data validation failed:")
for error in errors:
print(f" - {error}")
raise ValueError("Conversation data validation failed")
print("Conversation data validation passed")
# Analyze dataset
print("Analyzing conversation dataset...")
analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True)
print("Conversation dataset analysis completed")
return processed_splits, analysis
except Exception as e:
print(f"Error in load_and_preprocess: {e}")
raise
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
"""Save processed conversation data splits to files"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
for split_name, split_data in data.items():
if format == "jsonl":
output_file = output_path / f"{split_name}.jsonl"
with open(output_file, 'w', encoding='utf-8') as f:
for item in split_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
elif format == "json":
output_file = output_path / f"{split_name}.json"
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
print(f"Saved {len(split_data)} conversation samples to {output_file}")
def run_pipeline(
self,
config: InstructConfig,
save_splits: bool = True
) -> Dict[str, Any]:
"""Run complete instruction data pipeline"""
print("Starting instruction data pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
# Save data if requested
if save_splits:
output_dir = Path(config.output_dir)
self.save_data(processed_splits, str(output_dir))
# Create result summary
result = {
"config": config,
"analysis": analysis,
"splits": {
split_name: len(split_data) for split_name, split_data in processed_splits.items()
},
"output_format": config.output_format,
"output_dir": config.output_dir,
"data": processed_splits, # Include the actual processed data
}
print("Instruction data pipeline completed successfully!")
return result
def load_yaml_config(config_path: str) -> Dict[str, Any]:
"""Load and parse YAML configuration file with proper structure handling"""
try:
with open(config_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# Extract configuration from YAML structure
config_dict = {}
# Handle task section
if 'task' in yaml_data:
task_data = yaml_data['task']
config_dict.update({
'task_name': task_data.get('name'),
'task_type': task_data.get('type')
})
# Handle data section
if 'data' in yaml_data:
data_config = yaml_data['data']
config_dict.update({
'data_source': data_config.get('source'),
'dataset_name': data_config.get('dataset_name'),
'data_path': data_config.get('data_path'),
'data_format': data_config.get('data_format'),
'conversation_field': data_config.get('conversation_field'),
'max_samples': data_config.get('max_samples'),
'train_split': data_config.get('train_split'),
'validation_split': data_config.get('validation_split'),
'test_split': data_config.get('test_split'),
'clean_text': data_config.get('clean_text'),
'min_length': data_config.get('min_length'),
'max_length': data_config.get('max_length'),
'output_format': data_config.get('output_format'),
'output_dir': data_config.get('output_dir'),
'encoding': data_config.get('encoding')
})
print(f"Successfully parsed YAML configuration from: {config_path}")
print(f"Extracted {len(config_dict)} configuration parameters")
return config_dict
except Exception as e:
print(f"Error loading YAML config from {config_path}: {e}")
raise
def main():
"""Main function with YAML configuration support"""
parser = argparse.ArgumentParser(description="Instruction Data Processing Pipeline")
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "json"], help="Data format")
# Field mapping
parser.add_argument("--conversation-field", type=str, help="Conversation field name")
# Data processing
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--train-split", type=float, help="Training split ratio")
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
parser.add_argument("--test-split", type=float, help="Test split ratio")
# Output configuration
parser.add_argument("--output-dir", type=str, help="Output directory")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
args = parser.parse_args()
# Set up logging
# logging.basicConfig(
# level=getattr(logging, args.log_level),
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# Load configuration
config_dict = {}
# Load YAML config if provided
if args.config:
try:
config_dict = load_yaml_config(args.config)
except Exception as e:
print(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments (similar to styling pipeline)
cli_overrides = {}
if args.data_source:
cli_overrides['data_source'] = args.data_source
if args.dataset_name:
cli_overrides['dataset_name'] = args.dataset_name
if args.data_path:
cli_overrides['data_path'] = args.data_path
if args.data_format:
cli_overrides['data_format'] = args.data_format
if args.conversation_field:
cli_overrides['conversation_field'] = args.conversation_field
if args.max_samples:
cli_overrides['max_samples'] = args.max_samples
if args.train_split:
cli_overrides['train_split'] = args.train_split
if args.validation_split:
cli_overrides['validation_split'] = args.validation_split
if args.test_split:
cli_overrides['test_split'] = args.test_split
if args.output_dir:
cli_overrides['output_dir'] = args.output_dir
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
print(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments
if not config_dict.get('data_source'):
parser.error("--data-source is required (either in YAML config or CLI)")
if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'):
parser.error("--dataset-name is required for HuggingFace datasets")
if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'):
parser.error("--data-path is required for custom datasets")
# Create configuration object
config = InstructConfig(
data_source=config_dict.get('data_source', 'custom'),
dataset_name=config_dict.get('dataset_name'),
data_path=config_dict.get('data_path'),
data_format=config_dict.get('data_format', 'jsonl'),
conversation_field=config_dict.get('conversation_field', 'conversation'),
max_samples=config_dict.get('max_samples'),
train_split=config_dict.get('train_split', 0.8),
validation_split=config_dict.get('validation_split', 0.1),
test_split=config_dict.get('test_split', 0.1),
clean_text=config_dict.get('clean_text', True),
min_length=config_dict.get('min_length', 10),
max_length=config_dict.get('max_length', 2048),
output_format=config_dict.get('output_format', 'conversation'),
output_dir=config_dict.get('output_dir', './data/processed/instruct'),
hf_split=config_dict.get('hf_split', 'train'),
hf_cache_dir=config_dict.get('hf_cache_dir'),
test_split_from=config_dict.get('test_split_from', 'train'),
val_split_from=config_dict.get('val_split_from', 'train'),
encoding=config_dict.get('encoding', 'utf-8')
)
# Initialize pipeline
pipeline = InstructDataPipeline()
try:
print(f"Starting instruction data pipeline with {config.data_source} data source...")
if args.config:
print(f"Using YAML configuration: {args.config}")
print(f"Conversation field: {config.conversation_field}")
print()
result = pipeline.run_pipeline(config, save_splits=True)
print(f"✅ Pipeline completed successfully!")
print(f" Data source: {config.data_source}")
if config.data_source == "huggingface":
print(f" Dataset: {config.dataset_name}")
else:
print(f" Data file: {config.data_path}")
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
print(f" Output directory: {config.output_dir}")
print(f" Conversation stats: {result['analysis']['overall']['conversation_stats']}")
except Exception as e:
print(f"❌ Error running pipeline: {e}")
import traceback
print("Full error traceback:")
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()