added style mimicking piepelines
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Styling Scripts Package
|
||||
Provides command-line interfaces for styling data processing, training, and inference
|
||||
"""
|
||||
|
||||
from .data_processor import (
|
||||
run_with_yaml_config,
|
||||
run_styling_examples,
|
||||
create_sample_styling_data,
|
||||
create_custom_styling_config,
|
||||
show_styling_features
|
||||
)
|
||||
|
||||
from .train import (
|
||||
run_training_with_config,
|
||||
create_training_example,
|
||||
show_training_features
|
||||
)
|
||||
|
||||
from .inference import (
|
||||
run_inference_with_config,
|
||||
create_inference_example,
|
||||
run_batch_inference_example,
|
||||
show_inference_features
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Data processing
|
||||
'run_with_yaml_config',
|
||||
'run_styling_examples',
|
||||
'create_sample_styling_data',
|
||||
'create_custom_styling_config',
|
||||
'show_styling_features',
|
||||
|
||||
# Training
|
||||
'run_training_with_config',
|
||||
'create_training_example',
|
||||
'show_training_features',
|
||||
|
||||
# Inference
|
||||
'run_inference_with_config',
|
||||
'create_inference_example',
|
||||
'run_batch_inference_example',
|
||||
'show_inference_features'
|
||||
]
|
||||
@@ -0,0 +1,302 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling data processor script that uses YAML configurations.
|
||||
This provides a flexible and maintainable approach for style transfer tasks.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_with_yaml_config(config_path: str, **cli_overrides):
|
||||
"""Run styling data processor with YAML configuration"""
|
||||
print(f"=== Running Styling Data Processor with YAML config: {config_path} ===")
|
||||
|
||||
cmd = [
|
||||
"python", "pipelines/styling/data_processor.py",
|
||||
"--config", config_path
|
||||
]
|
||||
|
||||
# Add CLI overrides
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
cmd.extend([f"--{key.replace('_', '-')}", str(value)])
|
||||
|
||||
print(f"Running command: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
print("✅ Styling data processing completed successfully!")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Error running styling data processor: {e}")
|
||||
print(f"Error output: {e.stderr}")
|
||||
return False
|
||||
|
||||
def run_styling_examples():
|
||||
"""Run styling examples with YAML configs"""
|
||||
|
||||
# Example 1: Formal style transfer
|
||||
print("=== Example 1: Formal Style Transfer ===")
|
||||
success = run_with_yaml_config(
|
||||
"configs/styling/formal.yaml",
|
||||
max_samples=1000, # Override YAML value
|
||||
output_format="alpaca"
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Formal style transfer completed!")
|
||||
|
||||
# Example 2: Custom styling dataset (if available)
|
||||
print("\n=== Example 2: Custom Styling Dataset ===")
|
||||
if os.path.exists("data/raw/styling/custom_dataset.jsonl"):
|
||||
success = run_with_yaml_config(
|
||||
"configs/styling/formal.yaml", # Use formal config as base
|
||||
data_source="custom",
|
||||
data_path="data/raw/styling/custom_dataset.jsonl",
|
||||
instruction="Rewrite the following text in a casual, friendly style",
|
||||
output_dir="./data/processed/styling/casual"
|
||||
)
|
||||
if success:
|
||||
print("✅ Custom styling dataset processing completed!")
|
||||
else:
|
||||
print("⚠️ Custom styling dataset not found, skipping...")
|
||||
print(" You can create one with the 'create-sample-data' option")
|
||||
|
||||
def create_sample_styling_data():
|
||||
"""Create sample styling dataset for testing"""
|
||||
sample_data = [
|
||||
{
|
||||
"text": "Hey, what's up? How are you doing today?",
|
||||
"styled_text": "Hello, how are you doing today?"
|
||||
},
|
||||
{
|
||||
"text": "This is really cool stuff!",
|
||||
"styled_text": "This is quite impressive material."
|
||||
},
|
||||
{
|
||||
"text": "I'm gonna go to the store later.",
|
||||
"styled_text": "I will go to the store later."
|
||||
},
|
||||
{
|
||||
"text": "What's the deal with this?",
|
||||
"styled_text": "What is the situation regarding this matter?"
|
||||
},
|
||||
{
|
||||
"text": "That's totally awesome!",
|
||||
"styled_text": "That is quite remarkable!"
|
||||
}
|
||||
]
|
||||
|
||||
# Create directory structure
|
||||
data_dir = Path("data/raw/styling")
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save sample data
|
||||
import json
|
||||
sample_file = data_dir / "sample_formal.jsonl"
|
||||
with open(sample_file, 'w', encoding='utf-8') as f:
|
||||
for item in sample_data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"✅ Created sample styling dataset: {sample_file}")
|
||||
print(f" Contains {len(sample_data)} examples")
|
||||
print(f" Format: text → styled_text")
|
||||
print(f" Ready to use with configs/styling/formal.yaml")
|
||||
|
||||
def create_custom_styling_config():
|
||||
"""Create a custom styling configuration file"""
|
||||
custom_config = """task:
|
||||
name: "styling"
|
||||
type: "style_transfer"
|
||||
|
||||
data:
|
||||
source: "custom"
|
||||
input_field: "text"
|
||||
output_field: "styled_text"
|
||||
instruction: "Rewrite the following text in a professional business style"
|
||||
data_format: "jsonl"
|
||||
max_length: 512
|
||||
min_length: 10
|
||||
clean_text: true
|
||||
lowercase: false
|
||||
train_split: 0.8
|
||||
validation_split: 0.1
|
||||
test_split: 0.1
|
||||
output_format: "alpaca"
|
||||
output_dir: "./data/processed/styling/professional"
|
||||
|
||||
model:
|
||||
name: "t5-base"
|
||||
max_length: 512
|
||||
|
||||
training:
|
||||
num_epochs: 3
|
||||
batch_size: 16
|
||||
learning_rate: 3e-5
|
||||
weight_decay: 0.01
|
||||
warmup_ratio: 0.1
|
||||
lr_scheduler_type: "linear"
|
||||
|
||||
inference:
|
||||
batch_size: 32
|
||||
max_new_tokens: 128
|
||||
temperature: 0.8
|
||||
"""
|
||||
|
||||
config_path = "configs/styling/professional.yaml"
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(custom_config)
|
||||
|
||||
print(f"✅ Created custom styling config: {config_path}")
|
||||
print(" This config is set up for professional business style transfer")
|
||||
|
||||
def handle_direct_args():
|
||||
"""Handle direct command-line arguments by passing them to the styling pipeline"""
|
||||
parser = argparse.ArgumentParser(description="Styling Data Processor")
|
||||
|
||||
# Add all the same arguments as the styling pipeline
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
|
||||
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
|
||||
parser.add_argument("--data-path", type=str, help="Path to custom data file")
|
||||
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
|
||||
parser.add_argument("--input-field", type=str, help="Input field name")
|
||||
parser.add_argument("--output-field", type=str, help="Output field name")
|
||||
parser.add_argument("--instruction", type=str, help="Style instruction")
|
||||
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
|
||||
parser.add_argument("--train-split", type=float, help="Training split ratio")
|
||||
parser.add_argument("--validation-split", type=float, help="Validation split ratio")
|
||||
parser.add_argument("--test-split", type=float, help="Test split ratio")
|
||||
parser.add_argument("--clean-text", action="store_true", help="Clean and normalize text")
|
||||
parser.add_argument("--remove-special-chars", action="store_true", help="Remove special characters")
|
||||
parser.add_argument("--lowercase", action="store_true", help="Convert text to lowercase")
|
||||
parser.add_argument("--min-length", type=int, help="Minimum text length")
|
||||
parser.add_argument("--max-length", type=int, help="Maximum text length")
|
||||
parser.add_argument("--output-format", choices=["styling", "alpaca"], help="Output format")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory")
|
||||
|
||||
# HuggingFace dataset options
|
||||
parser.add_argument("--create-hf-dataset", action="store_true", help="Create HuggingFace dataset")
|
||||
parser.add_argument("--hf-dataset-path", type=str, help="Path to save HuggingFace dataset")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO", help="Logging level")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Build command to call the styling pipeline
|
||||
cmd = ["python", "pipelines/styling/data_processor.py"]
|
||||
|
||||
# Add all arguments that were provided
|
||||
for arg_name, arg_value in vars(args).items():
|
||||
if arg_value is not None:
|
||||
if isinstance(arg_value, bool):
|
||||
if arg_value: # Only add flag if True
|
||||
cmd.append(f"--{arg_name.replace('_', '-')}")
|
||||
else:
|
||||
cmd.extend([f"--{arg_name.replace('_', '-')}", str(arg_value)])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
print("✅ Styling data processing completed successfully!")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Error running styling data processor: {e}")
|
||||
print(f"Error output: {e.stderr}")
|
||||
return False
|
||||
|
||||
def show_styling_features():
|
||||
"""Show the features of the styling data processor"""
|
||||
print("=== Styling Data Processor Features ===")
|
||||
print()
|
||||
print("1. **Style Transfer Tasks**:")
|
||||
print(" - Formal vs. Informal style")
|
||||
print(" - Professional vs. Casual tone")
|
||||
print(" - Academic vs. Conversational")
|
||||
print(" - Any custom style instruction")
|
||||
print()
|
||||
print("2. **Data Formats Supported**:")
|
||||
print(" - HuggingFace datasets")
|
||||
print(" - Custom JSONL/CSV/JSON files")
|
||||
print(" - Automatic train/validation/test splits")
|
||||
print()
|
||||
print("3. **Output Formats**:")
|
||||
print(" - Raw styling format (input/output)")
|
||||
print(" - Alpaca format (instruction/input/output)")
|
||||
print(" - HuggingFace dataset format")
|
||||
print()
|
||||
print("4. **Advanced Features**:")
|
||||
print(" - Configurable field mapping")
|
||||
print(" - Text preprocessing options")
|
||||
print(" - Automatic dataset saving/loading")
|
||||
print(" - YAML configuration support")
|
||||
print()
|
||||
print("=== Usage Examples ===")
|
||||
print()
|
||||
print("1. Use YAML config only:")
|
||||
print(" python scripts/styling/data_processor.py --config configs/styling/formal.yaml")
|
||||
print()
|
||||
print("2. Override YAML values:")
|
||||
print(" python scripts/styling/data_processor.py --config configs/styling/formal.yaml --max-samples 500")
|
||||
print()
|
||||
print("3. Create sample data:")
|
||||
print(" python scripts/styling/data_processor.py create-sample-data")
|
||||
print()
|
||||
print("4. Create custom config:")
|
||||
print(" python scripts/styling/data_processor.py create-config")
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
if len(sys.argv) > 1:
|
||||
# Check if it's a subcommand
|
||||
if sys.argv[1] in ["examples", "create-sample-data", "create-config", "features"]:
|
||||
# Handle subcommands
|
||||
if sys.argv[1] == "examples":
|
||||
run_styling_examples()
|
||||
elif sys.argv[1] == "create-sample-data":
|
||||
create_sample_styling_data()
|
||||
elif sys.argv[1] == "create-config":
|
||||
create_custom_styling_config()
|
||||
elif sys.argv[1] == "features":
|
||||
show_styling_features()
|
||||
else:
|
||||
# Handle direct arguments (pass through to pipeline)
|
||||
handle_direct_args()
|
||||
else:
|
||||
print("Styling Data Processor")
|
||||
print("=====================")
|
||||
print()
|
||||
print("This script runs the styling data processor for style transfer tasks.")
|
||||
print("It supports both YAML configurations and command-line overrides.")
|
||||
print()
|
||||
print("Usage:")
|
||||
print(" python scripts/styling/data_processor.py examples # Run examples")
|
||||
print(" python scripts/styling/data_processor.py create-sample-data # Create sample dataset")
|
||||
print(" python scripts/styling/data_processor.py create-config # Create custom config")
|
||||
print(" python scripts/styling/data_processor.py features # Show features")
|
||||
print()
|
||||
print("Direct pipeline usage:")
|
||||
print(" python scripts/styling/data_processor.py --config configs/styling/formal.yaml")
|
||||
print(" python scripts/styling/data_processor.py --data-source custom --data-path ./data.jsonl")
|
||||
print()
|
||||
print("Key Features:")
|
||||
print(" ✅ Style transfer with custom instructions")
|
||||
print(" ✅ Multiple data source support")
|
||||
print(" ✅ YAML configuration files")
|
||||
print(" ✅ CLI argument overrides")
|
||||
print(" ✅ Automatic data splitting")
|
||||
print(" ✅ HuggingFace dataset export")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Inference Script
|
||||
Provides a command-line interface to run the styling inference pipeline
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_inference_with_config(config_path: str, **cli_overrides):
|
||||
"""Run the styling inference pipeline with YAML configuration"""
|
||||
print(f"🚀 Starting styling inference with config: {config_path}")
|
||||
print()
|
||||
|
||||
# Build command
|
||||
cmd = ["python", "pipelines/styling/inference.py", "--config", config_path]
|
||||
|
||||
# Add CLI overrides
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
if key == "model_path":
|
||||
cmd.extend(["--model-path", str(value)])
|
||||
elif key == "text":
|
||||
cmd.extend(["--text", str(value)])
|
||||
elif key == "input_file":
|
||||
cmd.extend(["--input-file", str(value)])
|
||||
elif key == "max_tokens":
|
||||
cmd.extend(["--max-tokens", str(value)])
|
||||
elif key == "temperature":
|
||||
cmd.extend(["--temperature", str(value)])
|
||||
elif key == "instruction":
|
||||
cmd.extend(["--instruction", str(value)])
|
||||
elif key == "output_file":
|
||||
cmd.extend(["--output-file", str(value)])
|
||||
elif key == "streaming":
|
||||
cmd.append("--streaming")
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
print("✅ Inference completed successfully!")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Inference failed: {e}")
|
||||
print(f"Error output: {e.stderr}")
|
||||
return False
|
||||
|
||||
def show_inference_features():
|
||||
"""Show the features of the styling inference pipeline"""
|
||||
print("=== Styling Inference Pipeline Features ===")
|
||||
print()
|
||||
print("1. **Model Support**:")
|
||||
print(" - Trained LoRA models")
|
||||
print(" - Base models from HuggingFace Hub")
|
||||
print(" - Automatic model loading and preparation")
|
||||
print()
|
||||
print("2. **Inference Modes**:")
|
||||
print(" - Single text inference")
|
||||
print(" - Batch file processing")
|
||||
print(" - Interactive mode")
|
||||
print(" - Streaming generation")
|
||||
print()
|
||||
print("3. **Generation Control**:")
|
||||
print(" - Configurable temperature and top-p")
|
||||
print(" - Adjustable max tokens")
|
||||
print(" - Custom style instructions")
|
||||
print()
|
||||
print("4. **Output Options**:")
|
||||
print(" - Console output")
|
||||
print(" - File output")
|
||||
print(" - Streaming real-time generation")
|
||||
|
||||
def create_inference_example():
|
||||
"""Create an inference example using the formal style configuration"""
|
||||
print("=== Inference Example: Formal Style Transfer ===")
|
||||
print()
|
||||
|
||||
# Check if we have the required files
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
|
||||
if not Path(config_path).exists():
|
||||
print(f"❌ Configuration file not found: {config_path}")
|
||||
print(" Please run the data processor first to create the configuration")
|
||||
return False
|
||||
|
||||
print("✅ Found configuration file!")
|
||||
print(f" Config: {config_path}")
|
||||
print()
|
||||
|
||||
# Example text
|
||||
example_text = "Hey, what's up? I'm gonna go grab some food later."
|
||||
|
||||
print(f"📝 Example text: {example_text}")
|
||||
print()
|
||||
|
||||
# Run inference
|
||||
success = run_inference_with_config(
|
||||
config_path=config_path,
|
||||
text=example_text,
|
||||
instruction="Rewrite the following text in a formal style"
|
||||
)
|
||||
|
||||
if success:
|
||||
print("🎉 Inference example completed!")
|
||||
|
||||
return success
|
||||
|
||||
def create_test_file():
|
||||
"""Create a test file with sample texts for batch inference"""
|
||||
test_file = "test_texts.txt"
|
||||
|
||||
test_texts = [
|
||||
"Hey, what's up? How are you doing today?",
|
||||
"I'm gonna go to the store later to get some stuff.",
|
||||
"This is pretty cool, right?",
|
||||
"Can you help me out with this?",
|
||||
"Thanks a lot for your help!"
|
||||
]
|
||||
|
||||
with open(test_file, 'w', encoding='utf-8') as f:
|
||||
for text in test_texts:
|
||||
f.write(text + '\n')
|
||||
|
||||
print(f"✅ Created test file: {test_file}")
|
||||
print(f" Contains {len(test_texts)} sample texts")
|
||||
return test_file
|
||||
|
||||
def run_batch_inference_example():
|
||||
"""Run a batch inference example"""
|
||||
print("=== Batch Inference Example ===")
|
||||
print()
|
||||
|
||||
# Create test file
|
||||
test_file = create_test_file()
|
||||
|
||||
# Check configuration
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
if not Path(config_path).exists():
|
||||
print(f"❌ Configuration file not found: {config_path}")
|
||||
return False
|
||||
|
||||
print("✅ Running batch inference...")
|
||||
print()
|
||||
|
||||
# Run batch inference
|
||||
success = run_inference_with_config(
|
||||
config_path=config_path,
|
||||
input_file=test_file,
|
||||
output_file="styled_results.txt",
|
||||
instruction="Rewrite the following text in a formal style"
|
||||
)
|
||||
|
||||
if success:
|
||||
print("🎉 Batch inference completed!")
|
||||
print(" Results saved to: styled_results.txt")
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Inference Script")
|
||||
|
||||
# Subcommands
|
||||
parser.add_argument("command", choices=["infer", "example", "batch", "features"],
|
||||
help="Command to run")
|
||||
|
||||
# Inference arguments
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
parser.add_argument("--model-path", type=str, help="Path to trained model")
|
||||
parser.add_argument("--text", type=str, help="Single text to style transfer")
|
||||
parser.add_argument("--input-file", type=str, help="File containing texts to process")
|
||||
parser.add_argument("--max-tokens", type=int, help="Maximum new tokens to generate")
|
||||
parser.add_argument("--temperature", type=float, help="Sampling temperature")
|
||||
parser.add_argument("--instruction", type=str, help="Custom style instruction")
|
||||
parser.add_argument("--output-file", type=str, help="Output file for results")
|
||||
parser.add_argument("--streaming", action="store_true", help="Enable streaming generation")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "features":
|
||||
show_inference_features()
|
||||
|
||||
elif args.command == "example":
|
||||
create_inference_example()
|
||||
|
||||
elif args.command == "batch":
|
||||
run_batch_inference_example()
|
||||
|
||||
elif args.command == "infer":
|
||||
if not args.config:
|
||||
print("❌ --config is required for inference")
|
||||
print("Usage: python scripts/styling/inference.py infer --config config.yaml [options]")
|
||||
sys.exit(1)
|
||||
|
||||
# Check if we have input
|
||||
if not args.text and not args.input_file:
|
||||
print("❌ Either --text or --input-file is required")
|
||||
print("Usage: python scripts/styling/inference.py infer --config config.yaml --text 'your text'")
|
||||
sys.exit(1)
|
||||
|
||||
success = run_inference_with_config(
|
||||
config_path=args.config,
|
||||
model_path=args.model_path,
|
||||
text=args.text,
|
||||
input_file=args.input_file,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
instruction=args.instruction,
|
||||
output_file=args.output_file,
|
||||
streaming=args.streaming
|
||||
)
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Styling Training Script
|
||||
Provides a command-line interface to run the styling training pipeline
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def run_training_with_config(config_path: str, dataset_path: str = None, **cli_overrides):
|
||||
"""Run the styling training pipeline with YAML configuration"""
|
||||
print(f"Starting styling training with config: {config_path}")
|
||||
if dataset_path:
|
||||
print(f"Training dataset: {dataset_path}")
|
||||
else:
|
||||
print("Training dataset: Will use output_dir from YAML config")
|
||||
print()
|
||||
|
||||
# Build command
|
||||
cmd = ["python", "pipelines/styling/train.py", "--config", config_path]
|
||||
|
||||
# Add dataset path if provided
|
||||
if dataset_path:
|
||||
cmd.extend(["--dataset", dataset_path])
|
||||
|
||||
# Add CLI overrides
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
if key == "output_dir":
|
||||
cmd.extend(["--output-dir", str(value)])
|
||||
elif key == "epochs":
|
||||
cmd.extend(["--epochs", str(value)])
|
||||
elif key == "batch_size":
|
||||
cmd.extend(["--batch-size", str(value)])
|
||||
elif key == "learning_rate":
|
||||
cmd.extend(["--learning-rate", str(value)])
|
||||
elif key == "max_steps":
|
||||
cmd.extend(["--max-steps", str(value)])
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print()
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
print("Training completed successfully!")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Training failed: {e}")
|
||||
print(f"Error output: {e.stderr}")
|
||||
return False
|
||||
|
||||
def show_training_features():
|
||||
"""Show the features of the styling training pipeline"""
|
||||
print("=== Styling Training Pipeline Features ===")
|
||||
print()
|
||||
print("1. **Model Support**:")
|
||||
print(" - Unsloth optimized models (4x faster)")
|
||||
print(" - LoRA fine-tuning for efficiency")
|
||||
print(" - Support for Llama-3.1, Mistral, Phi-3, Gemma")
|
||||
print()
|
||||
print("2. **Training Features**:")
|
||||
print(" - SFTTrainer with instruction tuning")
|
||||
print(" - Automatic mixed precision (FP16/BF16)")
|
||||
print(" - Gradient checkpointing for memory efficiency")
|
||||
print(" - Configurable LoRA parameters")
|
||||
print()
|
||||
print("3. **Configuration**:")
|
||||
print(" - YAML configuration files")
|
||||
print(" - CLI argument overrides")
|
||||
print(" - Automatic device detection")
|
||||
print()
|
||||
print("4. **Output**:")
|
||||
print(" - Saved LoRA models")
|
||||
print(" - Training logs and checkpoints")
|
||||
print(" - Ready for inference")
|
||||
|
||||
def create_training_example():
|
||||
"""Create a training example using the formal style configuration"""
|
||||
print("=== Training Example: Formal Style Transfer ===")
|
||||
print()
|
||||
|
||||
# Check if we have the required files
|
||||
config_path = "configs/styling/formal.yaml"
|
||||
|
||||
if not Path(config_path).exists():
|
||||
print(f"Configuration file not found: {config_path}")
|
||||
print(" Please run the data processor first to create the configuration")
|
||||
return False
|
||||
|
||||
print("Found required files!")
|
||||
print(f" Config: {config_path}")
|
||||
print(" Dataset: Will use output_dir from YAML config")
|
||||
print(" The training pipeline will automatically:")
|
||||
print(" - Load data from the output_dir specified in YAML")
|
||||
print(" - Convert JSONL files to HuggingFace dataset format")
|
||||
print(" - Apply formatting with EOS tokens")
|
||||
print(" - Train the model using SFTTrainer")
|
||||
print()
|
||||
|
||||
# Run training without explicit dataset path - will use YAML config
|
||||
success = run_training_with_config(
|
||||
config_path=config_path,
|
||||
dataset_path=None, # Use output_dir from YAML config
|
||||
epochs=1,
|
||||
batch_size=2,
|
||||
learning_rate=2e-4
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Training example completed!")
|
||||
print(" Model saved to: ./models/styling")
|
||||
print(" Ready for inference!")
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description="Styling Training Script")
|
||||
|
||||
# Subcommands
|
||||
parser.add_argument("command", choices=["train", "example", "features"],
|
||||
help="Command to run")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
parser.add_argument("--dataset", type=str, help="Path to training dataset")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for model")
|
||||
parser.add_argument("--epochs", type=int, help="Number of training epochs")
|
||||
parser.add_argument("--batch-size", type=int, help="Training batch size")
|
||||
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
||||
parser.add_argument("--max-steps", type=int, help="Maximum training steps")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "features":
|
||||
show_training_features()
|
||||
|
||||
elif args.command == "example":
|
||||
create_training_example()
|
||||
|
||||
elif args.command == "train":
|
||||
if not args.config:
|
||||
print("❌ --config is required for training")
|
||||
print("Usage: python scripts/styling/train.py train --config config.yaml")
|
||||
sys.exit(1)
|
||||
|
||||
# If dataset is not provided, try to use output_dir from config
|
||||
dataset_path = args.dataset if args.dataset else None
|
||||
|
||||
success = run_training_with_config(
|
||||
config_path=args.config,
|
||||
dataset_path=dataset_path,
|
||||
output_dir=args.output_dir,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
max_steps=args.max_steps
|
||||
)
|
||||
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user