From 78d519efbf6ad7ebf60549ca7512f2e37e59fc74 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 28 Aug 2025 14:12:30 +0000 Subject: [PATCH] updated instruct --- .../.ipynb_checkpoints/sample-checkpoint.yaml | 78 ++ configs/instruct/sample.yaml | 78 ++ configs/untitled.txt | 0 .../data_processor-checkpoint.py | 0 .../inference-checkpoint.py | 0 .../.ipynb_checkpoints/train-checkpoint.py | 0 pipelines/instruct/data_processor.py | 917 ++++++++++++++++++ pipelines/instruct/inference.py | 393 ++++++++ pipelines/instruct/train.py | 525 ++++++++++ .../data_processor-checkpoint.py | 320 ++++++ .../inference-checkpoint.py | 291 ++++++ .../.ipynb_checkpoints/train-checkpoint.py | 214 ++++ scripts/instruct/data_processor.py | 320 ++++++ scripts/instruct/inference.py | 291 ++++++ scripts/instruct/train.py | 214 ++++ 15 files changed, 3641 insertions(+) create mode 100644 configs/instruct/.ipynb_checkpoints/sample-checkpoint.yaml create mode 100644 configs/instruct/sample.yaml create mode 100644 configs/untitled.txt create mode 100644 pipelines/instruct/.ipynb_checkpoints/data_processor-checkpoint.py create mode 100644 pipelines/instruct/.ipynb_checkpoints/inference-checkpoint.py create mode 100644 pipelines/instruct/.ipynb_checkpoints/train-checkpoint.py create mode 100644 pipelines/instruct/data_processor.py create mode 100644 pipelines/instruct/inference.py create mode 100644 pipelines/instruct/train.py create mode 100644 scripts/instruct/.ipynb_checkpoints/data_processor-checkpoint.py create mode 100644 scripts/instruct/.ipynb_checkpoints/inference-checkpoint.py create mode 100644 scripts/instruct/.ipynb_checkpoints/train-checkpoint.py create mode 100644 scripts/instruct/data_processor.py create mode 100644 scripts/instruct/inference.py create mode 100644 scripts/instruct/train.py diff --git a/configs/instruct/.ipynb_checkpoints/sample-checkpoint.yaml b/configs/instruct/.ipynb_checkpoints/sample-checkpoint.yaml new file mode 100644 index 0000000..17fc405 --- /dev/null +++ b/configs/instruct/.ipynb_checkpoints/sample-checkpoint.yaml @@ -0,0 +1,78 @@ +# Comprehensive Instruct Configuration +# This file defines all parameters for instruction fine-tuning using conversational data +# Organized by level: task, data processing, model, training, and inference + +# Task Configuration +task: + name: "code_reasoning" # Task name: instruct, code_reasoning, general_chat + type: "instruction_following" # Model type: instruction_following, conversational + +# Data Processing Configuration +data: + source: "custom" # Data source: "huggingface" or "custom" + data_path: "./data/raw/instruct/code_reasoning.jsonl" # Path to conversation data file + data_format: "jsonl" # Data format: "jsonl", "json" + + # Field Mapping for Conversation Data + conversation_field: "conversation" # Field name containing conversation array + + # Data Format & Processing + max_length: 2048 # Maximum text length (truncate longer texts) + min_length: 10 # Minimum text length (filter out shorter texts) + + # Text Preprocessing + clean_text: true # Clean and normalize text + + # Data Splitting + train_split: 0.8 # Training split ratio (0.0 to 1.0) + validation_split: 0.1 # Validation split ratio (0.0 to 1.0) + test_split: 0.1 # Test split ratio (0.0 to 1.0) + + # Output Configuration + output_format: "conversation" # Output format: "conversation" (chat format) + output_dir: "./data/processed/instruct/code_reasoning" # Output directory for processed data + +# Model Configuration +model: + name: "unsloth/Qwen2.5-72B-Instruct" # Model name from HuggingFace Hub (optimized for instruction following) + max_length: 2048 # Maximum sequence length for tokenization + max_seq_length: 2048 # Maximum sequence length for training (RoPE scaling supported) + dtype: null # Data type: null for auto detection, float16 for Tesla T4/V100, bfloat16 for Ampere+ + load_in_4bit: true # Use 4bit quantization to reduce memory usage + token: null # HuggingFace token for gated models (e.g., "hf_...") + + # Training Model Parameters + training_model: "unsloth/Qwen2.5-72B-Instruct" # Model to use for training + training_max_seq_length: 2048 # Max sequence length for training + training_dtype: null # Data type for training + training_load_in_4bit: true # 4bit quantization for training + +# Training Configuration +training: + num_epochs: 1 # Number of training epochs (1 epoch is often sufficient for instruction tuning) + batch_size: 1 # Training batch size (small for large models) + learning_rate: 2e-4 # Learning rate (typical for instruction tuning) + weight_decay: 0.01 # Weight decay for optimizer (prevents overfitting) + warmup_steps: 5 # Warmup steps (fixed value) + max_steps: 30 # Maximum training steps (adjust based on dataset size) + gradient_accumulation_steps: 4 # Gradient accumulation steps + lr_scheduler_type: "linear" # Scheduler type: "linear", "cosine", "polynomial" + seed: 3407 # Random seed for reproducibility + + # LoRA Configuration + lora_r: 32 # LoRA rank (higher = more parameters) + lora_alpha: 16 # LoRA alpha (scaling factor) + lora_dropout: 0 # LoRA dropout (0 is optimized) + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + + # Output Configuration + output_dir: "./outputs" # Directory for training checkpoints + model_output_dir: "./models/instruct" # Directory to save the trained model + +# Inference Configuration +inference: + batch_size: 1 # Batch size for inference + max_new_tokens: 128 # Maximum new tokens to generate during inference + temperature: 1.5 # Sampling temperature (higher = more creative) + min_p: 0.1 # Min-p sampling parameter + use_cache: true # Use key-value cache for faster generation diff --git a/configs/instruct/sample.yaml b/configs/instruct/sample.yaml new file mode 100644 index 0000000..17fc405 --- /dev/null +++ b/configs/instruct/sample.yaml @@ -0,0 +1,78 @@ +# Comprehensive Instruct Configuration +# This file defines all parameters for instruction fine-tuning using conversational data +# Organized by level: task, data processing, model, training, and inference + +# Task Configuration +task: + name: "code_reasoning" # Task name: instruct, code_reasoning, general_chat + type: "instruction_following" # Model type: instruction_following, conversational + +# Data Processing Configuration +data: + source: "custom" # Data source: "huggingface" or "custom" + data_path: "./data/raw/instruct/code_reasoning.jsonl" # Path to conversation data file + data_format: "jsonl" # Data format: "jsonl", "json" + + # Field Mapping for Conversation Data + conversation_field: "conversation" # Field name containing conversation array + + # Data Format & Processing + max_length: 2048 # Maximum text length (truncate longer texts) + min_length: 10 # Minimum text length (filter out shorter texts) + + # Text Preprocessing + clean_text: true # Clean and normalize text + + # Data Splitting + train_split: 0.8 # Training split ratio (0.0 to 1.0) + validation_split: 0.1 # Validation split ratio (0.0 to 1.0) + test_split: 0.1 # Test split ratio (0.0 to 1.0) + + # Output Configuration + output_format: "conversation" # Output format: "conversation" (chat format) + output_dir: "./data/processed/instruct/code_reasoning" # Output directory for processed data + +# Model Configuration +model: + name: "unsloth/Qwen2.5-72B-Instruct" # Model name from HuggingFace Hub (optimized for instruction following) + max_length: 2048 # Maximum sequence length for tokenization + max_seq_length: 2048 # Maximum sequence length for training (RoPE scaling supported) + dtype: null # Data type: null for auto detection, float16 for Tesla T4/V100, bfloat16 for Ampere+ + load_in_4bit: true # Use 4bit quantization to reduce memory usage + token: null # HuggingFace token for gated models (e.g., "hf_...") + + # Training Model Parameters + training_model: "unsloth/Qwen2.5-72B-Instruct" # Model to use for training + training_max_seq_length: 2048 # Max sequence length for training + training_dtype: null # Data type for training + training_load_in_4bit: true # 4bit quantization for training + +# Training Configuration +training: + num_epochs: 1 # Number of training epochs (1 epoch is often sufficient for instruction tuning) + batch_size: 1 # Training batch size (small for large models) + learning_rate: 2e-4 # Learning rate (typical for instruction tuning) + weight_decay: 0.01 # Weight decay for optimizer (prevents overfitting) + warmup_steps: 5 # Warmup steps (fixed value) + max_steps: 30 # Maximum training steps (adjust based on dataset size) + gradient_accumulation_steps: 4 # Gradient accumulation steps + lr_scheduler_type: "linear" # Scheduler type: "linear", "cosine", "polynomial" + seed: 3407 # Random seed for reproducibility + + # LoRA Configuration + lora_r: 32 # LoRA rank (higher = more parameters) + lora_alpha: 16 # LoRA alpha (scaling factor) + lora_dropout: 0 # LoRA dropout (0 is optimized) + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + + # Output Configuration + output_dir: "./outputs" # Directory for training checkpoints + model_output_dir: "./models/instruct" # Directory to save the trained model + +# Inference Configuration +inference: + batch_size: 1 # Batch size for inference + max_new_tokens: 128 # Maximum new tokens to generate during inference + temperature: 1.5 # Sampling temperature (higher = more creative) + min_p: 0.1 # Min-p sampling parameter + use_cache: true # Use key-value cache for faster generation diff --git a/configs/untitled.txt b/configs/untitled.txt new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/instruct/.ipynb_checkpoints/data_processor-checkpoint.py b/pipelines/instruct/.ipynb_checkpoints/data_processor-checkpoint.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/instruct/.ipynb_checkpoints/inference-checkpoint.py b/pipelines/instruct/.ipynb_checkpoints/inference-checkpoint.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/instruct/.ipynb_checkpoints/train-checkpoint.py b/pipelines/instruct/.ipynb_checkpoints/train-checkpoint.py new file mode 100644 index 0000000..e69de29 diff --git a/pipelines/instruct/data_processor.py b/pipelines/instruct/data_processor.py new file mode 100644 index 0000000..10acda3 --- /dev/null +++ b/pipelines/instruct/data_processor.py @@ -0,0 +1,917 @@ +import json +import pandas as pd +import numpy as np +from pathlib import Path +from typing import Dict, List, Optional, Union, Any, Tuple +from datasets import Dataset, load_dataset +import os +from dataclasses import dataclass +from abc import ABC, abstractmethod +import logging +from sklearn.model_selection import train_test_split +import re +import argparse +import sys +import yaml + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +@dataclass +class InstructConfig: + """Configuration for instruction fine-tuning tasks""" + # Data source configuration + data_source: str = "custom" # "huggingface" or "custom" + dataset_name: Optional[str] = None # For Hugging Face datasets + data_path: Optional[str] = None # For custom datasets + data_format: str = "jsonl" # jsonl, json + + # Field mapping - conversation data specific + conversation_field: str = "conversation" # Field containing conversation array + + # Data processing + max_samples: Optional[int] = None + train_split: float = 0.8 + validation_split: float = 0.1 + test_split: float = 0.1 + + # Text preprocessing + clean_text: bool = True + min_length: int = 10 + max_length: int = 2048 + + # Output configuration + output_format: str = "conversation" # conversation, alpaca + output_dir: str = "./data/processed/instruct" + + # Hugging Face specific + hf_split: str = "train" + hf_cache_dir: Optional[str] = None + + # Split configuration + test_split_from: str = "train" + val_split_from: str = "train" + + # Custom data specific + encoding: str = "utf-8" + +class ConversationValidator: + """Validates conversation data quality and format""" + + @staticmethod + def validate_conversation_data(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Tuple[bool, List[str]]: + """Validate conversation dataset splits""" + errors = [] + + # Check if we have the expected splits + expected_splits = ["train", "validation", "test"] + for split in expected_splits: + if split not in data: + errors.append(f"Missing '{split}' split") + elif split == "train" and not data[split]: + errors.append(f"Train split cannot be empty") + + if errors: + return False, errors + + total_samples = sum(len(split_data) for split_data in data.values()) + logger.info(f"Validating {total_samples} total samples across all splits...") + + # Determine field names based on whether data is processed or not + conversation_field = "conversation" if not is_processed else "conversation" + + # Validate each split + for split_name, split_data in data.items(): + if not split_data: + logger.info(f"Skipping validation for empty {split_name} split") + continue + + logger.info(f"Validating {split_name} split with {len(split_data)} samples...") + + # Check required fields + missing_conversation_count = 0 + + for i, item in enumerate(split_data): + if conversation_field not in item: + errors.append(f"Missing conversation field '{conversation_field}' in {split_name} split, item {i}") + missing_conversation_count += 1 + else: + # Validate conversation structure + conversation = item[conversation_field] + if not isinstance(conversation, list): + errors.append(f"Conversation field must be a list in {split_name} split, item {i}") + else: + # Validate each turn in conversation + for j, turn in enumerate(conversation): + if not isinstance(turn, dict): + errors.append(f"Each conversation turn must be a dict in {split_name} split, item {i}, turn {j}") + continue + + # Check for required fields in conversation turn + if "role" not in turn: + errors.append(f"Missing 'role' field in conversation turn {j}, {split_name} split, item {i}") + if "content" not in turn: + errors.append(f"Missing 'content' field in conversation turn {j}, {split_name} split, item {i}") + + # Validate role values + if "role" in turn and turn["role"] not in ["user", "assistant", "system"]: + errors.append(f"Invalid role '{turn['role']}' in conversation turn {j}, {split_name} split, item {i}. Must be 'user', 'assistant', or 'system'") + + logger.info(f"{split_name} - Items missing conversation field: {missing_conversation_count}") + + # Show sample of processed data for debugging + if split_data: + logger.info(f"Sample conversation from {split_name}:") + for i in range(min(2, len(split_data))): + item = split_data[i] + conversation = item.get(conversation_field, []) + logger.info(f" Item {i} conversation length: {len(conversation)} turns") + for j, turn in enumerate(conversation[:3]): # Show first 3 turns + role = turn.get("role", "unknown") + content = turn.get("content", "")[:100] + "..." if len(turn.get("content", "")) > 100 else turn.get("content", "") + logger.info(f" Turn {j}: {role} -> '{content}'") + + return len(errors) == 0, errors + + @staticmethod + def analyze_conversation_dataset(data: Dict[str, List[Dict]], config: InstructConfig, is_processed: bool = False) -> Dict[str, Any]: + """Analyze conversation dataset characteristics across all splits""" + analysis = { + "splits": {}, + "overall": { + "total_samples": 0, + "split_sizes": {}, + "conversation_stats": { + "total_turns": 0, + "avg_turns_per_conversation": 0, + "role_distribution": {"user": 0, "assistant": 0, "system": 0} + } + } + } + + conversation_field = "conversation" if not is_processed else "conversation" + total_turns = 0 + total_conversations = 0 + role_counts = {"user": 0, "assistant": 0, "system": 0} + + # Analyze each split + for split_name, split_data in data.items(): + if not split_data: + split_analysis = { + "total_samples": 0, + "conversation_stats": {}, + "missing_values": {} + } + analysis["splits"][split_name] = split_analysis + analysis["overall"]["split_sizes"][split_name] = 0 + continue + + split_analysis = { + "total_samples": len(split_data), + "conversation_stats": {}, + "missing_values": {} + } + + # Conversation statistics + split_turns = 0 + split_conversations = len(split_data) + split_role_counts = {"user": 0, "assistant": 0, "system": 0} + conversation_lengths = [] + + for item in split_data: + conversation = item.get(conversation_field, []) + if isinstance(conversation, list): + conversation_lengths.append(len(conversation)) + split_turns += len(conversation) + for turn in conversation: + if isinstance(turn, dict) and "role" in turn: + role = turn["role"] + if role in split_role_counts: + split_role_counts[role] += 1 + + if conversation_lengths: + split_analysis["conversation_stats"] = { + "total_turns": split_turns, + "avg_turns_per_conversation": np.mean(conversation_lengths), + "min_turns": min(conversation_lengths), + "max_turns": max(conversation_lengths), + "median_turns": np.median(conversation_lengths), + "role_distribution": split_role_counts + } + + # Missing values + missing_count = sum(1 for item in split_data if not item.get(conversation_field)) + split_analysis["missing_values"][conversation_field] = missing_count + + analysis["splits"][split_name] = split_analysis + analysis["overall"]["total_samples"] += len(split_data) + analysis["overall"]["split_sizes"][split_name] = len(split_data) + + # Accumulate overall stats + total_turns += split_turns + total_conversations += split_conversations + for role, count in split_role_counts.items(): + role_counts[role] += count + + # Calculate overall conversation stats + if total_conversations > 0: + analysis["overall"]["conversation_stats"]["total_turns"] = total_turns + analysis["overall"]["conversation_stats"]["avg_turns_per_conversation"] = total_turns / total_conversations + analysis["overall"]["conversation_stats"]["role_distribution"] = role_counts + + return analysis + +class BaseInstructDataLoader(ABC): + """Abstract base class for instruction data loaders""" + + @abstractmethod + def load(self, config: InstructConfig) -> Dict[str, List[Dict]]: + """Load data and return dictionary with train/val/test splits""" + pass + + @abstractmethod + def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]: + """Apply preprocessing steps to all splits""" + pass + +class HuggingFaceInstructDataLoader(BaseInstructDataLoader): + """Load conversation datasets from Hugging Face Hub""" + + def load(self, config: InstructConfig) -> Dict[str, List[Dict]]: + """Load dataset from Hugging Face Hub with flexible split handling""" + if not config.dataset_name: + raise ValueError("Dataset name is required for Hugging Face datasets") + + logger.info(f"Loading Hugging Face conversation dataset: {config.dataset_name}") + + try: + dataset = load_dataset( + config.dataset_name, + cache_dir=config.hf_cache_dir + ) + + available_splits = list(dataset.keys()) + logger.info(f"Available splits in dataset: {available_splits}") + + splits_data = { + "train": [], + "validation": [], + "test": [] + } + + # Handle train split + if "train" in available_splits: + train_dataset = dataset["train"] + logger.info(f"Using 'train' split with {len(train_dataset)} samples") + splits_data["train"] = list(train_dataset) + else: + logger.error("No 'train' split found in dataset!") + raise ValueError(f"Dataset {config.dataset_name} does not have a 'train' split") + + # Handle validation and test splits (similar logic to styling pipeline) + # ... [validation and test split handling logic similar to styling pipeline] + + # Apply max_samples limit if specified + if config.max_samples: + for split_name in splits_data: + if splits_data[split_name]: + original_size = len(splits_data[split_name]) + splits_data[split_name] = splits_data[split_name][:config.max_samples] + logger.info(f"Limited {split_name} split from {original_size} to {len(splits_data[split_name])} samples") + + logger.info(f"Successfully loaded dataset {config.dataset_name}") + return splits_data + + except Exception as e: + logger.error(f"Error loading dataset {config.dataset_name}: {e}") + raise + + def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]: + """Apply preprocessing steps to all splits separately""" + processed_splits = {} + + logger.info(f"=== PREPROCESSING CONVERSATION DATA ===") + + for split_name, split_data in data.items(): + logger.info(f"Processing {split_name} split with {len(split_data)} items...") + + processed_data = [] + processed_count = 0 + skipped_count = 0 + + for i, item in enumerate(split_data): + processed_item = self._preprocess_item(item, config) + if processed_item is not None: + processed_data.append(processed_item) + processed_count += 1 + else: + skipped_count += 1 + + processed_splits[split_name] = processed_data + logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples") + + return processed_splits + + def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]: + """Preprocess a single conversation item""" + conversation = item.get(config.conversation_field, []) + + if not isinstance(conversation, list) or not conversation: + return None + + # Validate conversation structure + valid_conversation = [] + for turn in conversation: + if not isinstance(turn, dict): + continue + if "role" not in turn or "content" not in turn: + continue + if turn["role"] not in ["user", "assistant", "system"]: + continue + + content = str(turn["content"]).strip() + if len(content) < config.min_length or len(content) > config.max_length: + continue + + if config.clean_text: + content = self._clean_text(content) + + valid_conversation.append({ + "role": turn["role"], + "content": content + }) + + if len(valid_conversation) < 2: # Need at least 2 turns for a conversation + return None + + return {"conversation": valid_conversation} + + def _clean_text(self, text: str) -> str: + """Clean and normalize text""" + if not isinstance(text, str): + return "" + + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text).strip() + return text + +class CustomInstructDataLoader(BaseInstructDataLoader): + """Load custom conversation datasets from local files""" + + def load(self, config: InstructConfig) -> Dict[str, List[Dict]]: + """Load custom conversation dataset from local file and create splits""" + if not config.data_path: + raise ValueError("Data path is required for custom datasets") + + file_path = Path(config.data_path) + + if not file_path.exists(): + raise FileNotFoundError(f"Data file not found: {file_path}") + + logger.info(f"Loading custom conversation dataset: {file_path}") + + if config.data_format == "jsonl": + raw_data = self._load_jsonl(file_path, config) + elif config.data_format == "json": + raw_data = self._load_json(file_path, config) + else: + raise ValueError(f"Unsupported format: {config.data_format}") + + if config.max_samples: + raw_data = raw_data[:config.max_samples] + + logger.info(f"Loaded {len(raw_data)} conversation samples from {file_path}") + + # Create splits from the raw data + splits_data = self._create_splits(raw_data, config) + + return splits_data + + def _create_splits(self, data: List[Dict], config: InstructConfig) -> Dict[str, List[Dict]]: + """Create train/validation/test splits from raw data""" + logger.info(f"Creating splits from {len(data)} conversation samples...") + + # Handle very small datasets + if len(data) < 3: + logger.warning(f"Dataset has only {len(data)} samples. Using all data for training.") + return { + "train": data, + "validation": [], + "test": [] + } + + # Calculate split sizes + total_samples = len(data) + + # Adjust split ratios if dataset is too small + if total_samples < 10: + config.train_split = 0.6 + config.validation_split = 0.2 + config.test_split = 0.2 + logger.info(f"Small dataset detected. Adjusted split ratios to: train={config.train_split}, val={config.validation_split}, test={config.test_split}") + + val_size = max(1, int(total_samples * config.validation_split)) + test_size = max(1, int(total_samples * config.test_split)) + train_size = total_samples - val_size - test_size + + # Ensure train split has at least 1 sample + if train_size < 1: + if val_size > 1: + val_size -= 1 + train_size += 1 + elif test_size > 1: + test_size -= 1 + train_size += 1 + + logger.info(f"Split sizes: train={train_size}, validation={val_size}, test={test_size}") + + # Create splits + if val_size == 0 and test_size == 0: + splits_data = { + "train": data, + "validation": [], + "test": [] + } + elif val_size == 0: + train_data, test_data = train_test_split(data, test_size=test_size, random_state=42) + splits_data = { + "train": train_data, + "validation": [], + "test": test_data + } + elif test_size == 0: + train_data, val_data = train_test_split(data, test_size=val_size, random_state=42) + splits_data = { + "train": train_data, + "validation": val_data, + "test": [] + } + else: + # Full three-way split + train_data, temp_data = train_test_split( + data, + test_size=val_size + test_size, + random_state=42 + ) + + val_data, test_data = train_test_split( + temp_data, + test_size=test_size, + random_state=42 + ) + + splits_data = { + "train": train_data, + "validation": val_data, + "test": test_data + } + + logger.info(f"Created conversation splits:") + logger.info(f" Train: {len(splits_data['train'])} samples") + logger.info(f" Validation: {len(splits_data['validation'])} samples") + logger.info(f" Test: {len(splits_data['test'])} samples") + + return splits_data + + def _load_jsonl(self, file_path: Path, config: InstructConfig) -> List[Dict]: + """Load JSONL file""" + data = [] + with open(file_path, 'r', encoding=config.encoding) as f: + for line_num, line in enumerate(f, 1): + if line.strip(): + try: + data.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at line {line_num}: {e}") + return data + + def _load_json(self, file_path: Path, config: InstructConfig) -> List[Dict]: + """Load JSON file""" + with open(file_path, 'r', encoding=config.encoding) as f: + data = json.load(f) + + if isinstance(data, list): + return data + elif isinstance(data, dict) and "data" in data: + return data["data"] + else: + return [data] + + def preprocess(self, data: Dict[str, List[Dict]], config: InstructConfig) -> Dict[str, List[Dict]]: + """Apply preprocessing steps to all splits separately""" + processed_splits = {} + + logger.info(f"=== PREPROCESSING CUSTOM CONVERSATION DATA ===") + + for split_name, split_data in data.items(): + logger.info(f"Processing {split_name} split with {len(split_data)} items...") + + processed_data = [] + processed_count = 0 + skipped_count = 0 + + for i, item in enumerate(split_data): + processed_item = self._preprocess_item(item, config) + if processed_item is not None: + processed_data.append(processed_item) + processed_count += 1 + else: + skipped_count += 1 + if skipped_count <= 3: # Log first few skipped items + logger.info(f"Skipped item {i} from {split_name}: {item}") + + processed_splits[split_name] = processed_data + logger.info(f"{split_name} - Preprocessed {processed_count} samples, skipped {skipped_count} samples") + + return processed_splits + + def _preprocess_item(self, item: Dict, config: InstructConfig) -> Optional[Dict]: + """Preprocess a single conversation item""" + conversation = item.get(config.conversation_field, []) + + if not isinstance(conversation, list) or not conversation: + return None + + # Validate conversation structure + valid_conversation = [] + for turn in conversation: + if not isinstance(turn, dict): + continue + if "role" not in turn or "content" not in turn: + continue + if turn["role"] not in ["user", "assistant", "system"]: + continue + + content = str(turn["content"]).strip() + if len(content) < config.min_length or len(content) > config.max_length: + continue + + if config.clean_text: + content = self._clean_text(content) + + valid_conversation.append({ + "role": turn["role"], + "content": content + }) + + if len(valid_conversation) < 2: # Need at least 2 turns for a conversation + return None + + return {"conversation": valid_conversation} + + def _clean_text(self, text: str) -> str: + """Clean and normalize text""" + if not isinstance(text, str): + return "" + + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text).strip() + return text + +class InstructDataPipeline: + """Main instruction fine-tuning data pipeline""" + + def __init__(self): + self.validator = ConversationValidator() + self.hf_loader = HuggingFaceInstructDataLoader() + self.custom_loader = CustomInstructDataLoader() + + def create_config( + self, + data_source: str, + dataset_name: Optional[str] = None, + data_path: Optional[str] = None, + conversation_field: str = "conversation", + **kwargs + ) -> InstructConfig: + """Create instruction configuration""" + return InstructConfig( + data_source=data_source, + dataset_name=dataset_name, + data_path=data_path, + conversation_field=conversation_field, + **kwargs + ) + + def load_config_from_yaml(self, yaml_path: str) -> InstructConfig: + """Load configuration from YAML file""" + try: + config_dict = load_yaml_config(yaml_path) + + # Create configuration object from YAML data + config = InstructConfig( + data_source=config_dict.get('data_source', 'custom'), + dataset_name=config_dict.get('dataset_name'), + data_path=config_dict.get('data_path'), + data_format=config_dict.get('data_format', 'jsonl'), + conversation_field=config_dict.get('conversation_field', 'conversation'), + max_samples=config_dict.get('max_samples'), + train_split=config_dict.get('train_split', 0.8), + validation_split=config_dict.get('validation_split', 0.1), + test_split=config_dict.get('test_split', 0.1), + clean_text=config_dict.get('clean_text', True), + min_length=config_dict.get('min_length', 10), + max_length=config_dict.get('max_length', 2048), + output_format=config_dict.get('output_format', 'conversation'), + output_dir=config_dict.get('output_dir', './data/processed/instruct'), + hf_split=config_dict.get('hf_split', 'train'), + hf_cache_dir=config_dict.get('hf_cache_dir'), + test_split_from=config_dict.get('test_split_from', 'train'), + val_split_from=config_dict.get('val_split_from', 'train'), + encoding=config_dict.get('encoding', 'utf-8') + ) + + logger.info(f"Configuration loaded from YAML: {yaml_path}") + logger.info(f"Output directory: {config.output_dir}") + logger.info(f"Conversation field: {config.conversation_field}") + + return config + + except Exception as e: + logger.error(f"Error loading configuration from YAML {yaml_path}: {e}") + raise + + def load_and_preprocess(self, config: InstructConfig) -> Tuple[Dict[str, List[Dict]], Dict[str, Any]]: + """Load and preprocess conversation data""" + + logger.info(f"Starting conversation data loading and preprocessing...") + logger.info(f"Data source: {config.data_source}") + + try: + # Load data + if config.data_source == "huggingface": + logger.info("Loading HuggingFace conversation dataset...") + raw_splits = self.hf_loader.load(config) + logger.info("Preprocessing HuggingFace conversation dataset...") + processed_splits = self.hf_loader.preprocess(raw_splits, config) + elif config.data_source == "custom": + logger.info("Loading custom conversation dataset...") + raw_splits = self.custom_loader.load(config) + logger.info("Preprocessing custom conversation dataset...") + processed_splits = self.custom_loader.preprocess(raw_splits, config) + else: + raise ValueError(f"Unsupported data source: {config.data_source}") + + logger.info(f"Conversation data loading and preprocessing completed successfully") + + # Validate processed data + logger.info("Validating processed conversation data...") + is_valid, errors = self.validator.validate_conversation_data(processed_splits, config, is_processed=True) + if not is_valid: + logger.error("Conversation data validation failed:") + for error in errors: + logger.error(f" - {error}") + raise ValueError("Conversation data validation failed") + + logger.info("Conversation data validation passed") + + # Analyze dataset + logger.info("Analyzing conversation dataset...") + analysis = self.validator.analyze_conversation_dataset(processed_splits, config, is_processed=True) + logger.info("Conversation dataset analysis completed") + + return processed_splits, analysis + + except Exception as e: + logger.error(f"Error in load_and_preprocess: {e}") + raise + + def save_data(self, data: Dict[str, List[Dict]], output_dir: str, format: str = "jsonl"): + """Save processed conversation data splits to files""" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + for split_name, split_data in data.items(): + if format == "jsonl": + output_file = output_path / f"{split_name}.jsonl" + with open(output_file, 'w', encoding='utf-8') as f: + for item in split_data: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + elif format == "json": + output_file = output_path / f"{split_name}.json" + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(split_data, f, ensure_ascii=False, indent=2) + + logger.info(f"Saved {len(split_data)} conversation samples to {output_file}") + + def run_pipeline( + self, + config: InstructConfig, + save_splits: bool = True + ) -> Dict[str, Any]: + """Run complete instruction data pipeline""" + + logger.info("Starting instruction data pipeline...") + + # Load and preprocess data + processed_splits, analysis = self.load_and_preprocess(config) + + # Save data if requested + if save_splits: + output_dir = Path(config.output_dir) + self.save_data(processed_splits, str(output_dir)) + + # Create result summary + result = { + "config": config, + "analysis": analysis, + "splits": { + split_name: len(split_data) for split_name, split_data in processed_splits.items() + }, + "output_format": config.output_format, + "output_dir": config.output_dir, + "data": processed_splits, # Include the actual processed data + } + + logger.info("Instruction data pipeline completed successfully!") + return result + +def load_yaml_config(config_path: str) -> Dict[str, Any]: + """Load and parse YAML configuration file with proper structure handling""" + try: + with open(config_path, 'r', encoding='utf-8') as f: + yaml_data = yaml.safe_load(f) + + # Extract configuration from YAML structure + config_dict = {} + + # Handle task section + if 'task' in yaml_data: + task_data = yaml_data['task'] + config_dict.update({ + 'task_name': task_data.get('name'), + 'task_type': task_data.get('type') + }) + + # Handle data section + if 'data' in yaml_data: + data_config = yaml_data['data'] + config_dict.update({ + 'data_source': data_config.get('source'), + 'dataset_name': data_config.get('dataset_name'), + 'data_path': data_config.get('data_path'), + 'data_format': data_config.get('data_format'), + 'conversation_field': data_config.get('conversation_field'), + 'max_samples': data_config.get('max_samples'), + 'train_split': data_config.get('train_split'), + 'validation_split': data_config.get('validation_split'), + 'test_split': data_config.get('test_split'), + 'clean_text': data_config.get('clean_text'), + 'min_length': data_config.get('min_length'), + 'max_length': data_config.get('max_length'), + 'output_format': data_config.get('output_format'), + 'output_dir': data_config.get('output_dir'), + 'encoding': data_config.get('encoding') + }) + + logger.info(f"Successfully parsed YAML configuration from: {config_path}") + logger.info(f"Extracted {len(config_dict)} configuration parameters") + + return config_dict + + except Exception as e: + logger.error(f"Error loading YAML config from {config_path}: {e}") + raise + +def main(): + """Main function with YAML configuration support""" + + parser = argparse.ArgumentParser(description="Instruction Data Processing Pipeline") + + # YAML configuration + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + + # Data source arguments + parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source") + parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name") + parser.add_argument("--data-path", type=str, help="Path to custom data file") + parser.add_argument("--data-format", choices=["jsonl", "json"], help="Data format") + + # Field mapping + parser.add_argument("--conversation-field", type=str, help="Conversation field name") + + # Data processing + parser.add_argument("--max-samples", type=int, help="Maximum samples to process") + parser.add_argument("--train-split", type=float, help="Training split ratio") + parser.add_argument("--validation-split", type=float, help="Validation split ratio") + parser.add_argument("--test-split", type=float, help="Test split ratio") + + # Output configuration + parser.add_argument("--output-dir", type=str, help="Output directory") + + # Logging + parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") + + args = parser.parse_args() + + # Set up logging + logging.basicConfig( + level=getattr(logging, args.log_level), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # Load configuration + config_dict = {} + + # Load YAML config if provided + if args.config: + try: + config_dict = load_yaml_config(args.config) + except Exception as e: + logger.error(f"Error loading YAML config: {e}") + sys.exit(1) + + # Override YAML config with CLI arguments (similar to styling pipeline) + cli_overrides = {} + if args.data_source: + cli_overrides['data_source'] = args.data_source + if args.dataset_name: + cli_overrides['dataset_name'] = args.dataset_name + if args.data_path: + cli_overrides['data_path'] = args.data_path + if args.data_format: + cli_overrides['data_format'] = args.data_format + if args.conversation_field: + cli_overrides['conversation_field'] = args.conversation_field + if args.max_samples: + cli_overrides['max_samples'] = args.max_samples + if args.train_split: + cli_overrides['train_split'] = args.train_split + if args.validation_split: + cli_overrides['validation_split'] = args.validation_split + if args.test_split: + cli_overrides['test_split'] = args.test_split + if args.output_dir: + cli_overrides['output_dir'] = args.output_dir + + # Merge configurations + for key, value in cli_overrides.items(): + if key in config_dict: + logger.info(f"Overriding YAML config '{key}' with CLI value: {value}") + config_dict[key] = value + + # Validate required arguments + if not config_dict.get('data_source'): + parser.error("--data-source is required (either in YAML config or CLI)") + + if config_dict.get('data_source') == "huggingface" and not config_dict.get('dataset_name'): + parser.error("--dataset-name is required for HuggingFace datasets") + + if config_dict.get('data_source') == "custom" and not config_dict.get('data_path'): + parser.error("--data-path is required for custom datasets") + + # Create configuration object + config = InstructConfig( + data_source=config_dict.get('data_source', 'custom'), + dataset_name=config_dict.get('dataset_name'), + data_path=config_dict.get('data_path'), + data_format=config_dict.get('data_format', 'jsonl'), + conversation_field=config_dict.get('conversation_field', 'conversation'), + max_samples=config_dict.get('max_samples'), + train_split=config_dict.get('train_split', 0.8), + validation_split=config_dict.get('validation_split', 0.1), + test_split=config_dict.get('test_split', 0.1), + clean_text=config_dict.get('clean_text', True), + min_length=config_dict.get('min_length', 10), + max_length=config_dict.get('max_length', 2048), + output_format=config_dict.get('output_format', 'conversation'), + output_dir=config_dict.get('output_dir', './data/processed/instruct'), + hf_split=config_dict.get('hf_split', 'train'), + hf_cache_dir=config_dict.get('hf_cache_dir'), + test_split_from=config_dict.get('test_split_from', 'train'), + val_split_from=config_dict.get('val_split_from', 'train'), + encoding=config_dict.get('encoding', 'utf-8') + ) + + # Initialize pipeline + pipeline = InstructDataPipeline() + + try: + print(f"Starting instruction data pipeline with {config.data_source} data source...") + if args.config: + print(f"Using YAML configuration: {args.config}") + print(f"Conversation field: {config.conversation_field}") + print() + + result = pipeline.run_pipeline(config, save_splits=True) + + print(f"✅ Pipeline completed successfully!") + print(f" Data source: {config.data_source}") + if config.data_source == "huggingface": + print(f" Dataset: {config.dataset_name}") + else: + print(f" Data file: {config.data_path}") + print(f" Total samples: {result['analysis']['overall']['total_samples']}") + print(f" Split sizes: {result['analysis']['overall']['split_sizes']}") + print(f" Output directory: {config.output_dir}") + print(f" Conversation stats: {result['analysis']['overall']['conversation_stats']}") + + except Exception as e: + print(f"❌ Error running pipeline: {e}") + import traceback + print("Full error traceback:") + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/pipelines/instruct/inference.py b/pipelines/instruct/inference.py new file mode 100644 index 0000000..161a7a4 --- /dev/null +++ b/pipelines/instruct/inference.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python3 +""" +Instruct Inference Pipeline using Trained Models +Supports conversational inference with streaming and batch processing +""" + +import os +import sys +import json +import argparse +from pathlib import Path +from typing import Dict, Any, Optional, List, Union +import yaml + +# Add the project root to the path +sys.path.append(str(Path(__file__).parent.parent.parent)) + +# Inference imports +import torch +from datasets import load_from_disk, Dataset +from unsloth import FastLanguageModel +from unsloth.chat_templates import get_chat_template +from transformers import TextStreamer + +class InstructInference: + """Instruction fine-tuning inference using trained models""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.model = None + self.tokenizer = None + + # Set device + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {self.device}") + + # Model parameters + self.model_output_dir = config.get('model_output_dir', './models/instruct') + self.base_model_name = config.get('base_model_name', 'unsloth/Qwen2.5-72B-Instruct') + self.max_seq_length = config.get('max_seq_length', 2048) + self.dtype = config.get('dtype', None) + self.load_in_4bit = config.get('load_in_4bit', True) + self.hf_token = config.get('hf_token', None) + + # Inference parameters + self.batch_size = config.get('batch_size', 1) + self.max_new_tokens = config.get('max_new_tokens', 128) + self.temperature = config.get('temperature', 1.5) + self.min_p = config.get('min_p', 0.1) + self.use_cache = config.get('use_cache', True) + + # Chat template + self.chat_template = config.get('chat_template', 'qwen-2.5') + + def load_model_and_tokenizer(self): + """Load the trained model and tokenizer""" + print("Loading trained instruction model and tokenizer...") + + try: + # Load the saved LoRA model + model_path = self.model_output_dir + print(f"Loading model from: {model_path}") + + # Check if the model directory exists + if not Path(model_path).exists(): + raise FileNotFoundError(f"Model directory not found: {model_path}") + + # Load the model directly from the saved path + self.model, self.tokenizer = FastLanguageModel.from_pretrained( + model_name=model_path, + max_seq_length=self.max_seq_length, + dtype=self.dtype, + load_in_4bit=self.load_in_4bit, + ) + + # Enable native 2x faster inference + FastLanguageModel.for_inference(self.model) + + print(f"✅ Model loaded from: {model_path}") + print(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}") + + except Exception as e: + print(f"❌ Error loading model: {e}") + raise + + def setup_chat_template(self): + """Setup chat template for conversation formatting""" + print("Setting up chat template...") + + try: + self.tokenizer = get_chat_template( + self.tokenizer, + chat_template=self.chat_template, + ) + + print(f"✅ Chat template configured: {self.chat_template}") + + except Exception as e: + print(f"❌ Error setting up chat template: {e}") + raise + + def format_messages(self, messages: List[Dict[str, str]]) -> str: + """Format messages using chat template""" + try: + # Apply chat template to format the conversation + formatted_prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, # Add generation prompt for inference + ) + + return formatted_prompt + + except Exception as e: + print(f"❌ Error formatting messages: {e}") + raise + + def generate_response( + self, + messages: List[Dict[str, str]], + max_new_tokens: Optional[int] = None, + temperature: Optional[float] = None, + stream: bool = False + ) -> str: + """Generate response using the trained instruction model""" + try: + # Use default values if not provided + max_tokens = max_new_tokens or self.max_new_tokens + temp = temperature or self.temperature + + # Format the messages + formatted_prompt = self.format_messages(messages) + print(f"Formatted prompt: {formatted_prompt[:200]}...") + + # Tokenize the input + inputs = self.tokenizer( + [formatted_prompt], + return_tensors="pt" + ).to(self.device) + + if stream: + # Streaming generation + text_streamer = TextStreamer(self.tokenizer, skip_prompt=True) + print("Generating with streaming...") + _ = self.model.generate( + input_ids=inputs.input_ids, + streamer=text_streamer, + max_new_tokens=max_tokens, + use_cache=self.use_cache, + temperature=temp, + min_p=self.min_p + ) + return "" # Streaming output is handled by streamer + else: + # Non-streaming generation + print("Generating response...") + outputs = self.model.generate( + input_ids=inputs.input_ids, + max_new_tokens=max_tokens, + use_cache=self.use_cache, + temperature=temp, + min_p=self.min_p, + pad_token_id=self.tokenizer.eos_token_id + ) + + # Decode the generated text + full_response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + + # Extract only the new generated response (remove the input prompt) + prompt_length = len(formatted_prompt) + response = full_response[prompt_length:].strip() + + return response + + except Exception as e: + print(f"❌ Error generating response: {e}") + raise + + def chat(self, user_input: str, conversation_history: Optional[List[Dict[str, str]]] = None, stream: bool = False) -> str: + """Have a chat conversation with the model""" + try: + # Initialize conversation history if not provided + if conversation_history is None: + conversation_history = [] + + # Add user input to conversation + messages = conversation_history + [{"role": "user", "content": user_input}] + + print(f"User: {user_input}") + + if stream: + print("Assistant: ", end="", flush=True) + self.generate_response(messages, stream=True) + return "" + else: + # Generate response + response = self.generate_response(messages, stream=False) + print(f"Assistant: {response}") + return response + + except Exception as e: + print(f"❌ Error in chat: {e}") + raise + + def batch_inference( + self, + conversations: List[List[Dict[str, str]]], + max_new_tokens: Optional[int] = None + ) -> List[str]: + """Perform batch inference on multiple conversations""" + responses = [] + + for i, messages in enumerate(conversations): + print(f"Processing conversation {i+1}/{len(conversations)}") + response = self.generate_response(messages, max_new_tokens) + responses.append(response) + + return responses + + def interactive_chat(self): + """Start an interactive chat session""" + print("🤖 Starting interactive chat session...") + print("Type 'quit', 'exit', or 'bye' to end the conversation.") + print("Type 'clear' to clear conversation history.") + print("Type 'stream on' or 'stream off' to toggle streaming.") + print("-" * 50) + + conversation_history = [] + streaming = False + + while True: + try: + user_input = input("\n👤 You: ").strip() + + if user_input.lower() in ['quit', 'exit', 'bye']: + print("👋 Goodbye!") + break + elif user_input.lower() == 'clear': + conversation_history = [] + print("🗑️ Conversation history cleared.") + continue + elif user_input.lower() == 'stream on': + streaming = True + print("🔄 Streaming enabled.") + continue + elif user_input.lower() == 'stream off': + streaming = False + print("⏸️ Streaming disabled.") + continue + elif not user_input: + continue + + # Generate response + if streaming: + print("🤖 Assistant: ", end="", flush=True) + self.chat(user_input, conversation_history, stream=True) + # Add to history (we don't have the actual response text for streaming) + conversation_history.extend([ + {"role": "user", "content": user_input}, + {"role": "assistant", "content": "[Streamed response]"} + ]) + else: + response = self.chat(user_input, conversation_history, stream=False) + # Add to history + conversation_history.extend([ + {"role": "user", "content": user_input}, + {"role": "assistant", "content": response} + ]) + + except KeyboardInterrupt: + print("\n👋 Goodbye!") + break + except Exception as e: + print(f"❌ Error: {e}") + continue + +def load_inference_config(config_path: str) -> Dict[str, Any]: + """Load inference configuration from YAML file""" + try: + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + # Extract inference configuration + inference_config = {} + + # Model configuration + if 'model' in config: + model_data = config['model'] + inference_config.update({ + 'base_model_name': model_data.get('training_model', 'unsloth/Qwen2.5-72B-Instruct'), + 'max_seq_length': model_data.get('training_max_seq_length', 2048), + 'dtype': model_data.get('training_dtype'), + 'load_in_4bit': model_data.get('training_load_in_4bit', True), + 'hf_token': model_data.get('training_token') + }) + + # Training configuration - to get model_output_dir + if 'training' in config: + training_data = config['training'] + inference_config.update({ + 'model_output_dir': training_data.get('model_output_dir', './models/instruct') + }) + + # Inference configuration + if 'inference' in config: + inference_data = config['inference'] + inference_config.update({ + 'batch_size': inference_data.get('batch_size', 1), + 'max_new_tokens': inference_data.get('max_new_tokens', 128), + 'temperature': inference_data.get('temperature', 1.5), + 'min_p': inference_data.get('min_p', 0.1), + 'use_cache': inference_data.get('use_cache', True) + }) + + # Chat template + inference_config.update({ + 'chat_template': 'qwen-2.5' # Use Qwen chat template by default + }) + + return inference_config + + except Exception as e: + print(f"Error loading inference config: {e}") + raise + +def main(): + """Main inference function""" + parser = argparse.ArgumentParser(description="Instruction Inference Pipeline") + + # Configuration + parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + parser.add_argument("--interactive", action="store_true", help="Start interactive chat session") + parser.add_argument("--message", type=str, help="Single message to send to the model") + parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate") + parser.add_argument("--stream", action="store_true", help="Enable streaming generation") + parser.add_argument("--temperature", type=float, help="Sampling temperature") + + args = parser.parse_args() + + try: + # Load configuration + print(f"Loading configuration from: {args.config}") + inference_config = load_inference_config(args.config) + + # Override with CLI arguments + if args.max_tokens: + inference_config['max_new_tokens'] = args.max_tokens + if args.temperature: + inference_config['temperature'] = args.temperature + + print("Inference configuration:") + for key, value in inference_config.items(): + print(f" {key}: {value}") + + # Initialize inference + inference = InstructInference(inference_config) + + # Load model and tokenizer + inference.load_model_and_tokenizer() + + # Setup chat template + inference.setup_chat_template() + + # Run inference based on mode + if args.interactive: + # Interactive chat mode + inference.interactive_chat() + elif args.message: + # Single message mode + print("Running single message inference...") + messages = [{"role": "user", "content": args.message}] + + if args.stream: + print("User:", args.message) + print("Assistant: ", end="", flush=True) + inference.generate_response(messages, stream=True) + else: + response = inference.generate_response(messages, stream=False) + print(f"User: {args.message}") + print(f"Assistant: {response}") + else: + # Default to interactive mode if no specific mode is chosen + print("No specific mode chosen. Starting interactive chat...") + inference.interactive_chat() + + except Exception as e: + print(f"Inference failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/pipelines/instruct/train.py b/pipelines/instruct/train.py new file mode 100644 index 0000000..aa03480 --- /dev/null +++ b/pipelines/instruct/train.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python3 +""" +Instruct Training Pipeline using Unsloth and SFTTrainer +Supports instruction fine-tuning with conversational data and LoRA fine-tuning +""" + +import os +import sys +import json +import logging +import argparse +from pathlib import Path +from typing import Dict, Any, Optional, List +import yaml + +# Add the project root to the path +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from utils.config.config_manager import ConfigManager + +# Training imports +import torch +from datasets import load_from_disk, Dataset +from unsloth import FastLanguageModel, is_bfloat16_supported +from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only +from trl import SFTTrainer, SFTConfig +from transformers import DataCollatorForSeq2Seq + +logger = logging.getLogger(__name__) + +class InstructTrainer: + """Instruction fine-tuning trainer using Unsloth and SFTTrainer""" + + def __init__(self, config: Dict[str, Any]): + self.config = config + self.model = None + self.tokenizer = None + self.trainer = None + + # Set device + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {self.device}") + + # Model parameters + self.model_name = config.get('model_name', 'unsloth/Qwen2.5-72B-Instruct') + self.max_seq_length = config.get('max_seq_length', 2048) + self.dtype = config.get('dtype', None) + self.load_in_4bit = config.get('load_in_4bit', True) + self.hf_token = config.get('hf_token', None) + + # LoRA parameters + self.lora_r = config.get('lora_r', 32) + self.lora_alpha = config.get('lora_alpha', 16) + self.lora_dropout = config.get('lora_dropout', 0) + self.target_modules = config.get('target_modules', [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj" + ]) + + # Training arguments + self.batch_size = config.get('batch_size', 1) + self.gradient_accumulation_steps = config.get('gradient_accumulation_steps', 4) + self.learning_rate = config.get('learning_rate', 2e-4) + self.num_epochs = config.get('num_epochs', 1) + self.max_steps = config.get('max_steps', 30) + self.warmup_steps = config.get('warmup_steps', 5) + self.weight_decay = config.get('weight_decay', 0.01) + self.seed = config.get('seed', 3407) + + # Output paths + self.output_dir = config.get('output_dir', './outputs') + self.model_output_dir = config.get('model_output_dir', './models/instruct') + + # Chat template + self.chat_template = config.get('chat_template', 'qwen-2.5') + + def load_model_and_tokenizer(self): + """Load the pre-trained model and tokenizer""" + logger.info("Loading model and tokenizer...") + + try: + self.model, self.tokenizer = FastLanguageModel.from_pretrained( + model_name=self.model_name, + max_seq_length=self.max_seq_length, + dtype=self.dtype, + load_in_4bit=self.load_in_4bit, + token=self.hf_token + ) + + logger.info(f"✅ Model loaded: {self.model_name}") + logger.info(f"✅ Tokenizer loaded with vocab size: {self.tokenizer.vocab_size}") + + except Exception as e: + logger.error(f"❌ Error loading model: {e}") + raise + + def setup_lora(self): + """Setup LoRA for efficient fine-tuning""" + logger.info("Setting up LoRA configuration...") + + try: + self.model = FastLanguageModel.get_peft_model( + self.model, + r=self.lora_r, + target_modules=self.target_modules, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + bias="none", + use_gradient_checkpointing="unsloth", + random_state=self.seed, + use_rslora=False, + loftq_config=None + ) + + logger.info(f"✅ LoRA configured with r={self.lora_r}, alpha={self.lora_alpha}") + + except Exception as e: + logger.error(f"❌ Error setting up LoRA: {e}") + raise + + def setup_chat_template(self): + """Setup chat template for conversation formatting""" + logger.info("Setting up chat template...") + + try: + self.tokenizer = get_chat_template( + self.tokenizer, + chat_template=self.chat_template, + ) + + logger.info(f"✅ Chat template configured: {self.chat_template}") + + except Exception as e: + logger.error(f"❌ Error setting up chat template: {e}") + raise + + def load_dataset(self, dataset_path: str) -> Dataset: + """Load the conversation training dataset""" + logger.info(f"Loading conversation dataset from: {dataset_path}") + + try: + if Path(dataset_path).exists(): + # Check if it's a HuggingFace dataset directory + if (Path(dataset_path) / "dataset_info.json").exists(): + # Load from HuggingFace dataset directory + dataset = load_from_disk(dataset_path) + logger.info(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples") + else: + # Load from processed conversation data files (JSONL format) + logger.info("Loading from processed conversation data files...") + from datasets import Dataset + import json + + all_data = [] + data_dir = Path(dataset_path) + + # Look for train.jsonl, validation.jsonl, test.jsonl + for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]: + file_path = data_dir / split_file + if file_path.exists(): + logger.info(f"Loading {split_file}...") + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + data = json.loads(line) + all_data.append(data) + + if not all_data: + raise ValueError(f"No conversation data found in {dataset_path}") + + # Create HuggingFace dataset + dataset = Dataset.from_list(all_data) + logger.info(f"Created HuggingFace dataset from {len(all_data)} conversation samples") + else: + # Try loading from HuggingFace Hub + logger.info(f"Attempting to load from HuggingFace Hub: {dataset_path}") + dataset = Dataset.load_dataset(dataset_path, split="train") + logger.info(f"Loaded from HuggingFace Hub: {len(dataset)} samples") + + logger.info(f"Dataset loaded: {len(dataset)} samples") + logger.info(f"Dataset features: {dataset.features}") + + # Verify required fields exist for conversation data + required_fields = ["conversation"] + missing_fields = [field for field in required_fields if field not in dataset.features] + if missing_fields: + raise ValueError(f"Missing required fields in conversation dataset: {missing_fields}") + + return dataset + + except Exception as e: + logger.error(f"Error loading conversation dataset: {e}") + raise + + def format_dataset_for_training(self, dataset: Dataset) -> Dataset: + """Format conversation dataset for training using standardize_sharegpt and apply_chat_template""" + logger.info("Formatting conversation dataset for training...") + + try: + # Standardize the ShareGPT format + logger.info("Standardizing ShareGPT format...") + dataset = standardize_sharegpt(dataset) + + # Define the formatting function for chat templates + def formatting_prompts_func(examples): + convos = examples["conversation"] + texts = [ + self.tokenizer.apply_chat_template( + convo, + tokenize=False, + add_generation_prompt=False + ) for convo in convos + ] + return {"text": texts} + + # Apply the formatting function + logger.info("Applying chat template formatting...") + dataset = dataset.map(formatting_prompts_func, batched=True) + + logger.info(f"✅ Dataset formatted for training with {len(dataset)} samples") + logger.info(f"Sample formatted text: {dataset[0]['text'][:200]}...") + + return dataset + + except Exception as e: + logger.error(f"❌ Error formatting dataset: {e}") + raise + + def setup_trainer(self, train_dataset: Dataset): + """Setup the SFTTrainer for instruction fine-tuning""" + logger.info("Setting up SFTTrainer for instruction fine-tuning...") + + try: + # SFT Configuration + sft_config = SFTConfig( + per_device_train_batch_size=self.batch_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + warmup_steps=self.warmup_steps, + max_steps=self.max_steps, + learning_rate=self.learning_rate, + logging_steps=1, + optim="paged_adamw_8bit", + weight_decay=self.weight_decay, + lr_scheduler_type="linear", + seed=self.seed, + output_dir=self.output_dir, + report_to="none", # Disable wandb for now + ) + + logger.info("SFT Configuration:") + logger.info(f" batch_size: {self.batch_size}") + logger.info(f" gradient_accumulation_steps: {self.gradient_accumulation_steps}") + logger.info(f" warmup_steps: {self.warmup_steps}") + logger.info(f" max_steps: {self.max_steps}") + logger.info(f" learning_rate: {self.learning_rate}") + + # Create SFTTrainer + self.trainer = SFTTrainer( + model=self.model, + tokenizer=self.tokenizer, + train_dataset=train_dataset, + dataset_text_field="text", + max_seq_length=self.max_seq_length, + data_collator=DataCollatorForSeq2Seq(tokenizer=self.tokenizer), + packing=False, # Disable packing for conversation data + args=sft_config, + ) + + logger.info("✅ SFTTrainer configured successfully") + + except Exception as e: + logger.error(f"❌ Error setting up trainer: {e}") + import traceback + logger.error("Full error traceback:") + traceback.print_exc() + raise + + def setup_response_only_training(self): + """Setup training to only learn from assistant responses""" + logger.info("Setting up response-only training...") + + try: + # Configure trainer to only train on responses + self.trainer = train_on_responses_only( + self.trainer, + instruction_part="<|im_start|>user\n", + response_part="<|im_start|>assistant\n", + ) + + logger.info("✅ Response-only training configured") + + except Exception as e: + logger.error(f"❌ Error setting up response-only training: {e}") + raise + + def train(self, dataset_path: str): + """Run the instruction fine-tuning process""" + logger.info("🚀 Starting instruction fine-tuning process...") + + try: + # Load model and tokenizer + logger.info("Step 1: Loading model and tokenizer...") + self.load_model_and_tokenizer() + + # Setup LoRA + logger.info("Step 2: Setting up LoRA...") + self.setup_lora() + + # Setup chat template + logger.info("Step 3: Setting up chat template...") + self.setup_chat_template() + + # Load dataset + logger.info(f"Step 4: Loading conversation dataset from: {dataset_path}") + train_dataset = self.load_dataset(dataset_path) + + # Format dataset for training + logger.info("Step 5: Formatting dataset for training...") + formatted_dataset = self.format_dataset_for_training(train_dataset) + + # Setup trainer + logger.info("Step 6: Setting up trainer...") + self.setup_trainer(formatted_dataset) + + # Setup response-only training (optional but recommended for chat models) + logger.info("Step 7: Setting up response-only training...") + self.setup_response_only_training() + + # Start training + logger.info("Step 8: Starting training...") + trainer_stats = self.trainer.train() + + logger.info("✅ Instruction fine-tuning completed successfully!") + logger.info(f"Training stats: {trainer_stats}") + + # Save the model + self.save_model() + + return trainer_stats + + except Exception as e: + logger.error(f"❌ Instruction fine-tuning failed: {e}") + import traceback + logger.error("Full error traceback:") + traceback.print_exc() + raise + + def save_model(self): + """Save the trained instruction model""" + logger.info("Saving trained instruction model...") + + try: + # Create output directory + Path(self.model_output_dir).mkdir(parents=True, exist_ok=True) + + # Save model and tokenizer + self.model.save_pretrained(self.model_output_dir) + self.tokenizer.save_pretrained(self.model_output_dir) + + # Save training config + config_path = Path(self.model_output_dir) / "training_config.json" + with open(config_path, 'w') as f: + json.dump(self.config, f, indent=2) + + logger.info(f"✅ Instruction model saved to: {self.model_output_dir}") + logger.info(f"✅ You can now use this model for inference") + + except Exception as e: + logger.error(f"❌ Error saving model: {e}") + raise + + def prepare_for_inference(self): + """Prepare model for inference""" + logger.info("Preparing model for inference...") + + try: + FastLanguageModel.for_inference(self.model) + logger.info("✅ Model prepared for inference") + + except Exception as e: + logger.error(f"❌ Error preparing for inference: {e}") + raise + +def load_training_config(yaml_path: str) -> Dict[str, Any]: + """Load training configuration from YAML file""" + try: + with open(yaml_path, 'r') as f: + config = yaml.safe_load(f) + + training_config = {} + + # Model configuration - extract from model section + if 'model' in config: + model_config = config['model'] + training_config.update({ + 'model_name': model_config.get('name', 'unsloth/Qwen2.5-72B-Instruct'), + 'max_seq_length': int(model_config.get('max_seq_length', 2048)), + 'dtype': model_config.get('dtype', None), + 'load_in_4bit': model_config.get('load_in_4bit', True), + 'hf_token': model_config.get('token', None) + }) + + # Training configuration - extract from training section + if 'training' in config: + training_data = config['training'] + logger.info("Training data from YAML:") + logger.info(f" num_epochs: {training_data.get('num_epochs')} (type: {type(training_data.get('num_epochs'))})") + logger.info(f" batch_size: {training_data.get('batch_size')} (type: {type(training_data.get('batch_size'))})") + logger.info(f" learning_rate: {training_data.get('learning_rate')} (type: {type(training_data.get('learning_rate'))})") + logger.info(f" weight_decay: {training_data.get('weight_decay')} (type: {type(training_data.get('weight_decay'))})") + logger.info(f" warmup_steps: {training_data.get('warmup_steps')} (type: {type(training_data.get('warmup_steps'))})") + logger.info(f" max_steps: {training_data.get('max_steps')} (type: {type(training_data.get('max_steps'))})") + logger.info(f" gradient_accumulation_steps: {training_data.get('gradient_accumulation_steps')} (type: {type(training_data.get('gradient_accumulation_steps'))})") + logger.info(f" seed: {training_data.get('seed')} (type: {type(training_data.get('seed'))})") + logger.info(f" model_output_dir: {training_data.get('model_output_dir')} (type: {type(training_data.get('model_output_dir'))})") + + training_config.update({ + 'num_epochs': int(training_data.get('num_epochs', 1)), + 'batch_size': int(training_data.get('batch_size', 1)), + 'learning_rate': float(training_data.get('learning_rate', 2e-4)), + 'weight_decay': float(training_data.get('weight_decay', 0.01)), + 'warmup_steps': int(training_data.get('warmup_steps', 5)), + 'max_steps': int(training_data.get('max_steps', 30)), + 'gradient_accumulation_steps': int(training_data.get('gradient_accumulation_steps', 4)), + 'lr_scheduler_type': training_data.get('lr_scheduler_type', 'linear'), + 'seed': int(training_data.get('seed', 3407)), + 'model_output_dir': training_data.get('model_output_dir', './models/instruct'), + # LoRA configuration + 'lora_r': int(training_data.get('lora_r', 32)), + 'lora_alpha': int(training_data.get('lora_alpha', 16)), + 'lora_dropout': float(training_data.get('lora_dropout', 0)), + 'target_modules': training_data.get('target_modules', [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj" + ]) + }) + + # Data configuration - use output_dir from data section + if 'data' in config: + data_config = config['data'] + output_dir = data_config.get('output_dir', './data/processed/instruct') + training_config.update({ + 'data_output_dir': output_dir, + 'dataset_path': output_dir, # Default dataset path is the output_dir + }) + + # Output configuration + training_config.update({ + 'output_dir': './outputs', + 'chat_template': 'qwen-2.5' # Use Qwen chat template by default + }) + + logger.info("Final training_config:") + for key, value in training_config.items(): + logger.info(f" {key}: {value} (type: {type(value)})") + + return training_config + + except Exception as e: + logger.error(f"Error loading training config: {e}") + raise + +def main(): + """Main training function""" + parser = argparse.ArgumentParser(description="Instruction Fine-tuning Training Pipeline") + + # Configuration + parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + parser.add_argument("--dataset", type=str, help="Path to training dataset (conversation data path)") + parser.add_argument("--output-dir", type=str, help="Output directory for model") + parser.add_argument("--epochs", type=int, help="Number of training epochs") + parser.add_argument("--batch-size", type=int, help="Training batch size") + parser.add_argument("--learning-rate", type=float, help="Learning rate") + parser.add_argument("--max-steps", type=int, help="Maximum training steps") + + args = parser.parse_args() + + # Setup logging + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + try: + # Load configuration + logger.info(f"Loading configuration from: {args.config}") + training_config = load_training_config(args.config) + + # Override with CLI arguments + if args.output_dir: + training_config['model_output_dir'] = args.output_dir + if args.epochs: + training_config['num_epochs'] = int(args.epochs) + if args.batch_size: + training_config['batch_size'] = int(args.batch_size) + if args.learning_rate: + training_config['learning_rate'] = float(args.learning_rate) + if args.max_steps: + training_config['max_steps'] = int(args.max_steps) + + # Determine dataset path: CLI argument takes precedence, then YAML config + dataset_path = args.dataset or training_config.get('dataset_path') + if not dataset_path: + logger.error("No dataset path provided. Use --dataset or ensure output_dir is set in YAML config.") + sys.exit(1) + + logger.info("Training configuration:") + for key, value in training_config.items(): + logger.info(f" {key}: {value}") + logger.info(f" Dataset path: {dataset_path}") + + # Initialize trainer + trainer = InstructTrainer(training_config) + + # Start training + trainer.train(dataset_path) + + logger.info("Instruction fine-tuning completed successfully!") + + except Exception as e: + logger.error(f"Instruction fine-tuning failed: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/scripts/instruct/.ipynb_checkpoints/data_processor-checkpoint.py b/scripts/instruct/.ipynb_checkpoints/data_processor-checkpoint.py new file mode 100644 index 0000000..4716208 --- /dev/null +++ b/scripts/instruct/.ipynb_checkpoints/data_processor-checkpoint.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Instruct data processor script that uses YAML configurations. +This provides a flexible and maintainable approach for instruction fine-tuning tasks. +""" + +import sys +import os +import subprocess +import argparse +from pathlib import Path + +def run_with_yaml_config(config_path: str, **cli_overrides): + """Run instruct data processor with YAML configuration""" + print(f"=== Running Instruct Data Processor with YAML config: {config_path} ===") + + cmd = [ + "python", "pipelines/instruct/data_processor.py", + "--config", config_path + ] + + # Add CLI overrides + for key, value in cli_overrides.items(): + if value is not None: + cmd.extend([f"--{key.replace('_', '-')}", str(value)]) + + print(f"Running command: {' '.join(cmd)}") + print() + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("✅ Instruct data processing completed successfully!") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"❌ Error running instruct data processor: {e}") + print(f"Error output: {e.stderr}") + return False + +def run_instruct_examples(): + """Run instruct examples with YAML configs""" + + # Example 1: Code reasoning instruction tuning + print("=== Example 1: Code Reasoning Instruction Tuning ===") + success = run_with_yaml_config( + "configs/instruct/code_reasoning.yaml", + max_samples=1000, # Override YAML value + output_format="conversation" + ) + + if success: + print("✅ Code reasoning instruction tuning data processing completed!") + + # Example 2: Custom conversation dataset (if available) + print("\n=== Example 2: Custom Conversation Dataset ===") + if os.path.exists("data/raw/instruct/custom_conversations.jsonl"): + success = run_with_yaml_config( + "configs/instruct/code_reasoning.yaml", # Use code_reasoning config as base + data_source="custom", + data_path="data/raw/instruct/custom_conversations.jsonl", + output_dir="./data/processed/instruct/custom" + ) + if success: + print("✅ Custom conversation dataset processing completed!") + else: + print("⚠️ Custom conversation dataset not found, skipping...") + print(" You can create one with the 'create-sample-data' option") + +def create_sample_instruct_data(): + """Create sample conversation dataset for testing""" + sample_data = [ + { + "conversation": [ + {"role": "user", "content": "Can you explain what a Python list comprehension is?"}, + {"role": "assistant", "content": "A Python list comprehension is a concise way to create lists. It allows you to generate a new list by applying an expression to each item in an existing iterable, optionally filtering items with a condition. The basic syntax is: [expression for item in iterable if condition]. For example: [x**2 for x in range(10) if x%2==0] creates a list of squares of even numbers from 0 to 8."} + ] + }, + { + "conversation": [ + {"role": "user", "content": "How do I handle exceptions in Python?"}, + {"role": "assistant", "content": "In Python, you handle exceptions using try-except blocks. Here's the basic structure:\n\ntry:\n # Code that might raise an exception\n risky_code()\nexcept SpecificException as e:\n # Handle specific exception\n print(f'Error: {e}')\nexcept Exception as e:\n # Handle any other exception\n print(f'Unexpected error: {e}')\nelse:\n # Runs if no exception occurred\n print('Success!')\nfinally:\n # Always runs, regardless of exceptions\n cleanup_code()\n\nYou can catch multiple specific exceptions and handle them differently."} + ] + }, + { + "conversation": [ + {"role": "user", "content": "What's the difference between a function and a method in Python?"}, + {"role": "assistant", "content": "In Python, the main differences between functions and methods are:\n\n1. **Definition location**: Functions are defined independently, while methods are defined inside classes.\n\n2. **Calling**: Functions are called directly by name, methods are called on objects using dot notation.\n\n3. **First parameter**: Methods automatically receive 'self' as the first parameter (the instance they're called on), functions don't have this automatic parameter.\n\nExample:\n```python\n# Function\ndef my_function(x):\n return x * 2\n\n# Method (inside a class)\nclass MyClass:\n def my_method(self, x):\n return x * 2\n\n# Usage\nresult1 = my_function(5) # Function call\nobj = MyClass()\nresult2 = obj.my_method(5) # Method call\n```"} + ] + }, + { + "conversation": [ + {"role": "user", "content": "How can I optimize a slow Python loop?"}, + {"role": "assistant", "content": "Here are several strategies to optimize slow Python loops:\n\n1. **Use list comprehensions** instead of explicit loops when possible\n2. **Vectorize with NumPy** for numerical operations\n3. **Use built-in functions** like map(), filter(), sum()\n4. **Avoid repeated lookups** - store frequently accessed values in variables\n5. **Use sets for membership testing** instead of lists\n6. **Consider using itertools** for efficient iteration patterns\n7. **Profile your code** with cProfile to identify bottlenecks\n8. **Use numba or Cython** for compute-intensive loops\n\nExample optimization:\n```python\n# Slow\nresult = []\nfor i in range(1000000):\n if i % 2 == 0:\n result.append(i ** 2)\n\n# Faster\nresult = [i ** 2 for i in range(1000000) if i % 2 == 0]\n\n# Even faster with NumPy\nimport numpy as np\narr = np.arange(1000000)\nresult = (arr[arr % 2 == 0] ** 2).tolist()\n```"} + ] + }, + { + "conversation": [ + {"role": "user", "content": "What are Python decorators and how do I use them?"}, + {"role": "assistant", "content": "Python decorators are a powerful feature that allows you to modify or extend the behavior of functions or classes without permanently modifying their code. They use the @ symbol and are placed above the function definition.\n\n**Basic decorator example:**\n```python\ndef my_decorator(func):\n def wrapper(*args, **kwargs):\n print('Before function call')\n result = func(*args, **kwargs)\n print('After function call')\n return result\n return wrapper\n\n@my_decorator\ndef greet(name):\n print(f'Hello, {name}!')\n\ngreet('Alice') # Prints before, greeting, and after messages\n```\n\n**Common use cases:**\n- Timing function execution\n- Logging\n- Authentication/authorization\n- Caching results\n- Input validation\n\n**Built-in decorators:**\n- `@property` - creates getter/setter methods\n- `@staticmethod` - methods that don't need self or cls\n- `@classmethod` - methods that receive the class as first argument\n\nDecorators make code more modular and reusable by separating concerns."} + ] + } + ] + + # Create directory structure + data_dir = Path("data/raw/instruct") + data_dir.mkdir(parents=True, exist_ok=True) + + # Save sample data + import json + sample_file = data_dir / "code_reasoning.jsonl" + with open(sample_file, 'w', encoding='utf-8') as f: + for item in sample_data: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + print(f"✅ Created sample conversation dataset: {sample_file}") + print(f" Contains {len(sample_data)} conversation examples") + print(f" Format: conversation array with role/content pairs") + print(f" Ready to use with configs/instruct/code_reasoning.yaml") + +def create_custom_instruct_config(): + """Create a custom instruct configuration file""" + custom_config = """# Custom Instruct Configuration +task: + name: "general_chat" + type: "instruction_following" + +data: + source: "custom" + data_path: "./data/raw/instruct/general_chat.jsonl" + data_format: "jsonl" + conversation_field: "conversation" + max_length: 2048 + min_length: 10 + clean_text: true + train_split: 0.8 + validation_split: 0.1 + test_split: 0.1 + output_format: "conversation" + output_dir: "./data/processed/instruct/general_chat" + +model: + name: "unsloth/Qwen2.5-7B-Instruct" + max_length: 2048 + max_seq_length: 2048 + dtype: null + load_in_4bit: true + token: null + training_model: "unsloth/Qwen2.5-7B-Instruct" + training_max_seq_length: 2048 + training_dtype: null + training_load_in_4bit: true + +training: + num_epochs: 1 + batch_size: 1 + learning_rate: 2e-4 + weight_decay: 0.01 + warmup_steps: 5 + max_steps: 50 + gradient_accumulation_steps: 4 + lr_scheduler_type: "linear" + seed: 3407 + lora_r: 16 + lora_alpha: 16 + lora_dropout: 0 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + output_dir: "./outputs" + model_output_dir: "./models/instruct/general_chat" + +inference: + batch_size: 1 + max_new_tokens: 256 + temperature: 0.8 + min_p: 0.1 + use_cache: true +""" + + config_path = "configs/instruct/general_chat.yaml" + os.makedirs(os.path.dirname(config_path), exist_ok=True) + + with open(config_path, 'w') as f: + f.write(custom_config) + + print(f"✅ Created custom instruct config: {config_path}") + print(" This config is set up for general chat instruction tuning") + +def handle_direct_args(): + """Handle direct command-line arguments by passing them to the instruct pipeline""" + parser = argparse.ArgumentParser(description="Instruct Data Processor") + + # Add all the same arguments as the instruct pipeline + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source") + parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name") + parser.add_argument("--data-path", type=str, help="Path to custom data file") + parser.add_argument("--data-format", choices=["jsonl", "json"], help="Data format") + parser.add_argument("--conversation-field", type=str, help="Conversation field name") + parser.add_argument("--max-samples", type=int, help="Maximum samples to process") + parser.add_argument("--train-split", type=float, help="Training split ratio") + parser.add_argument("--validation-split", type=float, help="Validation split ratio") + parser.add_argument("--test-split", type=float, help="Test split ratio") + parser.add_argument("--output-dir", type=str, help="Output directory") + + # Logging + parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") + + args = parser.parse_args() + + # Build command to call the instruct pipeline + cmd = ["python", "pipelines/instruct/data_processor.py"] + + # Add all arguments that were provided + for arg_name, arg_value in vars(args).items(): + if arg_value is not None: + if isinstance(arg_value, bool): + if arg_value: # Only add flag if True + cmd.append(f"--{arg_name.replace('_', '-')}") + else: + cmd.extend([f"--{arg_name.replace('_', '-')}", str(arg_value)]) + + print(f"Running: {' '.join(cmd)}") + print() + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("✅ Instruct data processing completed successfully!") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"❌ Error running instruct data processor: {e}") + print(f"Error output: {e.stderr}") + return False + +def show_instruct_features(): + """Show the features of the instruct data processor""" + print("=== Instruct Data Processor Features ===") + print() + print("1. **Instruction Fine-tuning Tasks**:") + print(" - Code reasoning and explanation") + print(" - General conversation and chat") + print(" - Question answering") + print(" - Task-specific instruction following") + print() + print("2. **Conversation Data Formats Supported**:") + print(" - HuggingFace conversation datasets") + print(" - Custom JSONL/JSON files with conversation arrays") + print(" - ShareGPT format with role/content structure") + print(" - Automatic train/validation/test splits") + print() + print("3. **Conversation Validation**:") + print(" - Role validation (user/assistant/system)") + print(" - Content length and quality checks") + print(" - Conversation structure validation") + print(" - Turn-level statistics and analysis") + print() + print("4. **Advanced Features**:") + print(" - Configurable conversation field mapping") + print(" - Text preprocessing options") + print(" - Automatic dataset saving/loading") + print(" - YAML configuration support") + print(" - Compatible with Unsloth chat templates") + print() + print("=== Usage Examples ===") + print() + print("1. Use YAML config only:") + print(" python scripts/instruct/data_processor.py --config configs/instruct/code_reasoning.yaml") + print() + print("2. Override YAML values:") + print(" python scripts/instruct/data_processor.py --config configs/instruct/code_reasoning.yaml --max-samples 500") + print() + print("3. Create sample data:") + print(" python scripts/instruct/data_processor.py create-sample-data") + print() + print("4. Create custom config:") + print(" python scripts/instruct/data_processor.py create-config") + +def main(): + """Main function""" + if len(sys.argv) > 1: + # Check if it's a subcommand + if sys.argv[1] in ["examples", "create-sample-data", "create-config", "features"]: + # Handle subcommands + if sys.argv[1] == "examples": + run_instruct_examples() + elif sys.argv[1] == "create-sample-data": + create_sample_instruct_data() + elif sys.argv[1] == "create-config": + create_custom_instruct_config() + elif sys.argv[1] == "features": + show_instruct_features() + else: + # Handle direct arguments (pass through to pipeline) + handle_direct_args() + else: + print("Instruct Data Processor") + print("======================") + print() + print("This script runs the instruct data processor for instruction fine-tuning tasks.") + print("It supports both YAML configurations and command-line overrides.") + print() + print("Usage:") + print(" python scripts/instruct/data_processor.py examples # Run examples") + print(" python scripts/instruct/data_processor.py create-sample-data # Create sample dataset") + print(" python scripts/instruct/data_processor.py create-config # Create custom config") + print(" python scripts/instruct/data_processor.py features # Show features") + print() + print("Direct pipeline usage:") + print(" python scripts/instruct/data_processor.py --config configs/instruct/code_reasoning.yaml") + print(" python scripts/instruct/data_processor.py --data-source custom --data-path ./conversations.jsonl") + print() + print("Key Features:") + print(" ✅ Instruction fine-tuning with conversation data") + print(" ✅ Multiple data source support") + print(" ✅ YAML configuration files") + print(" ✅ CLI argument overrides") + print(" ✅ Conversation validation and analysis") + print(" ✅ Compatible with Unsloth chat templates") + +if __name__ == "__main__": + main() diff --git a/scripts/instruct/.ipynb_checkpoints/inference-checkpoint.py b/scripts/instruct/.ipynb_checkpoints/inference-checkpoint.py new file mode 100644 index 0000000..a092f7f --- /dev/null +++ b/scripts/instruct/.ipynb_checkpoints/inference-checkpoint.py @@ -0,0 +1,291 @@ + +#!/usr/bin/env python3 +""" +Instruct Inference Script +Provides a command-line interface to run the instruct inference pipeline +""" + +import sys +import os +import subprocess +import argparse +from pathlib import Path + +def run_inference_with_config(config_path: str, message: str = "", max_tokens: int = 128, stream: bool = False, interactive: bool = False): + """Run inference using a YAML configuration file""" + print(f"Running instruct inference with config: {config_path}") + if interactive: + print("Mode: Interactive chat") + elif message: + print(f"Message: {message}") + print(f"Max tokens: {max_tokens}") + print(f"Streaming: {stream}") + + cmd = [ + "python", "pipelines/instruct/inference.py", + "--config", config_path, + "--max-tokens", str(max_tokens) + ] + + if interactive: + cmd.append("--interactive") + elif message: + cmd.extend(["--message", message]) + + if stream: + cmd.append("--stream") + + print(f"Running: {' '.join(cmd)}") + + try: + if interactive: + # For interactive mode, don't capture output + result = subprocess.run(cmd, check=True) + return True + else: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + print("✅ Inference completed successfully!") + print("Output:") + print(result.stdout) + return result.stdout + except subprocess.CalledProcessError as e: + print(f"❌ Inference failed: {e}") + print("Error output:") + print(e.stderr) + return None + +def run_conversation_example(config_path: str): + """Run a conversation example""" + print(f"=== Conversation Example ===") + print(f"Config: {config_path}") + + example_messages = [ + "Can you explain what recursion is in programming?", + "How do I debug a Python program?", + "What's the difference between a list and a tuple in Python?", + "Can you show me how to use a for loop?", + "What are the benefits of using functions in programming?" + ] + + print("Running example conversations...") + print() + + for i, message in enumerate(example_messages): + print(f"--- Example {i+1} ---") + result = run_inference_with_config(config_path, message, max_tokens=256) + if not result: + print(f"❌ Failed to process message {i+1}") + print() + + print("✅ Conversation examples completed!") + +def show_inference_features(): + """Show the features of the instruct inference pipeline""" + print("=== Instruct Inference Pipeline Features ===") + print() + print("1. **Model Support**:") + print(" - Trained LoRA models from instruct training pipeline") + print(" - Automatic model loading from config") + print(" - Native Unsloth inference optimization") + print(" - Chat template integration") + print() + print("2. **Inference Modes**:") + print(" - Single message inference") + print(" - Interactive chat session") + print(" - Streaming generation") + print(" - Batch conversation processing") + print() + print("3. **Conversation Features**:") + print(" - Multi-turn conversation support") + print(" - Context preservation across turns") + print(" - Proper role handling (user/assistant/system)") + print(" - Chat history management") + print() + print("4. **Generation Control**:") + print(" - Configurable max tokens") + print(" - Temperature and sampling parameters") + print(" - Streaming output support") + print(" - Chat template formatting") + print() + print("5. **Interactive Features**:") + print(" - Real-time chat interface") + print(" - Command support (clear, stream toggle)") + print(" - Conversation history tracking") + print(" - Graceful exit handling") + print() + print("6. **Usage Examples**:") + print(" - Single message: --message 'your question here'") + print(" - Interactive chat: --interactive") + print(" - Streaming: add --stream flag") + print(" - Custom tokens: --max-tokens 256") + +def create_inference_example(): + """Create an inference example using the code reasoning configuration""" + print("=== Inference Example: Code Reasoning Chat ===") + print() + + # Check if we have the required files + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"❌ Configuration file not found: {config_path}") + print(" Please run the data processor and training first") + return False + + print("✅ Found configuration file!") + print(f" Config: {config_path}") + print() + + # Example conversation + example_message = "Can you explain what a Python decorator is and show me a simple example?" + + print(f"Example message: {example_message}") + print() + + # Run inference + success = run_inference_with_config( + config_path=config_path, + message=example_message, + max_tokens=256 + ) + + if success: + print("✅ Example inference completed successfully!") + return True + else: + print("❌ Example inference failed!") + return False + +def start_interactive_chat(config_path: str, stream: bool = False): + """Start an interactive chat session""" + print("=== Interactive Chat Session ===") + print() + + if not Path(config_path).exists(): + print(f"❌ Configuration file not found: {config_path}") + print(" Please run the data processor and training first") + return False + + print(f"Starting interactive chat with config: {config_path}") + print("Streaming:", "enabled" if stream else "disabled") + print() + + # Run interactive inference + success = run_inference_with_config( + config_path=config_path, + interactive=True, + stream=stream + ) + + return success + +def create_batch_test(): + """Create a batch test with multiple questions""" + print("=== Batch Test: Multiple Questions ===") + print() + + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"❌ Configuration file not found: {config_path}") + print(" Please run the data processor and training first") + return False + + # Create a batch of test questions + test_questions = [ + "What is object-oriented programming?", + "How do you handle errors in Python?", + "Explain the concept of variables in programming.", + "What's the difference between a compiler and an interpreter?" + ] + + print("Running batch test with multiple questions...") + print() + + success_count = 0 + for i, question in enumerate(test_questions): + print(f"Question {i+1}: {question}") + result = run_inference_with_config(config_path, question, max_tokens=200) + if result: + success_count += 1 + print("-" * 50) + + print(f"✅ Batch test completed: {success_count}/{len(test_questions)} questions processed successfully") + return success_count == len(test_questions) + +def main(): + """Main inference function""" + parser = argparse.ArgumentParser(description="Instruct Inference Pipeline") + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Inference command + infer_parser = subparsers.add_parser("infer", help="Run single inference") + infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + infer_parser.add_argument("--message", type=str, required=True, help="Message to send to the model") + infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate") + infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation") + + # Interactive command + interactive_parser = subparsers.add_parser("chat", help="Start interactive chat") + interactive_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + interactive_parser.add_argument("--stream", action="store_true", help="Enable streaming generation") + + # Batch test command + batch_parser = subparsers.add_parser("batch", help="Run batch test") + batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + + # Conversation example command + conv_parser = subparsers.add_parser("conversation", help="Run conversation examples") + conv_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + + # Features command + subparsers.add_parser("features", help="Show available features") + + # Example command + subparsers.add_parser("example", help="Run example inference") + + args = parser.parse_args() + + if args.command == "infer": + run_inference_with_config( + args.config, + args.message, + args.max_tokens, + args.stream + ) + elif args.command == "chat": + start_interactive_chat(args.config, args.stream) + elif args.command == "batch": + create_batch_test() + elif args.command == "conversation": + run_conversation_example(args.config) + elif args.command == "features": + show_inference_features() + elif args.command == "example": + create_inference_example() + else: + print("Instruct Inference Pipeline") + print("==========================") + print() + print("Available commands:") + print(" infer - Run single message inference") + print(" chat - Start interactive chat session") + print(" batch - Run batch test with multiple questions") + print(" conversation - Run conversation examples") + print(" features - Show available features") + print(" example - Run example inference") + print() + print("Examples:") + print(" python scripts/instruct/inference.py infer --config configs/instruct/code_reasoning.yaml --message 'Explain Python loops'") + print(" python scripts/instruct/inference.py chat --config configs/instruct/code_reasoning.yaml") + print(" python scripts/instruct/inference.py chat --config configs/instruct/code_reasoning.yaml --stream") + print() + print("Key Features:") + print(" ✅ Interactive chat with conversation history") + print(" ✅ Streaming generation support") + print(" ✅ Multi-turn conversation handling") + print(" ✅ Chat template integration") + print(" ✅ Configurable generation parameters") + +if __name__ == "__main__": + main() diff --git a/scripts/instruct/.ipynb_checkpoints/train-checkpoint.py b/scripts/instruct/.ipynb_checkpoints/train-checkpoint.py new file mode 100644 index 0000000..e3c43ea --- /dev/null +++ b/scripts/instruct/.ipynb_checkpoints/train-checkpoint.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Instruct Training Script +Provides a command-line interface to run the instruct training pipeline +""" + +import sys +import os +import subprocess +import argparse +from pathlib import Path + +def run_training_with_config(config_path: str, dataset_path: str = None, **cli_overrides): + """Run the instruct training pipeline with YAML configuration""" + print(f"Starting instruct training with config: {config_path}") + if dataset_path: + print(f"Training dataset: {dataset_path}") + else: + print("Training dataset: Will use output_dir from YAML config") + print() + + # Build command + cmd = ["python", "pipelines/instruct/train.py", "--config", config_path] + + # Add dataset path if provided + if dataset_path: + cmd.extend(["--dataset", dataset_path]) + + # Add CLI overrides + for key, value in cli_overrides.items(): + if value is not None: + if key == "output_dir": + cmd.extend(["--output-dir", str(value)]) + elif key == "epochs": + cmd.extend(["--epochs", str(value)]) + elif key == "batch_size": + cmd.extend(["--batch-size", str(value)]) + elif key == "learning_rate": + cmd.extend(["--learning-rate", str(value)]) + elif key == "max_steps": + cmd.extend(["--max-steps", str(value)]) + + print(f"Running: {' '.join(cmd)}") + print() + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("Training completed successfully!") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"Training failed: {e}") + print(f"Error output: {e.stderr}") + return False + +def show_training_features(): + """Show the features of the instruct training pipeline""" + print("=== Instruct Training Pipeline Features ===") + print() + print("1. **Model Support**:") + print(" - Unsloth optimized models (4x faster)") + print(" - LoRA fine-tuning for efficiency") + print(" - Support for Qwen2.5, Llama-3.1, Mistral, Phi-3") + print(" - Chat template integration") + print() + print("2. **Training Features**:") + print(" - SFTTrainer with conversation data") + print(" - Response-only training (train only on assistant responses)") + print(" - ShareGPT format standardization") + print(" - Automatic mixed precision (FP16/BF16)") + print(" - Gradient checkpointing for memory efficiency") + print(" - Configurable LoRA parameters") + print() + print("3. **Conversation Handling**:") + print(" - Multi-turn conversation support") + print(" - Proper chat template formatting") + print(" - Role-based training (user/assistant/system)") + print(" - Context preservation across turns") + print() + print("4. **Configuration**:") + print(" - YAML configuration files") + print(" - CLI argument overrides") + print(" - Automatic device detection") + print(" - Flexible LoRA configuration") + print() + print("5. **Output**:") + print(" - Saved LoRA models") + print(" - Training logs and checkpoints") + print(" - Ready for conversational inference") + +def create_training_example(): + """Create a training example using the code reasoning configuration""" + print("=== Training Example: Code Reasoning Instruction Tuning ===") + print() + + # Check if we have the required files + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"Configuration file not found: {config_path}") + print(" Please run the data processor first to create the configuration") + return False + + print("Found required files!") + print(f" Config: {config_path}") + print(" Dataset: Will use output_dir from YAML config") + print(" The training pipeline will automatically:") + print(" - Load conversation data from the output_dir specified in YAML") + print(" - Convert JSONL files to HuggingFace dataset format") + print(" - Apply ShareGPT standardization") + print(" - Format conversations with chat templates") + print(" - Train the model using SFTTrainer with response-only training") + print() + + # Run training without explicit dataset path - will use YAML config + success = run_training_with_config( + config_path=config_path, + dataset_path=None, # Use output_dir from YAML config + epochs=1, + batch_size=1, + learning_rate=2e-4, + max_steps=30 + ) + + if success: + print("Training example completed!") + print(" Model saved to: ./models/instruct") + print(" Ready for conversational inference!") + + return success + +def create_quick_test(): + """Create a quick test with minimal steps for testing""" + print("=== Quick Test: Minimal Training Steps ===") + print() + + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"Configuration file not found: {config_path}") + print(" Please run the data processor first to create the configuration") + return False + + print("Running quick test with minimal training steps...") + + # Run training with very few steps for quick testing + success = run_training_with_config( + config_path=config_path, + dataset_path=None, + epochs=1, + batch_size=1, + learning_rate=2e-4, + max_steps=5 # Very few steps for quick test + ) + + if success: + print("Quick test completed!") + print(" Model saved with minimal training") + print(" This is just for testing the pipeline") + + return success + +def main(): + """Main function""" + parser = argparse.ArgumentParser(description="Instruct Training Script") + + # Subcommands + parser.add_argument("command", choices=["train", "example", "features", "quick-test"], + help="Command to run") + + # Training arguments + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + parser.add_argument("--dataset", type=str, help="Path to training dataset") + parser.add_argument("--output-dir", type=str, help="Output directory for model") + parser.add_argument("--epochs", type=int, help="Number of training epochs") + parser.add_argument("--batch-size", type=int, help="Training batch size") + parser.add_argument("--learning-rate", type=float, help="Learning rate") + parser.add_argument("--max-steps", type=int, help="Maximum training steps") + + args = parser.parse_args() + + if args.command == "features": + show_training_features() + + elif args.command == "example": + create_training_example() + + elif args.command == "quick-test": + create_quick_test() + + elif args.command == "train": + if not args.config: + print("❌ --config is required for training") + print("Usage: python scripts/instruct/train.py train --config config.yaml") + sys.exit(1) + + # If dataset is not provided, try to use output_dir from config + dataset_path = args.dataset if args.dataset else None + + success = run_training_with_config( + config_path=args.config, + dataset_path=dataset_path, + output_dir=args.output_dir, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + max_steps=args.max_steps + ) + + if not success: + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/scripts/instruct/data_processor.py b/scripts/instruct/data_processor.py new file mode 100644 index 0000000..4716208 --- /dev/null +++ b/scripts/instruct/data_processor.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Instruct data processor script that uses YAML configurations. +This provides a flexible and maintainable approach for instruction fine-tuning tasks. +""" + +import sys +import os +import subprocess +import argparse +from pathlib import Path + +def run_with_yaml_config(config_path: str, **cli_overrides): + """Run instruct data processor with YAML configuration""" + print(f"=== Running Instruct Data Processor with YAML config: {config_path} ===") + + cmd = [ + "python", "pipelines/instruct/data_processor.py", + "--config", config_path + ] + + # Add CLI overrides + for key, value in cli_overrides.items(): + if value is not None: + cmd.extend([f"--{key.replace('_', '-')}", str(value)]) + + print(f"Running command: {' '.join(cmd)}") + print() + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("✅ Instruct data processing completed successfully!") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"❌ Error running instruct data processor: {e}") + print(f"Error output: {e.stderr}") + return False + +def run_instruct_examples(): + """Run instruct examples with YAML configs""" + + # Example 1: Code reasoning instruction tuning + print("=== Example 1: Code Reasoning Instruction Tuning ===") + success = run_with_yaml_config( + "configs/instruct/code_reasoning.yaml", + max_samples=1000, # Override YAML value + output_format="conversation" + ) + + if success: + print("✅ Code reasoning instruction tuning data processing completed!") + + # Example 2: Custom conversation dataset (if available) + print("\n=== Example 2: Custom Conversation Dataset ===") + if os.path.exists("data/raw/instruct/custom_conversations.jsonl"): + success = run_with_yaml_config( + "configs/instruct/code_reasoning.yaml", # Use code_reasoning config as base + data_source="custom", + data_path="data/raw/instruct/custom_conversations.jsonl", + output_dir="./data/processed/instruct/custom" + ) + if success: + print("✅ Custom conversation dataset processing completed!") + else: + print("⚠️ Custom conversation dataset not found, skipping...") + print(" You can create one with the 'create-sample-data' option") + +def create_sample_instruct_data(): + """Create sample conversation dataset for testing""" + sample_data = [ + { + "conversation": [ + {"role": "user", "content": "Can you explain what a Python list comprehension is?"}, + {"role": "assistant", "content": "A Python list comprehension is a concise way to create lists. It allows you to generate a new list by applying an expression to each item in an existing iterable, optionally filtering items with a condition. The basic syntax is: [expression for item in iterable if condition]. For example: [x**2 for x in range(10) if x%2==0] creates a list of squares of even numbers from 0 to 8."} + ] + }, + { + "conversation": [ + {"role": "user", "content": "How do I handle exceptions in Python?"}, + {"role": "assistant", "content": "In Python, you handle exceptions using try-except blocks. Here's the basic structure:\n\ntry:\n # Code that might raise an exception\n risky_code()\nexcept SpecificException as e:\n # Handle specific exception\n print(f'Error: {e}')\nexcept Exception as e:\n # Handle any other exception\n print(f'Unexpected error: {e}')\nelse:\n # Runs if no exception occurred\n print('Success!')\nfinally:\n # Always runs, regardless of exceptions\n cleanup_code()\n\nYou can catch multiple specific exceptions and handle them differently."} + ] + }, + { + "conversation": [ + {"role": "user", "content": "What's the difference between a function and a method in Python?"}, + {"role": "assistant", "content": "In Python, the main differences between functions and methods are:\n\n1. **Definition location**: Functions are defined independently, while methods are defined inside classes.\n\n2. **Calling**: Functions are called directly by name, methods are called on objects using dot notation.\n\n3. **First parameter**: Methods automatically receive 'self' as the first parameter (the instance they're called on), functions don't have this automatic parameter.\n\nExample:\n```python\n# Function\ndef my_function(x):\n return x * 2\n\n# Method (inside a class)\nclass MyClass:\n def my_method(self, x):\n return x * 2\n\n# Usage\nresult1 = my_function(5) # Function call\nobj = MyClass()\nresult2 = obj.my_method(5) # Method call\n```"} + ] + }, + { + "conversation": [ + {"role": "user", "content": "How can I optimize a slow Python loop?"}, + {"role": "assistant", "content": "Here are several strategies to optimize slow Python loops:\n\n1. **Use list comprehensions** instead of explicit loops when possible\n2. **Vectorize with NumPy** for numerical operations\n3. **Use built-in functions** like map(), filter(), sum()\n4. **Avoid repeated lookups** - store frequently accessed values in variables\n5. **Use sets for membership testing** instead of lists\n6. **Consider using itertools** for efficient iteration patterns\n7. **Profile your code** with cProfile to identify bottlenecks\n8. **Use numba or Cython** for compute-intensive loops\n\nExample optimization:\n```python\n# Slow\nresult = []\nfor i in range(1000000):\n if i % 2 == 0:\n result.append(i ** 2)\n\n# Faster\nresult = [i ** 2 for i in range(1000000) if i % 2 == 0]\n\n# Even faster with NumPy\nimport numpy as np\narr = np.arange(1000000)\nresult = (arr[arr % 2 == 0] ** 2).tolist()\n```"} + ] + }, + { + "conversation": [ + {"role": "user", "content": "What are Python decorators and how do I use them?"}, + {"role": "assistant", "content": "Python decorators are a powerful feature that allows you to modify or extend the behavior of functions or classes without permanently modifying their code. They use the @ symbol and are placed above the function definition.\n\n**Basic decorator example:**\n```python\ndef my_decorator(func):\n def wrapper(*args, **kwargs):\n print('Before function call')\n result = func(*args, **kwargs)\n print('After function call')\n return result\n return wrapper\n\n@my_decorator\ndef greet(name):\n print(f'Hello, {name}!')\n\ngreet('Alice') # Prints before, greeting, and after messages\n```\n\n**Common use cases:**\n- Timing function execution\n- Logging\n- Authentication/authorization\n- Caching results\n- Input validation\n\n**Built-in decorators:**\n- `@property` - creates getter/setter methods\n- `@staticmethod` - methods that don't need self or cls\n- `@classmethod` - methods that receive the class as first argument\n\nDecorators make code more modular and reusable by separating concerns."} + ] + } + ] + + # Create directory structure + data_dir = Path("data/raw/instruct") + data_dir.mkdir(parents=True, exist_ok=True) + + # Save sample data + import json + sample_file = data_dir / "code_reasoning.jsonl" + with open(sample_file, 'w', encoding='utf-8') as f: + for item in sample_data: + f.write(json.dumps(item, ensure_ascii=False) + '\n') + + print(f"✅ Created sample conversation dataset: {sample_file}") + print(f" Contains {len(sample_data)} conversation examples") + print(f" Format: conversation array with role/content pairs") + print(f" Ready to use with configs/instruct/code_reasoning.yaml") + +def create_custom_instruct_config(): + """Create a custom instruct configuration file""" + custom_config = """# Custom Instruct Configuration +task: + name: "general_chat" + type: "instruction_following" + +data: + source: "custom" + data_path: "./data/raw/instruct/general_chat.jsonl" + data_format: "jsonl" + conversation_field: "conversation" + max_length: 2048 + min_length: 10 + clean_text: true + train_split: 0.8 + validation_split: 0.1 + test_split: 0.1 + output_format: "conversation" + output_dir: "./data/processed/instruct/general_chat" + +model: + name: "unsloth/Qwen2.5-7B-Instruct" + max_length: 2048 + max_seq_length: 2048 + dtype: null + load_in_4bit: true + token: null + training_model: "unsloth/Qwen2.5-7B-Instruct" + training_max_seq_length: 2048 + training_dtype: null + training_load_in_4bit: true + +training: + num_epochs: 1 + batch_size: 1 + learning_rate: 2e-4 + weight_decay: 0.01 + warmup_steps: 5 + max_steps: 50 + gradient_accumulation_steps: 4 + lr_scheduler_type: "linear" + seed: 3407 + lora_r: 16 + lora_alpha: 16 + lora_dropout: 0 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + output_dir: "./outputs" + model_output_dir: "./models/instruct/general_chat" + +inference: + batch_size: 1 + max_new_tokens: 256 + temperature: 0.8 + min_p: 0.1 + use_cache: true +""" + + config_path = "configs/instruct/general_chat.yaml" + os.makedirs(os.path.dirname(config_path), exist_ok=True) + + with open(config_path, 'w') as f: + f.write(custom_config) + + print(f"✅ Created custom instruct config: {config_path}") + print(" This config is set up for general chat instruction tuning") + +def handle_direct_args(): + """Handle direct command-line arguments by passing them to the instruct pipeline""" + parser = argparse.ArgumentParser(description="Instruct Data Processor") + + # Add all the same arguments as the instruct pipeline + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source") + parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name") + parser.add_argument("--data-path", type=str, help="Path to custom data file") + parser.add_argument("--data-format", choices=["jsonl", "json"], help="Data format") + parser.add_argument("--conversation-field", type=str, help="Conversation field name") + parser.add_argument("--max-samples", type=int, help="Maximum samples to process") + parser.add_argument("--train-split", type=float, help="Training split ratio") + parser.add_argument("--validation-split", type=float, help="Validation split ratio") + parser.add_argument("--test-split", type=float, help="Test split ratio") + parser.add_argument("--output-dir", type=str, help="Output directory") + + # Logging + parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level") + + args = parser.parse_args() + + # Build command to call the instruct pipeline + cmd = ["python", "pipelines/instruct/data_processor.py"] + + # Add all arguments that were provided + for arg_name, arg_value in vars(args).items(): + if arg_value is not None: + if isinstance(arg_value, bool): + if arg_value: # Only add flag if True + cmd.append(f"--{arg_name.replace('_', '-')}") + else: + cmd.extend([f"--{arg_name.replace('_', '-')}", str(arg_value)]) + + print(f"Running: {' '.join(cmd)}") + print() + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("✅ Instruct data processing completed successfully!") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"❌ Error running instruct data processor: {e}") + print(f"Error output: {e.stderr}") + return False + +def show_instruct_features(): + """Show the features of the instruct data processor""" + print("=== Instruct Data Processor Features ===") + print() + print("1. **Instruction Fine-tuning Tasks**:") + print(" - Code reasoning and explanation") + print(" - General conversation and chat") + print(" - Question answering") + print(" - Task-specific instruction following") + print() + print("2. **Conversation Data Formats Supported**:") + print(" - HuggingFace conversation datasets") + print(" - Custom JSONL/JSON files with conversation arrays") + print(" - ShareGPT format with role/content structure") + print(" - Automatic train/validation/test splits") + print() + print("3. **Conversation Validation**:") + print(" - Role validation (user/assistant/system)") + print(" - Content length and quality checks") + print(" - Conversation structure validation") + print(" - Turn-level statistics and analysis") + print() + print("4. **Advanced Features**:") + print(" - Configurable conversation field mapping") + print(" - Text preprocessing options") + print(" - Automatic dataset saving/loading") + print(" - YAML configuration support") + print(" - Compatible with Unsloth chat templates") + print() + print("=== Usage Examples ===") + print() + print("1. Use YAML config only:") + print(" python scripts/instruct/data_processor.py --config configs/instruct/code_reasoning.yaml") + print() + print("2. Override YAML values:") + print(" python scripts/instruct/data_processor.py --config configs/instruct/code_reasoning.yaml --max-samples 500") + print() + print("3. Create sample data:") + print(" python scripts/instruct/data_processor.py create-sample-data") + print() + print("4. Create custom config:") + print(" python scripts/instruct/data_processor.py create-config") + +def main(): + """Main function""" + if len(sys.argv) > 1: + # Check if it's a subcommand + if sys.argv[1] in ["examples", "create-sample-data", "create-config", "features"]: + # Handle subcommands + if sys.argv[1] == "examples": + run_instruct_examples() + elif sys.argv[1] == "create-sample-data": + create_sample_instruct_data() + elif sys.argv[1] == "create-config": + create_custom_instruct_config() + elif sys.argv[1] == "features": + show_instruct_features() + else: + # Handle direct arguments (pass through to pipeline) + handle_direct_args() + else: + print("Instruct Data Processor") + print("======================") + print() + print("This script runs the instruct data processor for instruction fine-tuning tasks.") + print("It supports both YAML configurations and command-line overrides.") + print() + print("Usage:") + print(" python scripts/instruct/data_processor.py examples # Run examples") + print(" python scripts/instruct/data_processor.py create-sample-data # Create sample dataset") + print(" python scripts/instruct/data_processor.py create-config # Create custom config") + print(" python scripts/instruct/data_processor.py features # Show features") + print() + print("Direct pipeline usage:") + print(" python scripts/instruct/data_processor.py --config configs/instruct/code_reasoning.yaml") + print(" python scripts/instruct/data_processor.py --data-source custom --data-path ./conversations.jsonl") + print() + print("Key Features:") + print(" ✅ Instruction fine-tuning with conversation data") + print(" ✅ Multiple data source support") + print(" ✅ YAML configuration files") + print(" ✅ CLI argument overrides") + print(" ✅ Conversation validation and analysis") + print(" ✅ Compatible with Unsloth chat templates") + +if __name__ == "__main__": + main() diff --git a/scripts/instruct/inference.py b/scripts/instruct/inference.py new file mode 100644 index 0000000..a092f7f --- /dev/null +++ b/scripts/instruct/inference.py @@ -0,0 +1,291 @@ + +#!/usr/bin/env python3 +""" +Instruct Inference Script +Provides a command-line interface to run the instruct inference pipeline +""" + +import sys +import os +import subprocess +import argparse +from pathlib import Path + +def run_inference_with_config(config_path: str, message: str = "", max_tokens: int = 128, stream: bool = False, interactive: bool = False): + """Run inference using a YAML configuration file""" + print(f"Running instruct inference with config: {config_path}") + if interactive: + print("Mode: Interactive chat") + elif message: + print(f"Message: {message}") + print(f"Max tokens: {max_tokens}") + print(f"Streaming: {stream}") + + cmd = [ + "python", "pipelines/instruct/inference.py", + "--config", config_path, + "--max-tokens", str(max_tokens) + ] + + if interactive: + cmd.append("--interactive") + elif message: + cmd.extend(["--message", message]) + + if stream: + cmd.append("--stream") + + print(f"Running: {' '.join(cmd)}") + + try: + if interactive: + # For interactive mode, don't capture output + result = subprocess.run(cmd, check=True) + return True + else: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + print("✅ Inference completed successfully!") + print("Output:") + print(result.stdout) + return result.stdout + except subprocess.CalledProcessError as e: + print(f"❌ Inference failed: {e}") + print("Error output:") + print(e.stderr) + return None + +def run_conversation_example(config_path: str): + """Run a conversation example""" + print(f"=== Conversation Example ===") + print(f"Config: {config_path}") + + example_messages = [ + "Can you explain what recursion is in programming?", + "How do I debug a Python program?", + "What's the difference between a list and a tuple in Python?", + "Can you show me how to use a for loop?", + "What are the benefits of using functions in programming?" + ] + + print("Running example conversations...") + print() + + for i, message in enumerate(example_messages): + print(f"--- Example {i+1} ---") + result = run_inference_with_config(config_path, message, max_tokens=256) + if not result: + print(f"❌ Failed to process message {i+1}") + print() + + print("✅ Conversation examples completed!") + +def show_inference_features(): + """Show the features of the instruct inference pipeline""" + print("=== Instruct Inference Pipeline Features ===") + print() + print("1. **Model Support**:") + print(" - Trained LoRA models from instruct training pipeline") + print(" - Automatic model loading from config") + print(" - Native Unsloth inference optimization") + print(" - Chat template integration") + print() + print("2. **Inference Modes**:") + print(" - Single message inference") + print(" - Interactive chat session") + print(" - Streaming generation") + print(" - Batch conversation processing") + print() + print("3. **Conversation Features**:") + print(" - Multi-turn conversation support") + print(" - Context preservation across turns") + print(" - Proper role handling (user/assistant/system)") + print(" - Chat history management") + print() + print("4. **Generation Control**:") + print(" - Configurable max tokens") + print(" - Temperature and sampling parameters") + print(" - Streaming output support") + print(" - Chat template formatting") + print() + print("5. **Interactive Features**:") + print(" - Real-time chat interface") + print(" - Command support (clear, stream toggle)") + print(" - Conversation history tracking") + print(" - Graceful exit handling") + print() + print("6. **Usage Examples**:") + print(" - Single message: --message 'your question here'") + print(" - Interactive chat: --interactive") + print(" - Streaming: add --stream flag") + print(" - Custom tokens: --max-tokens 256") + +def create_inference_example(): + """Create an inference example using the code reasoning configuration""" + print("=== Inference Example: Code Reasoning Chat ===") + print() + + # Check if we have the required files + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"❌ Configuration file not found: {config_path}") + print(" Please run the data processor and training first") + return False + + print("✅ Found configuration file!") + print(f" Config: {config_path}") + print() + + # Example conversation + example_message = "Can you explain what a Python decorator is and show me a simple example?" + + print(f"Example message: {example_message}") + print() + + # Run inference + success = run_inference_with_config( + config_path=config_path, + message=example_message, + max_tokens=256 + ) + + if success: + print("✅ Example inference completed successfully!") + return True + else: + print("❌ Example inference failed!") + return False + +def start_interactive_chat(config_path: str, stream: bool = False): + """Start an interactive chat session""" + print("=== Interactive Chat Session ===") + print() + + if not Path(config_path).exists(): + print(f"❌ Configuration file not found: {config_path}") + print(" Please run the data processor and training first") + return False + + print(f"Starting interactive chat with config: {config_path}") + print("Streaming:", "enabled" if stream else "disabled") + print() + + # Run interactive inference + success = run_inference_with_config( + config_path=config_path, + interactive=True, + stream=stream + ) + + return success + +def create_batch_test(): + """Create a batch test with multiple questions""" + print("=== Batch Test: Multiple Questions ===") + print() + + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"❌ Configuration file not found: {config_path}") + print(" Please run the data processor and training first") + return False + + # Create a batch of test questions + test_questions = [ + "What is object-oriented programming?", + "How do you handle errors in Python?", + "Explain the concept of variables in programming.", + "What's the difference between a compiler and an interpreter?" + ] + + print("Running batch test with multiple questions...") + print() + + success_count = 0 + for i, question in enumerate(test_questions): + print(f"Question {i+1}: {question}") + result = run_inference_with_config(config_path, question, max_tokens=200) + if result: + success_count += 1 + print("-" * 50) + + print(f"✅ Batch test completed: {success_count}/{len(test_questions)} questions processed successfully") + return success_count == len(test_questions) + +def main(): + """Main inference function""" + parser = argparse.ArgumentParser(description="Instruct Inference Pipeline") + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Inference command + infer_parser = subparsers.add_parser("infer", help="Run single inference") + infer_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + infer_parser.add_argument("--message", type=str, required=True, help="Message to send to the model") + infer_parser.add_argument("--max-tokens", type=int, default=128, help="Maximum new tokens to generate") + infer_parser.add_argument("--stream", action="store_true", help="Enable streaming generation") + + # Interactive command + interactive_parser = subparsers.add_parser("chat", help="Start interactive chat") + interactive_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + interactive_parser.add_argument("--stream", action="store_true", help="Enable streaming generation") + + # Batch test command + batch_parser = subparsers.add_parser("batch", help="Run batch test") + batch_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + + # Conversation example command + conv_parser = subparsers.add_parser("conversation", help="Run conversation examples") + conv_parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file") + + # Features command + subparsers.add_parser("features", help="Show available features") + + # Example command + subparsers.add_parser("example", help="Run example inference") + + args = parser.parse_args() + + if args.command == "infer": + run_inference_with_config( + args.config, + args.message, + args.max_tokens, + args.stream + ) + elif args.command == "chat": + start_interactive_chat(args.config, args.stream) + elif args.command == "batch": + create_batch_test() + elif args.command == "conversation": + run_conversation_example(args.config) + elif args.command == "features": + show_inference_features() + elif args.command == "example": + create_inference_example() + else: + print("Instruct Inference Pipeline") + print("==========================") + print() + print("Available commands:") + print(" infer - Run single message inference") + print(" chat - Start interactive chat session") + print(" batch - Run batch test with multiple questions") + print(" conversation - Run conversation examples") + print(" features - Show available features") + print(" example - Run example inference") + print() + print("Examples:") + print(" python scripts/instruct/inference.py infer --config configs/instruct/code_reasoning.yaml --message 'Explain Python loops'") + print(" python scripts/instruct/inference.py chat --config configs/instruct/code_reasoning.yaml") + print(" python scripts/instruct/inference.py chat --config configs/instruct/code_reasoning.yaml --stream") + print() + print("Key Features:") + print(" ✅ Interactive chat with conversation history") + print(" ✅ Streaming generation support") + print(" ✅ Multi-turn conversation handling") + print(" ✅ Chat template integration") + print(" ✅ Configurable generation parameters") + +if __name__ == "__main__": + main() diff --git a/scripts/instruct/train.py b/scripts/instruct/train.py new file mode 100644 index 0000000..e3c43ea --- /dev/null +++ b/scripts/instruct/train.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Instruct Training Script +Provides a command-line interface to run the instruct training pipeline +""" + +import sys +import os +import subprocess +import argparse +from pathlib import Path + +def run_training_with_config(config_path: str, dataset_path: str = None, **cli_overrides): + """Run the instruct training pipeline with YAML configuration""" + print(f"Starting instruct training with config: {config_path}") + if dataset_path: + print(f"Training dataset: {dataset_path}") + else: + print("Training dataset: Will use output_dir from YAML config") + print() + + # Build command + cmd = ["python", "pipelines/instruct/train.py", "--config", config_path] + + # Add dataset path if provided + if dataset_path: + cmd.extend(["--dataset", dataset_path]) + + # Add CLI overrides + for key, value in cli_overrides.items(): + if value is not None: + if key == "output_dir": + cmd.extend(["--output-dir", str(value)]) + elif key == "epochs": + cmd.extend(["--epochs", str(value)]) + elif key == "batch_size": + cmd.extend(["--batch-size", str(value)]) + elif key == "learning_rate": + cmd.extend(["--learning-rate", str(value)]) + elif key == "max_steps": + cmd.extend(["--max-steps", str(value)]) + + print(f"Running: {' '.join(cmd)}") + print() + + try: + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + print("Training completed successfully!") + print(result.stdout) + return True + except subprocess.CalledProcessError as e: + print(f"Training failed: {e}") + print(f"Error output: {e.stderr}") + return False + +def show_training_features(): + """Show the features of the instruct training pipeline""" + print("=== Instruct Training Pipeline Features ===") + print() + print("1. **Model Support**:") + print(" - Unsloth optimized models (4x faster)") + print(" - LoRA fine-tuning for efficiency") + print(" - Support for Qwen2.5, Llama-3.1, Mistral, Phi-3") + print(" - Chat template integration") + print() + print("2. **Training Features**:") + print(" - SFTTrainer with conversation data") + print(" - Response-only training (train only on assistant responses)") + print(" - ShareGPT format standardization") + print(" - Automatic mixed precision (FP16/BF16)") + print(" - Gradient checkpointing for memory efficiency") + print(" - Configurable LoRA parameters") + print() + print("3. **Conversation Handling**:") + print(" - Multi-turn conversation support") + print(" - Proper chat template formatting") + print(" - Role-based training (user/assistant/system)") + print(" - Context preservation across turns") + print() + print("4. **Configuration**:") + print(" - YAML configuration files") + print(" - CLI argument overrides") + print(" - Automatic device detection") + print(" - Flexible LoRA configuration") + print() + print("5. **Output**:") + print(" - Saved LoRA models") + print(" - Training logs and checkpoints") + print(" - Ready for conversational inference") + +def create_training_example(): + """Create a training example using the code reasoning configuration""" + print("=== Training Example: Code Reasoning Instruction Tuning ===") + print() + + # Check if we have the required files + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"Configuration file not found: {config_path}") + print(" Please run the data processor first to create the configuration") + return False + + print("Found required files!") + print(f" Config: {config_path}") + print(" Dataset: Will use output_dir from YAML config") + print(" The training pipeline will automatically:") + print(" - Load conversation data from the output_dir specified in YAML") + print(" - Convert JSONL files to HuggingFace dataset format") + print(" - Apply ShareGPT standardization") + print(" - Format conversations with chat templates") + print(" - Train the model using SFTTrainer with response-only training") + print() + + # Run training without explicit dataset path - will use YAML config + success = run_training_with_config( + config_path=config_path, + dataset_path=None, # Use output_dir from YAML config + epochs=1, + batch_size=1, + learning_rate=2e-4, + max_steps=30 + ) + + if success: + print("Training example completed!") + print(" Model saved to: ./models/instruct") + print(" Ready for conversational inference!") + + return success + +def create_quick_test(): + """Create a quick test with minimal steps for testing""" + print("=== Quick Test: Minimal Training Steps ===") + print() + + config_path = "configs/instruct/code_reasoning.yaml" + + if not Path(config_path).exists(): + print(f"Configuration file not found: {config_path}") + print(" Please run the data processor first to create the configuration") + return False + + print("Running quick test with minimal training steps...") + + # Run training with very few steps for quick testing + success = run_training_with_config( + config_path=config_path, + dataset_path=None, + epochs=1, + batch_size=1, + learning_rate=2e-4, + max_steps=5 # Very few steps for quick test + ) + + if success: + print("Quick test completed!") + print(" Model saved with minimal training") + print(" This is just for testing the pipeline") + + return success + +def main(): + """Main function""" + parser = argparse.ArgumentParser(description="Instruct Training Script") + + # Subcommands + parser.add_argument("command", choices=["train", "example", "features", "quick-test"], + help="Command to run") + + # Training arguments + parser.add_argument("--config", type=str, help="Path to YAML configuration file") + parser.add_argument("--dataset", type=str, help="Path to training dataset") + parser.add_argument("--output-dir", type=str, help="Output directory for model") + parser.add_argument("--epochs", type=int, help="Number of training epochs") + parser.add_argument("--batch-size", type=int, help="Training batch size") + parser.add_argument("--learning-rate", type=float, help="Learning rate") + parser.add_argument("--max-steps", type=int, help="Maximum training steps") + + args = parser.parse_args() + + if args.command == "features": + show_training_features() + + elif args.command == "example": + create_training_example() + + elif args.command == "quick-test": + create_quick_test() + + elif args.command == "train": + if not args.config: + print("❌ --config is required for training") + print("Usage: python scripts/instruct/train.py train --config config.yaml") + sys.exit(1) + + # If dataset is not provided, try to use output_dir from config + dataset_path = args.dataset if args.dataset else None + + success = run_training_with_config( + config_path=args.config, + dataset_path=dataset_path, + output_dir=args.output_dir, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + max_steps=args.max_steps + ) + + if not success: + sys.exit(1) + +if __name__ == "__main__": + main()