instruct mode added to pipiline
This commit is contained in:
@@ -7,16 +7,12 @@ from datasets import Dataset, load_dataset
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from sklearn.model_selection import train_test_split
|
||||
import re
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@dataclass
|
||||
class InstructConfig:
|
||||
"""Configuration for instruction fine-tuning tasks"""
|
||||
@@ -75,7 +71,7 @@ class ConversationValidator:
|
||||
return False, errors
|
||||
|
||||
total_samples = sum(len(split_data) for split_data in data.values())
|
||||
logger.info(f"Validating {total_samples} total samples across all splits...")
|
||||
print(f"Validating {total_samples} total samples across all splits...")
|
||||
|
||||
# Determine field names based on whether data is processed or not
|
||||
conversation_field = "conversation" if not is_processed else "conversation"
|
||||
@@ -83,10 +79,10 @@ class ConversationValidator:
|
||||
# Validate each split
|
||||
for split_name, split_data in data.items():
|
||||
if not split_data:
|
||||
logger.info(f"Skipping validation for empty {split_name} split")
|
||||
print(f"Skipping validation for empty {split_name} split")
|
||||
continue
|
||||
|
||||
logger.info(f"Validating {split_name} split with {len(split_data)} samples...")
|
||||
print(f"Validating {split_name} split with {len(split_data)} samples...")
|
||||
|
||||
# Check required fields
|
||||
missing_conversation_count = 0
|
||||
@@ -117,19 +113,19 @@ class ConversationValidator:
|
||||
if "role" in turn and turn["role"] not in ["user", "assistant", "system"]:
|
||||
errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'")
|
||||
|
||||
logger.info(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
|
||||
print(f"{split_name} - Items missing conversation field: {missing_conversation_count}")
|
||||
|
||||
# Show sample of processed data for debugging
|
||||
if split_data:
|
||||
logger.info(f"Sample conversation from {split_name}:")
|
||||
print(f"Sample conversation from {split_name}:")
|
||||
for i in range(min(2, len(split_data))):
|
||||
item = split_data[i]
|
||||
conversation = item.get(conversation_field, [])
|
||||
logger.info(f" Item {i} conversation length: {len(conversation)} turns")
|
||||
print(f" Item {i} conversation length: {len(conversation)} turns")
|
||||
for j, turn in enumerate(conversation[:3]): # Show first 3 turns
|
||||
role = turn.get("role", "unknown")
|
||||
content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "")
|
||||
logger.info(f" Turn {j}: {role} -> '{content}'")
|
||||
print(f" Turn {j}: {role} -> '{content}'")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@@ -242,7 +238,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
|
||||
if not config.dataset_name:
|
||||
raise ValueError("Dataset name is required for Hugging Face datasets")
|
||||
|
||||
logger.info(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
|
||||
print(f"Loading Hugging Face conversation dataset: {config.dataset_name}")
|
||||
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
@@ -251,7 +247,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
|
||||
)
|
||||
|
||||
available_splits = list(dataset.keys())
|
||||
logger.info(f"Available splits in dataset: {available_splits}")
|
||||
print(f"Available splits in dataset: {available_splits}")
|
||||
|
||||
splits_data = {
|
||||
"train": [],
|
||||
@@ -262,10 +258,10 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
|
||||
# Handle train split
|
||||
if "train" in available_splits:
|
||||
train_dataset = dataset["train"]
|
||||
logger.info(f"Using 'train' split with {len(train_dataset)} samples")
|
||||
print(f"Using 'train' split with {len(train_dataset)} samples")
|
||||
splits_data["train"] = list(train_dataset)
|
||||
else:
|
||||
logger.error("No 'train' split found in dataset!")
|
||||
print("No 'train' split found in dataset!")
|
||||
raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split")
|
||||
|
||||
# Handle validation and test splits (similar logic to styling pipeline)
|
||||
@@ -277,23 +273,23 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
|
||||
if splits_data[split_name]:
|
||||
original_size = len(splits_data[split_name])
|
||||
splits_data[split_name] = splits_data[split_name][:config.max_samples]
|
||||
logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
|
||||
print(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples")
|
||||
|
||||
logger.info(f"Successfully loaded dataset {config.dataset_name}")
|
||||
print(f"Successfully loaded dataset {config.dataset_name}")
|
||||
return splits_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset {config.dataset_name}: {e}")
|
||||
print(f"Error loading dataset {config.dataset_name}: {e}")
|
||||
raise
|
||||
|
||||
def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]:
|
||||
"""Apply preprocessing steps to all splits separately"""
|
||||
processed_splits = {}
|
||||
|
||||
logger.info(f"=== PREPROCESSING CONVERSATION DATA ===")
|
||||
print(f"=== PREPROCESSING CONVERSATION DATA ===")
|
||||
|
||||
for split_name, split_data in data.items():
|
||||
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
|
||||
print(f"Processing {split_name} split with {len(split_data)} items...")
|
||||
|
||||
processed_data = []
|
||||
processed_count = 0
|
||||
@@ -308,7 +304,7 @@ class HuggingFaceInstructDataLoader(BaseInstructDataLoader):
|
||||
skipped_count += 1
|
||||
|
||||
processed_splits[split_name] = processed_data
|
||||
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
||||
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
||||
|
||||
return processed_splits
|
||||
|
||||
@@ -368,7 +364,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Data file not found: {file_path}")
|
||||
|
||||
logger.info(f"Loading custom conversation dataset: {file_path}")
|
||||
print(f"Loading custom conversation dataset: {file_path}")
|
||||
|
||||
if config.data_format == "jsonl":
|
||||
raw_data = self._load_jsonl(file_path, config)
|
||||
@@ -380,7 +376,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
if config.max_samples:
|
||||
raw_data = raw_data[:config.max_samples]
|
||||
|
||||
logger.info(f"Loaded {len(raw_data)} conversation samples from {file_path}")
|
||||
print(f"Loaded {len(raw_data)} conversation samples from {file_path}")
|
||||
|
||||
# Create splits from the raw data
|
||||
splits_data = self._create_splits(raw_data, config)
|
||||
@@ -389,11 +385,11 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
|
||||
def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]:
|
||||
"""Create train/validation/test splits from raw data"""
|
||||
logger.info(f"Creating splits from {len(data)} conversation samples...")
|
||||
print(f"Creating splits from {len(data)} conversation samples...")
|
||||
|
||||
# Handle very small datasets
|
||||
if len(data) < 3:
|
||||
logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.")
|
||||
print(f"Dataset has only {len(data)} samples. Using all data for training.")
|
||||
return {
|
||||
"train": data,
|
||||
"validation": [],
|
||||
@@ -408,7 +404,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
config.train_split = 0.6
|
||||
config.validation_split = 0.2
|
||||
config.test_split = 0.2
|
||||
logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
|
||||
print(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}")
|
||||
|
||||
val_size = max(1, int(total_samples * config.validation_split))
|
||||
test_size = max(1, int(total_samples * config.test_split))
|
||||
@@ -423,7 +419,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
test_size -= 1
|
||||
train_size += 1
|
||||
|
||||
logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
|
||||
print(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}")
|
||||
|
||||
# Create splits
|
||||
if val_size == 0 and test_size == 0:
|
||||
@@ -466,10 +462,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
"test": test_data
|
||||
}
|
||||
|
||||
logger.info(f"Created conversation splits:")
|
||||
logger.info(f" Train: {len(splits_data['train'])} samples")
|
||||
logger.info(f" Validation: {len(splits_data['validation'])} samples")
|
||||
logger.info(f" Test: {len(splits_data['test'])} samples")
|
||||
print(f"Created conversation splits:")
|
||||
print(f" Train: {len(splits_data['train'])} samples")
|
||||
print(f" Validation: {len(splits_data['validation'])} samples")
|
||||
print(f" Test: {len(splits_data['test'])} samples")
|
||||
|
||||
return splits_data
|
||||
|
||||
@@ -482,7 +478,7 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
try:
|
||||
data.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON at line {line_num}: {e}")
|
||||
print(f"Invalid JSON at line {line_num}: {e}")
|
||||
return data
|
||||
|
||||
def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]:
|
||||
@@ -501,10 +497,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
"""Apply preprocessing steps to all splits separately"""
|
||||
processed_splits = {}
|
||||
|
||||
logger.info(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
|
||||
print(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===")
|
||||
|
||||
for split_name, split_data in data.items():
|
||||
logger.info(f"Processing {split_name} split with {len(split_data)} items...")
|
||||
print(f"Processing {split_name} split with {len(split_data)} items...")
|
||||
|
||||
processed_data = []
|
||||
processed_count = 0
|
||||
@@ -518,10 +514,10 @@ class CustomInstructDataLoader(BaseInstructDataLoader):
|
||||
else:
|
||||
skipped_count += 1
|
||||
if skipped_count <= 3: # Log first few skipped items
|
||||
logger.info(f"Skipped item {i} from {split_name}: {item}")
|
||||
print(f"Skipped item {i} from {split_name}: {item}")
|
||||
|
||||
processed_splits[split_name] = processed_data
|
||||
logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
||||
print(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples")
|
||||
|
||||
return processed_splits
|
||||
|
||||
@@ -621,59 +617,59 @@ class InstructDataPipeline:
|
||||
encoding=config_dict.get('encoding', 'utf-8')
|
||||
)
|
||||
|
||||
logger.info(f"Configuration loaded from YAML: {yaml_path}")
|
||||
logger.info(f"Output directory: {config.output_dir}")
|
||||
logger.info(f"Conversation field: {config.conversation_field}")
|
||||
print(f"Configuration loaded from YAML: {yaml_path}")
|
||||
print(f"Output directory: {config.output_dir}")
|
||||
print(f"Conversation field: {config.conversation_field}")
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration from YAML {yaml_path}: {e}")
|
||||
print(f"Error loading configuration from YAML {yaml_path}: {e}")
|
||||
raise
|
||||
|
||||
def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]:
|
||||
"""Load and preprocess conversation data"""
|
||||
|
||||
logger.info(f"Starting conversation data loading and preprocessing...")
|
||||
logger.info(f"Data source: {config.data_source}")
|
||||
print(f"Starting conversation data loading and preprocessing...")
|
||||
print(f"Data source: {config.data_source}")
|
||||
|
||||
try:
|
||||
# Load data
|
||||
if config.data_source == "huggingface":
|
||||
logger.info("Loading HuggingFace conversation dataset...")
|
||||
print("Loading HuggingFace conversation dataset...")
|
||||
raw_splits = self.hf_loader.load(config)
|
||||
logger.info("Preprocessing HuggingFace conversation dataset...")
|
||||
print("Preprocessing HuggingFace conversation dataset...")
|
||||
processed_splits = self.hf_loader.preprocess(raw_splits, config)
|
||||
elif config.data_source == "custom":
|
||||
logger.info("Loading custom conversation dataset...")
|
||||
print("Loading custom conversation dataset...")
|
||||
raw_splits = self.custom_loader.load(config)
|
||||
logger.info("Preprocessing custom conversation dataset...")
|
||||
print("Preprocessing custom conversation dataset...")
|
||||
processed_splits = self.custom_loader.preprocess(raw_splits, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data source: {config.data_source}")
|
||||
|
||||
logger.info(f"Conversation data loading and preprocessing completed successfully")
|
||||
print(f"Conversation data loading and preprocessing completed successfully")
|
||||
|
||||
# Validate processed data
|
||||
logger.info("Validating processed conversation data...")
|
||||
print("Validating processed conversation data...")
|
||||
is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True)
|
||||
if not is_valid:
|
||||
logger.error("Conversation data validation failed:")
|
||||
print("Conversation data validation failed:")
|
||||
for error in errors:
|
||||
logger.error(f" - {error}")
|
||||
print(f" - {error}")
|
||||
raise ValueError("Conversation data validation failed")
|
||||
|
||||
logger.info("Conversation data validation passed")
|
||||
print("Conversation data validation passed")
|
||||
|
||||
# Analyze dataset
|
||||
logger.info("Analyzing conversation dataset...")
|
||||
print("Analyzing conversation dataset...")
|
||||
analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True)
|
||||
logger.info("Conversation dataset analysis completed")
|
||||
print("Conversation dataset analysis completed")
|
||||
|
||||
return processed_splits, analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in load_and_preprocess: {e}")
|
||||
print(f"Error in load_and_preprocess: {e}")
|
||||
raise
|
||||
|
||||
def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"):
|
||||
@@ -692,7 +688,7 @@ class InstructDataPipeline:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(split_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"Saved {len(split_data)} conversation samples to {output_file}")
|
||||
print(f"Saved {len(split_data)} conversation samples to {output_file}")
|
||||
|
||||
def run_pipeline(
|
||||
self,
|
||||
@@ -701,7 +697,7 @@ class InstructDataPipeline:
|
||||
) -> Dict[str, Any]:
|
||||
"""Run complete instruction data pipeline"""
|
||||
|
||||
logger.info("Starting instruction data pipeline...")
|
||||
print("Starting instruction data pipeline...")
|
||||
|
||||
# Load and preprocess data
|
||||
processed_splits, analysis = self.load_and_preprocess(config)
|
||||
@@ -723,7 +719,7 @@ class InstructDataPipeline:
|
||||
"data": processed_splits, # Include the actual processed data
|
||||
}
|
||||
|
||||
logger.info("Instruction data pipeline completed successfully!")
|
||||
print("Instruction data pipeline completed successfully!")
|
||||
return result
|
||||
|
||||
def load_yaml_config(config_path: str) -> Dict[str, Any]:
|
||||
@@ -764,13 +760,13 @@ def load_yaml_config(config_path: str) -> Dict[str, Any]:
|
||||
'encoding': data_config.get('encoding')
|
||||
})
|
||||
|
||||
logger.info(f"Successfully parsed YAML configuration from: {config_path}")
|
||||
logger.info(f"Extracted {len(config_dict)} configuration parameters")
|
||||
print(f"Successfully parsed YAML configuration from: {config_path}")
|
||||
print(f"Extracted {len(config_dict)} configuration parameters")
|
||||
|
||||
return config_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading YAML config from {config_path}: {e}")
|
||||
print(f"Error loading YAML config from {config_path}: {e}")
|
||||
raise
|
||||
|
||||
def main():
|
||||
@@ -805,10 +801,10 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
# logging.basicConfig(
|
||||
# level=getattr(logging, args.log_level),
|
||||
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
# )
|
||||
|
||||
# Load configuration
|
||||
config_dict = {}
|
||||
@@ -818,7 +814,7 @@ def main():
|
||||
try:
|
||||
config_dict = load_yaml_config(args.config)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading YAML config: {e}")
|
||||
print(f"Error loading YAML config: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Override YAML config with CLI arguments (similar to styling pipeline)
|
||||
@@ -847,7 +843,7 @@ def main():
|
||||
# Merge configurations
|
||||
for key, value in cli_overrides.items():
|
||||
if key in config_dict:
|
||||
logger.info(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||||
print(f"Overriding YAML config '{key}' with CLI value: {value}")
|
||||
config_dict[key] = value
|
||||
|
||||
# Validate required arguments
|
||||
|
||||
Reference in New Issue
Block a user