2025-08-28 14:12:30 +00:00
|
|
|
import json
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import numpy as np
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Dict, List, Optional, Union, Any, Tuple
|
|
|
|
|
from datasets import Dataset, load_dataset
|
|
|
|
|
import os
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
import re
|
|
|
|
|
import argparse
|
|
|
|
|
import sys
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class InstructConfig:
|
|
|
|
|
"""Configuration for instruction fine-tuning tasks"""
|
|
|
|
|
# Data source configuration
|
|
|
|
|
data_source: str = "custom" # "huggingface" or "custom"
|
|
|
|
|
dataset_name: Optional[str] = None # For Hugging Face datasets
|
|
|
|
|
data_path: Optional[str] = None # For custom datasets
|
|
|
|
|
data_format: str = "jsonl" # jsonl, json
|
|
|
|
|
|
|
|
|
|
# Field mapping - conversation data specific
|
|
|
|
|
conversation_field: str = "conversation" # Field containing conversation array
|
|
|
|
|
|
|
|
|
|
# Data processing
|
|
|
|
|
max_samples: Optional[int] = None
|
|
|
|
|
train_split: float = 0.8
|
|
|
|
|
validation_split: float = 0.1
|
|
|
|
|
test_split: float = 0.1
|
|
|
|
|
|
|
|
|
|
# Text preprocessing
|
|
|
|
|
clean_text: bool = True
|
|
|
|
|
min_length: int = 10
|
|
|
|
|
max_length: int = 2048
|
|
|
|
|
|
|
|
|
|
# Output configuration
|
|
|
|
|
output_format: str = "conversation" # conversation, alpaca
|
|
|
|
|
output_dir: str = "./data/processed/instruct"
|
|
|
|
|
|
|
|
|
|
# Hugging Face specific
|
|
|
|
|
hf_split: str = "train"
|
|
|
|
|
hf_cache_dir: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
# Split configuration
|
|
|
|
|
test_split_from: str = "train"
|
|
|
|
|
val_split_from: str = "train"
|
|
|
|
|
|
|
|
|
|
# Custom data specific
|
|
|
|
|
encoding: str = "utf-8"
|
|
|
|
|
|
|
|
|
|
class ConversationValidator:
|
|
|
|
|
"""Validates conversation data quality and format"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def validate_conversation_data(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Tuple[bool, List[str]]:
|
|
|
|
|
"""Validate conversation dataset splits"""
|
|
|
|
|
errors = []
|
|
|
|
|
|
|
|
|
|
# Check if we have the expected splits
|
|
|
|
|
expected_splits = ["train", "validation", "test"]
|
|
|
|
|
for split in expected_splits:
|
|
|
|
|
if split not in data:
|
|
|
|
|
errors.append(f"Missing '{split}' split")
|
|
|
|
|
elif split == "train" and not data[split]:
|
|
|
|
|
errors.append(f"Train split cannot be empty")
|
|
|
|
|
|
|
|
|
|
if errors:
|
|
|
|
|
return False, errors
|
|
|
|
|
|
|
|
|
|
total_samples = sum(len(split_data) for split_data in data.values())
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Validating {total_samples} total samples across all splits...")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Determine field names based on whether data is processed or not
|
|
|
|
|
conversation_field = "conversation" if not is_processed else "conversation"
|
|
|
|
|
|
|
|
|
|
# Validate each split
|
|
|
|
|
for split_name, split_data in data.items():
|
|
|
|
|
if not split_data:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Skipping validation for empty {split_name} split")
|
2025-08-28 14:12:30 +00:00
|
|
|
continue
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Validating {split_name} split with {len(split_data)} samples...")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Check required fields
|
|
|
|
|
missing_conversation_count = 0
|
|
|
|
|
|
|
|
|
|
for i, item in enumerate(split_data):
|
|
|
|
|
if conversation_field not in item:
|
|
|
|
|
errors.append(f"Missing conversation field '{conversation_field}' in {split_name} split, item {i}")
|
|
|
|
|
missing_conversation_count += 1
|
|
|
|
|
else:
|
|
|
|
|
# Validate conversation structure
|
|
|
|
|
conversation = item[conversation_field]
|
|
|
|
|
if not isinstance(conversation, list):
|
|
|
|
|
errors.append(f"Conversation field must be a list in {split_name} split, item {i}")
|
|
|
|
|
else:
|
|
|
|
|
# Validate each turn in conversation
|
|
|
|
|
for j, turn in enumerate(conversation):
|
|
|
|
|
if not isinstance(turn, dict):
|
|
|
|
|
errors.append(f"Each conversation turn must be a dict in {split_name} split, item {i}, turn {j}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Check for required fields in conversation turn
|
|
|
|
|
if "role" not in turn:
|
|
|
|
|
errors.append(f"Missing 'role' field in conversation turn {j}, {split_name} split, item {i}")
|
|
|
|
|
if "content" not in turn:
|
|
|
|
|
errors.append(f"Missing 'content' field in conversation turn {j}, {split_name} split, item {i}")
|
|
|
|
|
|
|
|
|
|
# Validate role values
|
|
|
|
|
if "role" in turn and turn["role"] not in ["user", "assistant", "system"]:
|
|
|
|
|
errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'")
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Show sample of processed data for debugging
|
|
|
|
|
if split_data:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Sample conversation from {split_name}:")
|
2025-08-28 14:12:30 +00:00
|
|
|
for i in range(min(2, len(split_data))):
|
|
|
|
|
item = split_data[i]
|
|
|
|
|
conversation = item.get(conversation_field, [])
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f" Item {i} conversation length: {len(conversation)} turns")
|
2025-08-28 14:12:30 +00:00
|
|
|
for j, turn in enumerate(conversation[:3]): # Show first 3 turns
|
|
|
|
|
role = turn.get("role", "unknown")
|
|
|
|
|
content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "")
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f" Turn {j}: {role} -> '{content}'")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return len(errors) == 0, errors
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def analyze_conversation_dataset(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Dict[str, Any]:
|
|
|
|
|
"""Analyze conversation dataset characteristics across all splits"""
|
|
|
|
|
analysis = {
|
|
|
|
|
"splits": {},
|
|
|
|
|
"overall": {
|
|
|
|
|
"total_samples": 0,
|
|
|
|
|
"split_sizes": {},
|
|
|
|
|
"conversation_stats": {
|
|
|
|
|
"total_turns": 0,
|
|
|
|
|
"avg_turns_per_conversation": 0,
|
|
|
|
|
"role_distribution": {"user": 0, "assistant": 0, "system": 0}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
conversation_field = "conversation" if not is_processed else "conversation"
|
|
|
|
|
total_turns = 0
|
|
|
|
|
total_conversations = 0
|
|
|
|
|
role_counts = {"user": 0, "assistant": 0, "system": 0}
|
|
|
|
|
|
|
|
|
|
# Analyze each split
|
|
|
|
|
for split_name, split_data in data.items():
|
|
|
|
|
if not split_data:
|
|
|
|
|
split_analysis = {
|
|
|
|
|
"total_samples": 0,
|
|
|
|
|
"conversation_stats": {},
|
|
|
|
|
"missing_values": {}
|
|
|
|
|
}
|
|
|
|
|
analysis["splits"][split_name] = split_analysis
|
|
|
|
|
analysis["overall"]["split_sizes"][split_name] = 0
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
split_analysis = {
|
|
|
|
|
"total_samples": len(split_data),
|
|
|
|
|
"conversation_stats": {},
|
|
|
|
|
"missing_values": {}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Conversation statistics
|
|
|
|
|
split_turns = 0
|
|
|
|
|
split_conversations = len(split_data)
|
|
|
|
|
split_role_counts = {"user": 0, "assistant": 0, "system": 0}
|
|
|
|
|
conversation_lengths = []
|
|
|
|
|
|
|
|
|
|
for item in split_data:
|
|
|
|
|
conversation = item.get(conversation_field, [])
|
|
|
|
|
if isinstance(conversation, list):
|
|
|
|
|
conversation_lengths.append(len(conversation))
|
|
|
|
|
split_turns += len(conversation)
|
|
|
|
|
for turn in conversation:
|
|
|
|
|
if isinstance(turn, dict) and "role" in turn:
|
|
|
|
|
role = turn["role"]
|
|
|
|
|
if role in split_role_counts:
|
|
|
|
|
split_role_counts[role] += 1
|
|
|
|
|
|
|
|
|
|
if conversation_lengths:
|
|
|
|
|
split_analysis["conversation_stats"] = {
|
|
|
|
|
"total_turns": split_turns,
|
|
|
|
|
"avg_turns_per_conversation": np.mean(conversation_lengths),
|
|
|
|
|
"min_turns": min(conversation_lengths),
|
|
|
|
|
"max_turns": max(conversation_lengths),
|
|
|
|
|
"median_turns": np.median(conversation_lengths),
|
|
|
|
|
"role_distribution": split_role_counts
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Missing values
|
|
|
|
|
missing_count = sum(1 for item in split_data if not item.get(conversation_field))
|
|
|
|
|
split_analysis["missing_values"][conversation_field] = missing_count
|
|
|
|
|
|
|
|
|
|
analysis["splits"][split_name] = split_analysis
|
|
|
|
|
analysis["overall"]["total_samples"] += len(split_data)
|
|
|
|
|
analysis["overall"]["split_sizes"][split_name] = len(split_data)
|
|
|
|
|
|
|
|
|
|
# Accumulate overall stats
|
|
|
|
|
total_turns += split_turns
|
|
|
|
|
total_conversations += split_conversations
|
|
|
|
|
for role, count in split_role_counts.items():
|
|
|
|
|
role_counts[role] += count
|
|
|
|
|
|
|
|
|
|
# Calculate overall conversation stats
|
|
|
|
|
if total_conversations > 0:
|
|
|
|
|
analysis["overall"]["conversation_stats"]["total_turns"] = total_turns
|
|
|
|
|
analysis["overall"]["conversation_stats"]["avg_turns_per_conversation"] = total_turns / total_conversations
|
|
|
|
|
analysis["overall"]["conversation_stats"]["role_distribution"] = role_counts
|
|
|
|
|
|
|
|
|
|
return analysis
|
|
|
|
|
|
|
|
|
|
class BaseInstructDataLoader(ABC):
|
|
|
|
|
"""Abstract base class for instruction data loaders"""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Load data and return dictionary with train/val/test splits"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Apply preprocessing steps to all splits"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
|
|
|
|
|
"""Load conversation datasets from Hugging Face Hub"""
|
|
|
|
|
|
|
|
|
|
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Load dataset from Hugging Face Hub with flexible split handling"""
|
|
|
|
|
if not config.dataset_name:
|
|
|
|
|
raise ValueError("Dataset name is required for Hugging Face datasets")
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
dataset = load_dataset(
|
|
|
|
|
config.dataset_name,
|
|
|
|
|
cache_dir=config.hf_cache_dir
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
available_splits = list(dataset.keys())
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Available splits in dataset: {available_splits}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
splits_data = {
|
|
|
|
|
"train": [],
|
|
|
|
|
"validation": [],
|
|
|
|
|
"test": []
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Handle train split
|
|
|
|
|
if "train" in available_splits:
|
|
|
|
|
train_dataset = dataset["train"]
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Using 'train' split with {len(train_dataset)} samples")
|
2025-08-28 14:12:30 +00:00
|
|
|
splits_data["train"] = list(train_dataset)
|
|
|
|
|
else:
|
2025-08-28 16:46:24 +00:00
|
|
|
print("No 'train' split found in dataset!")
|
2025-08-28 14:12:30 +00:00
|
|
|
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
|
|
|
|
|
|
|
|
|
|
# Handle validation and test splits (similar logic to styling pipeline)
|
|
|
|
|
# ... [validation and test split handling logic similar to styling pipeline]
|
|
|
|
|
|
|
|
|
|
# Apply max_samples limit if specified
|
|
|
|
|
if config.max_samples:
|
|
|
|
|
for split_name in splits_data:
|
|
|
|
|
if splits_data[split_name]:
|
|
|
|
|
original_size = len(splits_data[split_name])
|
|
|
|
|
splits_data[split_name] = splits_data[split_name][:config.max_samples]
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Successfully loaded dataset {config.dataset_name}")
|
2025-08-28 14:12:30 +00:00
|
|
|
return splits_data
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Error loading dataset {config.dataset_name}: {e}")
|
2025-08-28 14:12:30 +00:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Apply preprocessing steps to all splits separately"""
|
|
|
|
|
processed_splits = {}
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"=== PREPROCESSING CONVERSATION DATA ===")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
for split_name, split_data in data.items():
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Processing {split_name} split with {len(split_data)} items...")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
processed_data = []
|
|
|
|
|
processed_count = 0
|
|
|
|
|
skipped_count = 0
|
|
|
|
|
|
|
|
|
|
for i, item in enumerate(split_data):
|
|
|
|
|
processed_item = self._preprocess_item(item, config)
|
|
|
|
|
if processed_item is not None:
|
|
|
|
|
processed_data.append(processed_item)
|
|
|
|
|
processed_count += 1
|
|
|
|
|
else:
|
|
|
|
|
skipped_count += 1
|
|
|
|
|
|
|
|
|
|
processed_splits[split_name] = processed_data
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return processed_splits
|
|
|
|
|
|
|
|
|
|
def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]:
|
|
|
|
|
"""Preprocess a single conversation item"""
|
|
|
|
|
conversation = item.get(config.conversation_field, [])
|
|
|
|
|
|
|
|
|
|
if not isinstance(conversation, list) or not conversation:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Validate conversation structure
|
|
|
|
|
valid_conversation = []
|
|
|
|
|
for turn in conversation:
|
|
|
|
|
if not isinstance(turn, dict):
|
|
|
|
|
continue
|
|
|
|
|
if "role" not in turn or "content" not in turn:
|
|
|
|
|
continue
|
|
|
|
|
if turn["role"] not in ["user", "assistant", "system"]:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
content = str(turn["content"]).strip()
|
|
|
|
|
if len(content) < config.min_length or len(content) > config.max_length:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if config.clean_text:
|
|
|
|
|
content = self._clean_text(content)
|
|
|
|
|
|
|
|
|
|
valid_conversation.append({
|
|
|
|
|
"role": turn["role"],
|
|
|
|
|
"content": content
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if len(valid_conversation) < 2: # Need at least 2 turns for a conversation
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return {"conversation": valid_conversation}
|
|
|
|
|
|
|
|
|
|
def _clean_text(self, text: str) -> str:
|
|
|
|
|
"""Clean and normalize text"""
|
|
|
|
|
if not isinstance(text, str):
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
# Remove extra whitespace
|
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
class CustomInstructDataLoader(BaseInstructDataLoader):
|
|
|
|
|
"""Load custom conversation datasets from local files"""
|
|
|
|
|
|
|
|
|
|
def load(self, config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Load custom conversation dataset from local file and create splits"""
|
|
|
|
|
if not config.data_path:
|
|
|
|
|
raise ValueError("Data path is required for custom datasets")
|
|
|
|
|
|
|
|
|
|
file_path = Path(config.data_path)
|
|
|
|
|
|
|
|
|
|
if not file_path.exists():
|
|
|
|
|
raise FileNotFoundError(f"Data file not found: {file_path}")
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Loading custom conversation dataset: {file_path}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
if config.data_format == "jsonl":
|
|
|
|
|
raw_data = self._load_jsonl(file_path, config)
|
|
|
|
|
elif config.data_format == "json":
|
|
|
|
|
raw_data = self._load_json(file_path, config)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported format: {config.data_format}")
|
|
|
|
|
|
|
|
|
|
if config.max_samples:
|
|
|
|
|
raw_data = raw_data[:config.max_samples]
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Loaded {len(raw_data)} conversation samples from {file_path}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Create splits from the raw data
|
|
|
|
|
splits_data = self._create_splits(raw_data, config)
|
|
|
|
|
|
|
|
|
|
return splits_data
|
|
|
|
|
|
|
|
|
|
def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Create train/validation/test splits from raw data"""
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Creating splits from {len(data)} conversation samples...")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Handle very small datasets
|
|
|
|
|
if len(data) < 3:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Dataset has only {len(data)} samples. Using all data for training.")
|
2025-08-28 14:12:30 +00:00
|
|
|
return {
|
|
|
|
|
"train": data,
|
|
|
|
|
"validation": [],
|
|
|
|
|
"test": []
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Calculate split sizes
|
|
|
|
|
total_samples = len(data)
|
|
|
|
|
|
|
|
|
|
# Adjust split ratios if dataset is too small
|
|
|
|
|
if total_samples < 10:
|
|
|
|
|
config.train_split = 0.6
|
|
|
|
|
config.validation_split = 0.2
|
|
|
|
|
config.test_split = 0.2
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
val_size = max(1, int(total_samples * config.validation_split))
|
|
|
|
|
test_size = max(1, int(total_samples * config.test_split))
|
|
|
|
|
train_size = total_samples - val_size - test_size
|
|
|
|
|
|
|
|
|
|
# Ensure train split has at least 1 sample
|
|
|
|
|
if train_size < 1:
|
|
|
|
|
if val_size > 1:
|
|
|
|
|
val_size -= 1
|
|
|
|
|
train_size += 1
|
|
|
|
|
elif test_size > 1:
|
|
|
|
|
test_size -= 1
|
|
|
|
|
train_size += 1
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Create splits
|
|
|
|
|
if val_size == 0 and test_size == 0:
|
|
|
|
|
splits_data = {
|
|
|
|
|
"train": data,
|
|
|
|
|
"validation": [],
|
|
|
|
|
"test": []
|
|
|
|
|
}
|
|
|
|
|
elif val_size == 0:
|
|
|
|
|
train_data, test_data = train_test_split(data, test_size=test_size, random_state=42)
|
|
|
|
|
splits_data = {
|
|
|
|
|
"train": train_data,
|
|
|
|
|
"validation": [],
|
|
|
|
|
"test": test_data
|
|
|
|
|
}
|
|
|
|
|
elif test_size == 0:
|
|
|
|
|
train_data, val_data = train_test_split(data, test_size=val_size, random_state=42)
|
|
|
|
|
splits_data = {
|
|
|
|
|
"train": train_data,
|
|
|
|
|
"validation": val_data,
|
|
|
|
|
"test": []
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
# Full three-way split
|
|
|
|
|
train_data, temp_data = train_test_split(
|
|
|
|
|
data,
|
|
|
|
|
test_size=val_size + test_size,
|
|
|
|
|
random_state=42
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
val_data, test_data = train_test_split(
|
|
|
|
|
temp_data,
|
|
|
|
|
test_size=test_size,
|
|
|
|
|
random_state=42
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
splits_data = {
|
|
|
|
|
"train": train_data,
|
|
|
|
|
"validation": val_data,
|
|
|
|
|
"test": test_data
|
|
|
|
|
}
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Created conversation splits:")
|
|
|
|
|
print(f" Train: {len(splits_data['train'])} samples")
|
|
|
|
|
print(f" Validation: {len(splits_data['validation'])} samples")
|
|
|
|
|
print(f" Test: {len(splits_data['test'])} samples")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return splits_data
|
|
|
|
|
|
|
|
|
|
def _load_jsonl(self, file_path: Path, config: InstructConfig) -> List[Dict]:
|
|
|
|
|
"""Load JSONL file"""
|
|
|
|
|
data = []
|
|
|
|
|
with open(file_path, 'r', encoding=config.encoding) as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
if line.strip():
|
|
|
|
|
try:
|
|
|
|
|
data.append(json.loads(line))
|
|
|
|
|
except json.JSONDecodeError as e:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Invalid JSON at line {line_num}: {e}")
|
2025-08-28 14:12:30 +00:00
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]:
|
|
|
|
|
"""Load JSON file"""
|
|
|
|
|
with open(file_path, 'r', encoding=config.encoding) as f:
|
|
|
|
|
data = json.load(f)
|
|
|
|
|
|
|
|
|
|
if isinstance(data, list):
|
|
|
|
|
return data
|
|
|
|
|
elif isinstance(data, dict) and "data" in data:
|
|
|
|
|
return data["data"]
|
|
|
|
|
else:
|
|
|
|
|
return [data]
|
|
|
|
|
|
|
|
|
|
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
|
|
|
|
|
"""Apply preprocessing steps to all splits separately"""
|
|
|
|
|
processed_splits = {}
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
for split_name, split_data in data.items():
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Processing {split_name} split with {len(split_data)} items...")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
processed_data = []
|
|
|
|
|
processed_count = 0
|
|
|
|
|
skipped_count = 0
|
|
|
|
|
|
|
|
|
|
for i, item in enumerate(split_data):
|
|
|
|
|
processed_item = self._preprocess_item(item, config)
|
|
|
|
|
if processed_item is not None:
|
|
|
|
|
processed_data.append(processed_item)
|
|
|
|
|
processed_count += 1
|
|
|
|
|
else:
|
|
|
|
|
skipped_count += 1
|
|
|
|
|
if skipped_count <= 3: # Log first few skipped items
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Skipped item {i} from {split_name}: {item}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
processed_splits[split_name] = processed_data
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return processed_splits
|
|
|
|
|
|
|
|
|
|
def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]:
|
|
|
|
|
"""Preprocess a single conversation item"""
|
|
|
|
|
conversation = item.get(config.conversation_field, [])
|
|
|
|
|
|
|
|
|
|
if not isinstance(conversation, list) or not conversation:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Validate conversation structure
|
|
|
|
|
valid_conversation = []
|
|
|
|
|
for turn in conversation:
|
|
|
|
|
if not isinstance(turn, dict):
|
|
|
|
|
continue
|
|
|
|
|
if "role" not in turn or "content" not in turn:
|
|
|
|
|
continue
|
|
|
|
|
if turn["role"] not in ["user", "assistant", "system"]:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
content = str(turn["content"]).strip()
|
|
|
|
|
if len(content) < config.min_length or len(content) > config.max_length:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if config.clean_text:
|
|
|
|
|
content = self._clean_text(content)
|
|
|
|
|
|
|
|
|
|
valid_conversation.append({
|
|
|
|
|
"role": turn["role"],
|
|
|
|
|
"content": content
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if len(valid_conversation) < 2: # Need at least 2 turns for a conversation
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return {"conversation": valid_conversation}
|
|
|
|
|
|
|
|
|
|
def _clean_text(self, text: str) -> str:
|
|
|
|
|
"""Clean and normalize text"""
|
|
|
|
|
if not isinstance(text, str):
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
# Remove extra whitespace
|
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
class InstructDataPipeline:
|
|
|
|
|
"""Main instruction fine-tuning data pipeline"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.validator = ConversationValidator()
|
|
|
|
|
self.hf_loader = HuggingFaceInstructDataLoader()
|
|
|
|
|
self.custom_loader = CustomInstructDataLoader()
|
|
|
|
|
|
|
|
|
|
def create_config(
|
|
|
|
|
self,
|
|
|
|
|
data_source: str,
|
|
|
|
|
dataset_name: Optional[str] = None,
|
|
|
|
|
data_path: Optional[str] = None,
|
|
|
|
|
conversation_field: str = "conversation",
|
|
|
|
|
**kwargs
|
|
|
|
|
) -> InstructConfig:
|
|
|
|
|
"""Create instruction configuration"""
|
|
|
|
|
return InstructConfig(
|
|
|
|
|
data_source=data_source,
|
|
|
|
|
dataset_name=dataset_name,
|
|
|
|
|
data_path=data_path,
|
|
|
|
|
conversation_field=conversation_field,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_config_from_yaml(self, yaml_path: str) -> InstructConfig:
|
|
|
|
|
"""Load configuration from YAML file"""
|
|
|
|
|
try:
|
|
|
|
|
config_dict = load_yaml_config(yaml_path)
|
|
|
|
|
|
|
|
|
|
# Create configuration object from YAML data
|
|
|
|
|
config = InstructConfig(
|
|
|
|
|
data_source=config_dict.get('data_source', 'custom'),
|
|
|
|
|
dataset_name=config_dict.get('dataset_name'),
|
|
|
|
|
data_path=config_dict.get('data_path'),
|
|
|
|
|
data_format=config_dict.get('data_format', 'jsonl'),
|
|
|
|
|
conversation_field=config_dict.get('conversation_field', 'conversation'),
|
|
|
|
|
max_samples=config_dict.get('max_samples'),
|
|
|
|
|
train_split=config_dict.get('train_split', 0.8),
|
|
|
|
|
validation_split=config_dict.get('validation_split', 0.1),
|
|
|
|
|
test_split=config_dict.get('test_split', 0.1),
|
|
|
|
|
clean_text=config_dict.get('clean_text', True),
|
|
|
|
|
min_length=config_dict.get('min_length', 10),
|
|
|
|
|
max_length=config_dict.get('max_length', 2048),
|
|
|
|
|
output_format=config_dict.get('output_format', 'conversation'),
|
|
|
|
|
output_dir=config_dict.get('output_dir', './data/processed/instruct'),
|
|
|
|
|
hf_split=config_dict.get('hf_split', 'train'),
|
|
|
|
|
hf_cache_dir=config_dict.get('hf_cache_dir'),
|
|
|
|
|
test_split_from=config_dict.get('test_split_from', 'train'),
|
|
|
|
|
val_split_from=config_dict.get('val_split_from', 'train'),
|
|
|
|
|
encoding=config_dict.get('encoding', 'utf-8')
|
|
|
|
|
)
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Configuration loaded from YAML: {yaml_path}")
|
|
|
|
|
print(f"Output directory: {config.output_dir}")
|
|
|
|
|
print(f"Conversation field: {config.conversation_field}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Error loading configuration from YAML {yaml_path}: {e}")
|
2025-08-28 14:12:30 +00:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
|
|
|
|
|
"""Load and preprocess conversation data"""
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Starting conversation data loading and preprocessing...")
|
|
|
|
|
print(f"Data source: {config.data_source}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Load data
|
|
|
|
|
if config.data_source == "huggingface":
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Loading HuggingFace conversation dataset...")
|
2025-08-28 14:12:30 +00:00
|
|
|
raw_splits = self.hf_loader.load(config)
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Preprocessing HuggingFace conversation dataset...")
|
2025-08-28 14:12:30 +00:00
|
|
|
processed_splits = self.hf_loader.preprocess(raw_splits, config)
|
|
|
|
|
elif config.data_source == "custom":
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Loading custom conversation dataset...")
|
2025-08-28 14:12:30 +00:00
|
|
|
raw_splits = self.custom_loader.load(config)
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Preprocessing custom conversation dataset...")
|
2025-08-28 14:12:30 +00:00
|
|
|
processed_splits = self.custom_loader.preprocess(raw_splits, config)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported data source: {config.data_source}")
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Conversation data loading and preprocessing completed successfully")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Validate processed data
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Validating processed conversation data...")
|
2025-08-28 14:12:30 +00:00
|
|
|
is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True)
|
|
|
|
|
if not is_valid:
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Conversation data validation failed:")
|
2025-08-28 14:12:30 +00:00
|
|
|
for error in errors:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f" - {error}")
|
2025-08-28 14:12:30 +00:00
|
|
|
raise ValueError("Conversation data validation failed")
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Conversation data validation passed")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Analyze dataset
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Analyzing conversation dataset...")
|
2025-08-28 14:12:30 +00:00
|
|
|
analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True)
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Conversation dataset analysis completed")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return processed_splits, analysis
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Error in load_and_preprocess: {e}")
|
2025-08-28 14:12:30 +00:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
|
|
|
|
|
"""Save processed conversation data splits to files"""
|
|
|
|
|
output_path = Path(output_dir)
|
|
|
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
for split_name, split_data in data.items():
|
|
|
|
|
if format == "jsonl":
|
|
|
|
|
output_file = output_path / f"{split_name}.jsonl"
|
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
for item in split_data:
|
|
|
|
|
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
|
|
|
|
elif format == "json":
|
|
|
|
|
output_file = output_path / f"{split_name}.json"
|
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(split_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Saved {len(split_data)} conversation samples to {output_file}")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
def run_pipeline(
|
|
|
|
|
self,
|
|
|
|
|
config: InstructConfig,
|
|
|
|
|
save_splits: bool = True
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
"""Run complete instruction data pipeline"""
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Starting instruction data pipeline...")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Load and preprocess data
|
|
|
|
|
processed_splits, analysis = self.load_and_preprocess(config)
|
|
|
|
|
|
|
|
|
|
# Save data if requested
|
|
|
|
|
if save_splits:
|
|
|
|
|
output_dir = Path(config.output_dir)
|
|
|
|
|
self.save_data(processed_splits, str(output_dir))
|
|
|
|
|
|
|
|
|
|
# Create result summary
|
|
|
|
|
result = {
|
|
|
|
|
"config": config,
|
|
|
|
|
"analysis": analysis,
|
|
|
|
|
"splits": {
|
|
|
|
|
split_name: len(split_data) for split_name, split_data in processed_splits.items()
|
|
|
|
|
},
|
|
|
|
|
"output_format": config.output_format,
|
|
|
|
|
"output_dir": config.output_dir,
|
|
|
|
|
"data": processed_splits, # Include the actual processed data
|
|
|
|
|
}
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print("Instruction data pipeline completed successfully!")
|
2025-08-28 14:12:30 +00:00
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def load_yaml_config(config_path: str) -> Dict[str, Any]:
|
|
|
|
|
"""Load and parse YAML configuration file with proper structure handling"""
|
|
|
|
|
try:
|
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
yaml_data = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
# Extract configuration from YAML structure
|
|
|
|
|
config_dict = {}
|
|
|
|
|
|
|
|
|
|
# Handle task section
|
|
|
|
|
if 'task' in yaml_data:
|
|
|
|
|
task_data = yaml_data['task']
|
|
|
|
|
config_dict.update({
|
|
|
|
|
'task_name': task_data.get('name'),
|
|
|
|
|
'task_type': task_data.get('type')
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# Handle data section
|
|
|
|
|
if 'data' in yaml_data:
|
|
|
|
|
data_config = yaml_data['data']
|
|
|
|
|
config_dict.update({
|
|
|
|
|
'data_source': data_config.get('source'),
|
|
|
|
|
'dataset_name': data_config.get('dataset_name'),
|
|
|
|
|
'data_path': data_config.get('data_path'),
|
|
|
|
|
'data_format': data_config.get('data_format'),
|
|
|
|
|
'conversation_field': data_config.get('conversation_field'),
|
|
|
|
|
'max_samples': data_config.get('max_samples'),
|
|
|
|
|
'train_split': data_config.get('train_split'),
|
|
|
|
|
'validation_split': data_config.get('validation_split'),
|
|
|
|
|
'test_split': data_config.get('test_split'),
|
|
|
|
|
'clean_text': data_config.get('clean_text'),
|
|
|
|
|
'min_length': data_config.get('min_length'),
|
|
|
|
|
'max_length': data_config.get('max_length'),
|
|
|
|
|
'output_format': data_config.get('output_format'),
|
|
|
|
|
'output_dir': data_config.get('output_dir'),
|
|
|
|
|
'encoding': data_config.get('encoding')
|
|
|
|
|
})
|
|
|
|
|
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Successfully parsed YAML configuration from: {config_path}")
|
|
|
|
|
print(f"Extracted {len(config_dict)} configuration parameters")
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
return config_dict
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Error loading YAML config from {config_path}: {e}")
|
2025-08-28 14:12:30 +00:00
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
"""Main function with YAML configuration support"""
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Instruction Data Processing Pipeline")
|
|
|
|
|
|
|
|
|
|
# YAML configuration
|
|
|
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
|
|
|
|
|
|
|
|
|
# Data source arguments
|
|
|
|
|
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
|
|
|
|
|
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
|
|
|
|
|
parser.add_argument("--data-path", type=str, help="Path to custom data file")
|
|
|
|
|
parser.add_argument("--data-format", choices=["jsonl", "json"], help="Data format")
|
|
|
|
|
|
|
|
|
|
# Field mapping
|
|
|
|
|
parser.add_argument("--conversation-field", type=str, help="Conversation field name")
|
|
|
|
|
|
|
|
|
|
# Data processing
|
|
|
|
|
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
|
|
|
|
|
parser.add_argument("--train-split", type=float, help="Training split ratio")
|
|
|
|
|
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
|
|
|
|
|
parser.add_argument("--test-split", type=float, help="Test split ratio")
|
|
|
|
|
|
|
|
|
|
# Output configuration
|
|
|
|
|
parser.add_argument("--output-dir", type=str, help="Output directory")
|
|
|
|
|
|
|
|
|
|
# Logging
|
|
|
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
# Set up logging
|
2025-08-28 16:46:24 +00:00
|
|
|
# logging.basicConfig(
|
|
|
|
|
# level=getattr(logging, args.log_level),
|
|
|
|
|
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
|
|
# )
|
2025-08-28 14:12:30 +00:00
|
|
|
|
|
|
|
|
# Load configuration
|
|
|
|
|
config_dict = {}
|
|
|
|
|
|
|
|
|
|
# Load YAML config if provided
|
|
|
|
|
if args.config:
|
|
|
|
|
try:
|
|
|
|
|
config_dict = load_yaml_config(args.config)
|
|
|
|
|
except Exception as e:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Error loading YAML config: {e}")
|
2025-08-28 14:12:30 +00:00
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
# Override YAML config with CLI arguments (similar to styling pipeline)
|
|
|
|
|
cli_overrides = {}
|
|
|
|
|
if args.data_source:
|
|
|
|
|
cli_overrides['data_source'] = args.data_source
|
|
|
|
|
if args.dataset_name:
|
|
|
|
|
cli_overrides['dataset_name'] = args.dataset_name
|
|
|
|
|
if args.data_path:
|
|
|
|
|
cli_overrides['data_path'] = args.data_path
|
|
|
|
|
if args.data_format:
|
|
|
|
|
cli_overrides['data_format'] = args.data_format
|
|
|
|
|
if args.conversation_field:
|
|
|
|
|
cli_overrides['conversation_field'] = args.conversation_field
|
|
|
|
|
if args.max_samples:
|
|
|
|
|
cli_overrides['max_samples'] = args.max_samples
|
|
|
|
|
if args.train_split:
|
|
|
|
|
cli_overrides['train_split'] = args.train_split
|
|
|
|
|
if args.validation_split:
|
|
|
|
|
cli_overrides['validation_split'] = args.validation_split
|
|
|
|
|
if args.test_split:
|
|
|
|
|
cli_overrides['test_split'] = args.test_split
|
|
|
|
|
if args.output_dir:
|
|
|
|
|
cli_overrides['output_dir'] = args.output_dir
|
|
|
|
|
|
|
|
|
|
# Merge configurations
|
|
|
|
|
for key, value in cli_overrides.items():
|
|
|
|
|
if key in config_dict:
|
2025-08-28 16:46:24 +00:00
|
|
|
print(f"Overriding YAML config '{key}' with CLI value: {value}")
|
2025-08-28 14:12:30 +00:00
|
|
|
config_dict[key] = value
|
|
|
|
|
|
|
|
|
|
# Validate required arguments
|
|
|
|
|
if not config_dict.get('data_source'):
|
|
|
|
|
parser.error("--data-source is required (either in YAML config or CLI)")
|
|
|
|
|
|
|
|
|
|
if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'):
|
|
|
|
|
parser.error("--dataset-name is required for HuggingFace datasets")
|
|
|
|
|
|
|
|
|
|
if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'):
|
|
|
|
|
parser.error("--data-path is required for custom datasets")
|
|
|
|
|
|
|
|
|
|
# Create configuration object
|
|
|
|
|
config = InstructConfig(
|
|
|
|
|
data_source=config_dict.get('data_source', 'custom'),
|
|
|
|
|
dataset_name=config_dict.get('dataset_name'),
|
|
|
|
|
data_path=config_dict.get('data_path'),
|
|
|
|
|
data_format=config_dict.get('data_format', 'jsonl'),
|
|
|
|
|
conversation_field=config_dict.get('conversation_field', 'conversation'),
|
|
|
|
|
max_samples=config_dict.get('max_samples'),
|
|
|
|
|
train_split=config_dict.get('train_split', 0.8),
|
|
|
|
|
validation_split=config_dict.get('validation_split', 0.1),
|
|
|
|
|
test_split=config_dict.get('test_split', 0.1),
|
|
|
|
|
clean_text=config_dict.get('clean_text', True),
|
|
|
|
|
min_length=config_dict.get('min_length', 10),
|
|
|
|
|
max_length=config_dict.get('max_length', 2048),
|
|
|
|
|
output_format=config_dict.get('output_format', 'conversation'),
|
|
|
|
|
output_dir=config_dict.get('output_dir', './data/processed/instruct'),
|
|
|
|
|
hf_split=config_dict.get('hf_split', 'train'),
|
|
|
|
|
hf_cache_dir=config_dict.get('hf_cache_dir'),
|
|
|
|
|
test_split_from=config_dict.get('test_split_from', 'train'),
|
|
|
|
|
val_split_from=config_dict.get('val_split_from', 'train'),
|
|
|
|
|
encoding=config_dict.get('encoding', 'utf-8')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Initialize pipeline
|
|
|
|
|
pipeline = InstructDataPipeline()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
print(f"Starting instruction data pipeline with {config.data_source} data source...")
|
|
|
|
|
if args.config:
|
|
|
|
|
print(f"Using YAML configuration: {args.config}")
|
|
|
|
|
print(f"Conversation field: {config.conversation_field}")
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
result = pipeline.run_pipeline(config, save_splits=True)
|
|
|
|
|
|
|
|
|
|
print(f"✅ Pipeline completed successfully!")
|
|
|
|
|
print(f" Data source: {config.data_source}")
|
|
|
|
|
if config.data_source == "huggingface":
|
|
|
|
|
print(f" Dataset: {config.dataset_name}")
|
|
|
|
|
else:
|
|
|
|
|
print(f" Data file: {config.data_path}")
|
|
|
|
|
print(f" Total samples: {result['analysis']['overall']['total_samples']}")
|
|
|
|
|
print(f" Split sizes: {result['analysis']['overall']['split_sizes']}")
|
|
|
|
|
print(f" Output directory: {config.output_dir}")
|
|
|
|
|
print(f" Conversation stats: {result['analysis']['overall']['conversation_stats']}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ Error running pipeline: {e}")
|
|
|
|
|
import traceback
|
|
|
|
|
print("Full error traceback:")
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|