instruct mode added to pipiline

This commit is contained in:
Your Name
2025-08-28 16:46:24 +00:00
parent 78d519efbf
commit 77c563f358
16 changed files with 19404 additions and 161 deletions
+62 -66
View File
@@ -7,16 +7,12 @@ from datasets import Dataset, load_dataset
import os
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
from sklearn.model_selection import train_test_split
import re
import argparse
import sys
import yaml
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@dataclass
class InstructConfig:
"""Configuration for instruction fine-tuning tasks"""
@@ -75,7 +71,7 @@ class ConversationValidator:
return False, errors
total_samples = sum(len(split_data) for split_data in data.values())
logger.info(f"Validating {total_samples} total samples across all splits...")
print(f"Validating {total_samples} total samples across all splits...")
# Determine field names based on whether data is processed or not
conversation_field = "conversation" if not is_processed else "conversation"
@@ -83,10 +79,10 @@ class ConversationValidator:
# Validate each split
for split_name, split_data in data.items():
if not split_data:
logger.info(f"Skipping validation for empty {split_name} split")
print(f"Skipping validation for empty {split_name} split")
continue
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
print(f"Validating {split_name} split with {len(split_data)} samples...")
# Check required fields
missing_conversation_count = 0
@@ -117,19 +113,19 @@ class ConversationValidator:
if "role" in turn and turn["role"] not in ["user", "assistant", "system"]:
errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'")
logger.info(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
print(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
# Show sample of processed data for debugging
if split_data:
logger.info(f"Sample conversation from {split_name}:")
print(f"Sample conversation from {split_name}:")
for i in range(min(2, len(split_data))):
item = split_data[i]
conversation = item.get(conversation_field, [])
logger.info(f" Item {i} conversation length: {len(conversation)} turns")
print(f" Item {i} conversation length: {len(conversation)} turns")
for j, turn in enumerate(conversation[:3]): # Show first 3 turns
role = turn.get("role", "unknown")
content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "")
logger.info(f" Turn {j}: {role} -> '{content}'")
print(f" Turn {j}: {role} -> '{content}'")
return len(errors) == 0, errors
@@ -242,7 +238,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
if not config.dataset_name:
raise ValueError("Dataset name is required for Hugging Face datasets")
logger.info(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
print(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
try:
dataset = load_dataset(
@@ -251,7 +247,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
)
available_splits = list(dataset.keys())
logger.info(f"Available splits in dataset: {available_splits}")
print(f"Available splits in dataset: {available_splits}")
splits_data = {
"train": [],
@@ -262,10 +258,10 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
# Handle train split
if "train" in available_splits:
train_dataset = dataset["train"]
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
print(f"Using 'train' split with {len(train_dataset)} samples")
splits_data["train"] = list(train_dataset)
else:
logger.error("No 'train' split found in dataset!")
print("No 'train' split found in dataset!")
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
# Handle validation and test splits (similar logic to styling pipeline)
@@ -277,23 +273,23 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
if splits_data[split_name]:
original_size = len(splits_data[split_name])
splits_data[split_name] = splits_data[split_name][:config.max_samples]
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
print(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
logger.info(f"Successfully loaded dataset {config.dataset_name}")
print(f"Successfully loaded dataset {config.dataset_name}")
return splits_data
except Exception as e:
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
print(f"Error loading dataset {config.dataset_name}: {e}")
raise
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CONVERSATION DATA ===")
print(f"=== PREPROCESSING CONVERSATION DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
@@ -308,7 +304,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
skipped_count += 1
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
@@ -368,7 +364,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
if not file_path.exists():
raise FileNotFoundError(f"Data file not found: {file_path}")
logger.info(f"Loading custom conversation dataset: {file_path}")
print(f"Loading custom conversation dataset: {file_path}")
if config.data_format == "jsonl":
raw_data = self._load_jsonl(file_path, config)
@@ -380,7 +376,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
if config.max_samples:
raw_data = raw_data[:config.max_samples]
logger.info(f"Loaded {len(raw_data)} conversation samples from {file_path}")
print(f"Loaded {len(raw_data)} conversation samples from {file_path}")
# Create splits from the raw data
splits_data = self._create_splits(raw_data, config)
@@ -389,11 +385,11 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]:
"""Create train/validation/test splits from raw data"""
logger.info(f"Creating splits from {len(data)} conversation samples...")
print(f"Creating splits from {len(data)} conversation samples...")
# Handle very small datasets
if len(data) < 3:
logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.")
print(f"Dataset has only {len(data)} samples. Using all data for training.")
return {
"train": data,
"validation": [],
@@ -408,7 +404,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
config.train_split = 0.6
config.validation_split = 0.2
config.test_split = 0.2
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
print(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
val_size = max(1, int(total_samples * config.validation_split))
test_size = max(1, int(total_samples * config.test_split))
@@ -423,7 +419,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
test_size -= 1
train_size += 1
logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
print(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
# Create splits
if val_size == 0 and test_size == 0:
@@ -466,10 +462,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
"test": test_data
}
logger.info(f"Created conversation splits:")
logger.info(f" Train: {len(splits_data['train'])} samples")
logger.info(f" Validation: {len(splits_data['validation'])} samples")
logger.info(f" Test: {len(splits_data['test'])} samples")
print(f"Created conversation splits:")
print(f" Train: {len(splits_data['train'])} samples")
print(f" Validation: {len(splits_data['validation'])} samples")
print(f" Test: {len(splits_data['test'])} samples")
return splits_data
@@ -482,7 +478,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON at line {line_num}: {e}")
print(f"Invalid JSON at line {line_num}: {e}")
return data
def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]:
@@ -501,10 +497,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
"""Apply preprocessing steps to all splits separately"""
processed_splits = {}
logger.info(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
print(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
for split_name, split_data in data.items():
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
print(f"Processing {split_name} split with {len(split_data)} items...")
processed_data = []
processed_count = 0
@@ -518,10 +514,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
else:
skipped_count += 1
if skipped_count <= 3: # Log first few skipped items
logger.info(f"Skipped item {i} from {split_name}: {item}")
print(f"Skipped item {i} from {split_name}: {item}")
processed_splits[split_name] = processed_data
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
return processed_splits
@@ -621,59 +617,59 @@ class InstructDataPipeline:
encoding=config_dict.get('encoding', 'utf-8')
)
logger.info(f"Configuration loaded from YAML: {yaml_path}")
logger.info(f"Output directory: {config.output_dir}")
logger.info(f"Conversation field: {config.conversation_field}")
print(f"Configuration loaded from YAML: {yaml_path}")
print(f"Output directory: {config.output_dir}")
print(f"Conversation field: {config.conversation_field}")
return config
except Exception as e:
logger.error(f"Error loading configuration from YAML {yaml_path}: {e}")
print(f"Error loading configuration from YAML {yaml_path}: {e}")
raise
def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
"""Load and preprocess conversation data"""
logger.info(f"Starting conversation data loading and preprocessing...")
logger.info(f"Data source: {config.data_source}")
print(f"Starting conversation data loading and preprocessing...")
print(f"Data source: {config.data_source}")
try:
# Load data
if config.data_source == "huggingface":
logger.info("Loading HuggingFace conversation dataset...")
print("Loading HuggingFace conversation dataset...")
raw_splits = self.hf_loader.load(config)
logger.info("Preprocessing HuggingFace conversation dataset...")
print("Preprocessing HuggingFace conversation dataset...")
processed_splits = self.hf_loader.preprocess(raw_splits, config)
elif config.data_source == "custom":
logger.info("Loading custom conversation dataset...")
print("Loading custom conversation dataset...")
raw_splits = self.custom_loader.load(config)
logger.info("Preprocessing custom conversation dataset...")
print("Preprocessing custom conversation dataset...")
processed_splits = self.custom_loader.preprocess(raw_splits, config)
else:
raise ValueError(f"Unsupported data source: {config.data_source}")
logger.info(f"Conversation data loading and preprocessing completed successfully")
print(f"Conversation data loading and preprocessing completed successfully")
# Validate processed data
logger.info("Validating processed conversation data...")
print("Validating processed conversation data...")
is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True)
if not is_valid:
logger.error("Conversation data validation failed:")
print("Conversation data validation failed:")
for error in errors:
logger.error(f" - {error}")
print(f" - {error}")
raise ValueError("Conversation data validation failed")
logger.info("Conversation data validation passed")
print("Conversation data validation passed")
# Analyze dataset
logger.info("Analyzing conversation dataset...")
print("Analyzing conversation dataset...")
analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True)
logger.info("Conversation dataset analysis completed")
print("Conversation dataset analysis completed")
return processed_splits, analysis
except Exception as e:
logger.error(f"Error in load_and_preprocess: {e}")
print(f"Error in load_and_preprocess: {e}")
raise
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
@@ -692,7 +688,7 @@ class InstructDataPipeline:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
logger.info(f"Saved {len(split_data)} conversation samples to {output_file}")
print(f"Saved {len(split_data)} conversation samples to {output_file}")
def run_pipeline(
self,
@@ -701,7 +697,7 @@ class InstructDataPipeline:
) -> Dict[str, Any]:
"""Run complete instruction data pipeline"""
logger.info("Starting instruction data pipeline...")
print("Starting instruction data pipeline...")
# Load and preprocess data
processed_splits, analysis = self.load_and_preprocess(config)
@@ -723,7 +719,7 @@ class InstructDataPipeline:
"data": processed_splits, # Include the actual processed data
}
logger.info("Instruction data pipeline completed successfully!")
print("Instruction data pipeline completed successfully!")
return result
def load_yaml_config(config_path: str) -> Dict[str, Any]:
@@ -764,13 +760,13 @@ def load_yaml_config(config_path: str) -> Dict[str, Any]:
'encoding': data_config.get('encoding')
})
logger.info(f"Successfully parsed YAML configuration from: {config_path}")
logger.info(f"Extracted {len(config_dict)} configuration parameters")
print(f"Successfully parsed YAML configuration from: {config_path}")
print(f"Extracted {len(config_dict)} configuration parameters")
return config_dict
except Exception as e:
logger.error(f"Error loading YAML config from {config_path}: {e}")
print(f"Error loading YAML config from {config_path}: {e}")
raise
def main():
@@ -805,10 +801,10 @@ def main():
args = parser.parse_args()
# Set up logging
logging.basicConfig(
level=getattr(logging, args.log_level),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# logging.basicConfig(
# level=getattr(logging, args.log_level),
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# Load configuration
config_dict = {}
@@ -818,7 +814,7 @@ def main():
try:
config_dict = load_yaml_config(args.config)
except Exception as e:
logger.error(f"Error loading YAML config: {e}")
print(f"Error loading YAML config: {e}")
sys.exit(1)
# Override YAML config with CLI arguments (similar to styling pipeline)
@@ -847,7 +843,7 @@ def main():
# Merge configurations
for key, value in cli_overrides.items():
if key in config_dict:
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
print(f"Overriding YAML config '{key}' with CLI value: {value}")
config_dict[key] = value
# Validate required arguments