instruct model setup
This commit is contained in:
@@ -20,7 +20,7 @@ from utils.config.config_manager import ConfigManager
|
||||
# Training imports
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from unsloth import FastLanguageModel #is_bfloat16_supported
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
@@ -132,56 +132,23 @@ class InstructTrainer:
|
||||
raise
|
||||
|
||||
def load_dataset(self, dataset_path: str) -> Dataset:
|
||||
"""Load the conversation training dataset"""
|
||||
"""Load the conversation training dataset directly from JSONL file"""
|
||||
print(f"Loading conversation dataset from: {dataset_path}")
|
||||
|
||||
try:
|
||||
if Path(dataset_path).exists():
|
||||
# Check if it's a HuggingFace dataset directory
|
||||
if (Path(dataset_path) / "dataset_info.json").exists():
|
||||
# Load from HuggingFace dataset directory
|
||||
dataset = load_from_disk(dataset_path)
|
||||
print(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
|
||||
else:
|
||||
# Load from processed conversation data files (JSONL format)
|
||||
print("Loading from processed conversation data files...")
|
||||
from datasets import Dataset
|
||||
import json
|
||||
|
||||
all_data = []
|
||||
data_dir = Path(dataset_path)
|
||||
|
||||
# Look for train.jsonl, validation.jsonl, test.jsonl
|
||||
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
|
||||
file_path = data_dir / split_file
|
||||
if file_path.exists():
|
||||
print(f"Loading {split_file}...")
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
all_data.append(data)
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"No conversation data found in {dataset_path}")
|
||||
|
||||
# Create HuggingFace dataset
|
||||
dataset = Dataset.from_list(all_data)
|
||||
print(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
|
||||
else:
|
||||
# Try loading from HuggingFace Hub
|
||||
print(f"Attempting to load from HuggingFace Hub: {dataset_path}")
|
||||
dataset = Dataset.load_dataset(dataset_path, split="train")
|
||||
print(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
|
||||
|
||||
print(f"Dataset loaded: {len(dataset)} samples")
|
||||
print(f"Dataset features: {dataset.features}")
|
||||
|
||||
# Verify required fields exist for conversation data
|
||||
required_fields = ["conversation"]
|
||||
missing_fields = [field for field in required_fields if field not in dataset.features]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields in conversation dataset: {missing_fields}")
|
||||
# Load JSONL data exactly as provided
|
||||
data = []
|
||||
with open(dataset_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"Loaded {len(data)} examples")
|
||||
|
||||
# Convert to HuggingFace Dataset
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
print(dataset)
|
||||
print(dataset[0]) # Show first example
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -194,22 +161,16 @@ class InstructTrainer:
|
||||
print("Formatting conversation dataset for training...")
|
||||
|
||||
try:
|
||||
# Define the formatting function exactly as provided
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversation"]
|
||||
texts = [self.tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
|
||||
return {"text": texts}
|
||||
|
||||
# Standardize the ShareGPT format
|
||||
print("Standardizing ShareGPT format...")
|
||||
dataset = standardize_sharegpt(dataset)
|
||||
|
||||
# Define the formatting function for chat templates
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversation"]
|
||||
texts = [
|
||||
self.tokenizer.apply_chat_template(
|
||||
convo,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False
|
||||
) for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
# Apply the formatting function
|
||||
print("Applying chat template formatting...")
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
@@ -277,18 +238,29 @@ class InstructTrainer:
|
||||
print("Setting up response-only training...")
|
||||
|
||||
try:
|
||||
# For Qwen models, we need to use the correct chat template tokens
|
||||
# Qwen uses different tokens than Llama
|
||||
if "qwen" in self.model_name.lower():
|
||||
instruction_part = "<|im_start|>user\n"
|
||||
response_part = "<|im_start|>assistant\n"
|
||||
else:
|
||||
# Default for other models
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
# Configure trainer to only train on responses
|
||||
self.trainer = train_on_responses_only(
|
||||
self.trainer,
|
||||
instruction_part="<|im_start|>user\n",
|
||||
response_part="<|im_start|>assistant\n",
|
||||
instruction_part=instruction_part,
|
||||
response_part=response_part,
|
||||
)
|
||||
|
||||
print("✅ Response-only training configured")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error setting up response-only training: {e}")
|
||||
raise
|
||||
print("Skipping response-only training and proceeding with full training...")
|
||||
# Don't raise the exception, just continue with regular training
|
||||
|
||||
def train(self, dataset_path: str):
|
||||
"""Run the instruction fine-tuning process"""
|
||||
@@ -321,7 +293,11 @@ class InstructTrainer:
|
||||
|
||||
# Setup response-only training (optional but recommended for chat models)
|
||||
print("Step 7: Setting up response-only training...")
|
||||
self.setup_response_only_training()
|
||||
try:
|
||||
self.setup_response_only_training()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Response-only training failed: {e}")
|
||||
print("Continuing with full training (will train on all tokens)...")
|
||||
|
||||
# Start training
|
||||
print("Step 8: Starting training...")
|
||||
@@ -432,13 +408,12 @@ def load_training_config(yaml_path: str) -> Dict[str, Any]:
|
||||
])
|
||||
})
|
||||
|
||||
# Data configuration - use output_dir from data section
|
||||
# Data configuration - use data_path from data section
|
||||
if 'data' in config:
|
||||
data_config = config['data']
|
||||
output_dir = data_config.get('output_dir', './data/processed/instruct')
|
||||
data_path = data_config.get('data_path', './data/raw/instruct/code_reasoning.jsonl')
|
||||
training_config.update({
|
||||
'data_output_dir': output_dir,
|
||||
'dataset_path': output_dir, # Default dataset path is the output_dir
|
||||
'dataset_path': data_path, # Use data_path directly for JSONL file
|
||||
})
|
||||
|
||||
# Output configuration
|
||||
|
||||
+43
-68
@@ -20,7 +20,7 @@ from utils.config.config_manager import ConfigManager
|
||||
# Training imports
|
||||
import torch
|
||||
from datasets import load_from_disk, Dataset
|
||||
from unsloth import FastLanguageModel #is_bfloat16_supported
|
||||
from unsloth import FastLanguageModel, is_bfloat16_supported
|
||||
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
@@ -132,56 +132,23 @@ class InstructTrainer:
|
||||
raise
|
||||
|
||||
def load_dataset(self, dataset_path: str) -> Dataset:
|
||||
"""Load the conversation training dataset"""
|
||||
"""Load the conversation training dataset directly from JSONL file"""
|
||||
print(f"Loading conversation dataset from: {dataset_path}")
|
||||
|
||||
try:
|
||||
if Path(dataset_path).exists():
|
||||
# Check if it's a HuggingFace dataset directory
|
||||
if (Path(dataset_path) / "dataset_info.json").exists():
|
||||
# Load from HuggingFace dataset directory
|
||||
dataset = load_from_disk(dataset_path)
|
||||
print(f"Loaded HuggingFace dataset from disk: {len(dataset)} samples")
|
||||
else:
|
||||
# Load from processed conversation data files (JSONL format)
|
||||
print("Loading from processed conversation data files...")
|
||||
from datasets import Dataset
|
||||
import json
|
||||
|
||||
all_data = []
|
||||
data_dir = Path(dataset_path)
|
||||
|
||||
# Look for train.jsonl, validation.jsonl, test.jsonl
|
||||
for split_file in ["train.jsonl", "validation.jsonl", "test.jsonl"]:
|
||||
file_path = data_dir / split_file
|
||||
if file_path.exists():
|
||||
print(f"Loading {split_file}...")
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
all_data.append(data)
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"No conversation data found in {dataset_path}")
|
||||
|
||||
# Create HuggingFace dataset
|
||||
dataset = Dataset.from_list(all_data)
|
||||
print(f"Created HuggingFace dataset from {len(all_data)} conversation samples")
|
||||
else:
|
||||
# Try loading from HuggingFace Hub
|
||||
print(f"Attempting to load from HuggingFace Hub: {dataset_path}")
|
||||
dataset = Dataset.load_dataset(dataset_path, split="train")
|
||||
print(f"Loaded from HuggingFace Hub: {len(dataset)} samples")
|
||||
|
||||
print(f"Dataset loaded: {len(dataset)} samples")
|
||||
print(f"Dataset features: {dataset.features}")
|
||||
|
||||
# Verify required fields exist for conversation data
|
||||
required_fields = ["conversation"]
|
||||
missing_fields = [field for field in required_fields if field not in dataset.features]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields in conversation dataset: {missing_fields}")
|
||||
# Load JSONL data exactly as provided
|
||||
data = []
|
||||
with open(dataset_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"Loaded {len(data)} examples")
|
||||
|
||||
# Convert to HuggingFace Dataset
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
print(dataset)
|
||||
print(dataset[0]) # Show first example
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -194,22 +161,16 @@ class InstructTrainer:
|
||||
print("Formatting conversation dataset for training...")
|
||||
|
||||
try:
|
||||
# Define the formatting function exactly as provided
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversation"]
|
||||
texts = [self.tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
|
||||
return {"text": texts}
|
||||
|
||||
# Standardize the ShareGPT format
|
||||
print("Standardizing ShareGPT format...")
|
||||
dataset = standardize_sharegpt(dataset)
|
||||
|
||||
# Define the formatting function for chat templates
|
||||
def formatting_prompts_func(examples):
|
||||
convos = examples["conversation"]
|
||||
texts = [
|
||||
self.tokenizer.apply_chat_template(
|
||||
convo,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False
|
||||
) for convo in convos
|
||||
]
|
||||
return {"text": texts}
|
||||
|
||||
# Apply the formatting function
|
||||
print("Applying chat template formatting...")
|
||||
dataset = dataset.map(formatting_prompts_func, batched=True)
|
||||
@@ -277,18 +238,29 @@ class InstructTrainer:
|
||||
print("Setting up response-only training...")
|
||||
|
||||
try:
|
||||
# For Qwen models, we need to use the correct chat template tokens
|
||||
# Qwen uses different tokens than Llama
|
||||
if "qwen" in self.model_name.lower():
|
||||
instruction_part = "<|im_start|>user\n"
|
||||
response_part = "<|im_start|>assistant\n"
|
||||
else:
|
||||
# Default for other models
|
||||
instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
# Configure trainer to only train on responses
|
||||
self.trainer = train_on_responses_only(
|
||||
self.trainer,
|
||||
instruction_part="<|im_start|>user\n",
|
||||
response_part="<|im_start|>assistant\n",
|
||||
instruction_part=instruction_part,
|
||||
response_part=response_part,
|
||||
)
|
||||
|
||||
print("✅ Response-only training configured")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error setting up response-only training: {e}")
|
||||
raise
|
||||
print("Skipping response-only training and proceeding with full training...")
|
||||
# Don't raise the exception, just continue with regular training
|
||||
|
||||
def train(self, dataset_path: str):
|
||||
"""Run the instruction fine-tuning process"""
|
||||
@@ -321,7 +293,11 @@ class InstructTrainer:
|
||||
|
||||
# Setup response-only training (optional but recommended for chat models)
|
||||
print("Step 7: Setting up response-only training...")
|
||||
self.setup_response_only_training()
|
||||
try:
|
||||
self.setup_response_only_training()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Response-only training failed: {e}")
|
||||
print("Continuing with full training (will train on all tokens)...")
|
||||
|
||||
# Start training
|
||||
print("Step 8: Starting training...")
|
||||
@@ -432,13 +408,12 @@ def load_training_config(yaml_path: str) -> Dict[str, Any]:
|
||||
])
|
||||
})
|
||||
|
||||
# Data configuration - use output_dir from data section
|
||||
# Data configuration - use data_path from data section
|
||||
if 'data' in config:
|
||||
data_config = config['data']
|
||||
output_dir = data_config.get('output_dir', './data/processed/instruct')
|
||||
data_path = data_config.get('data_path', './data/raw/instruct/code_reasoning.jsonl')
|
||||
training_config.update({
|
||||
'data_output_dir': output_dir,
|
||||
'dataset_path': output_dir, # Default dataset path is the output_dir
|
||||
'dataset_path': data_path, # Use data_path directly for JSONL file
|
||||
})
|
||||
|
||||
# Output configuration
|
||||
|
||||
Reference in New Issue
Block a user