130 lines
5.2 KiB
Python
130 lines
5.2 KiB
Python
import yaml
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ConfigManager:
|
|
"""Manages configuration loading from YAML files and command-line arguments"""
|
|
|
|
def __init__(self):
|
|
self.config = {}
|
|
|
|
def load_yaml_config(self, config_path: str) -> Dict[str, Any]:
|
|
"""Load configuration from YAML file"""
|
|
config_file = Path(config_path)
|
|
|
|
if not config_file.exists():
|
|
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
|
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
logger.info(f"Loaded configuration from: {config_path}")
|
|
return config
|
|
|
|
def merge_configs(self, yaml_config: Dict[str, Any], cli_args: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Merge YAML configuration with command-line arguments"""
|
|
merged_config = yaml_config.copy()
|
|
|
|
# Override YAML config with CLI args (CLI takes precedence)
|
|
for key, value in cli_args.items():
|
|
if value is not None: # Only override if CLI arg is provided
|
|
self._set_nested_value(merged_config, key, value)
|
|
|
|
return merged_config
|
|
|
|
def _set_nested_value(self, config: Dict[str, Any], key: str, value: Any):
|
|
"""Set a nested value in config using dot notation (e.g., 'training.batch_size')"""
|
|
keys = key.split('.')
|
|
current = config
|
|
|
|
for k in keys[:-1]:
|
|
if k not in current:
|
|
current[k] = {}
|
|
current = current[k]
|
|
|
|
current[keys[-1]] = value
|
|
|
|
def get_config_value(self, config: Dict[str, Any], key: str, default: Any = None) -> Any:
|
|
"""Get a nested value from config using dot notation"""
|
|
keys = key.split('.')
|
|
current = config
|
|
|
|
for k in keys:
|
|
if isinstance(current, dict) and k in current:
|
|
current = current[k]
|
|
else:
|
|
return default
|
|
|
|
return current
|
|
|
|
def create_argparser_with_yaml(self, description: str = "Pipeline with YAML configuration") -> argparse.ArgumentParser:
|
|
"""Create argument parser that supports YAML config files"""
|
|
parser = argparse.ArgumentParser(description=description)
|
|
|
|
# YAML configuration
|
|
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
|
|
|
# Data source arguments
|
|
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
|
|
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
|
|
parser.add_argument("--data-path", type=str, help="Path to custom data file")
|
|
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
|
|
|
|
# Field mapping
|
|
parser.add_argument("--input-field", type=str, help="Input field name")
|
|
parser.add_argument("--label-field", type=str, help="Label field name")
|
|
|
|
# Processing options
|
|
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
|
|
parser.add_argument("--output-dir", type=str, help="Output directory")
|
|
|
|
# Model options
|
|
parser.add_argument("--model-name", type=str, help="Model name")
|
|
parser.add_argument("--max-length", type=int, help="Maximum sequence length")
|
|
|
|
# Training options
|
|
parser.add_argument("--num-epochs", type=int, help="Number of epochs")
|
|
parser.add_argument("--batch-size", type=int, help="Batch size")
|
|
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
|
|
|
# Inference options
|
|
parser.add_argument("--model-path", type=str, help="Path to saved model")
|
|
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device")
|
|
|
|
# Logging
|
|
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level")
|
|
|
|
return parser
|
|
|
|
def load_and_merge_config(self, config_path: Optional[str] = None, cli_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
"""Load YAML config and merge with CLI arguments"""
|
|
yaml_config = {}
|
|
|
|
if config_path:
|
|
yaml_config = self.load_yaml_config(config_path)
|
|
|
|
cli_args = cli_args or {}
|
|
|
|
# Filter out None values from CLI args
|
|
cli_args = {k: v for k, v in cli_args.items() if v is not None}
|
|
|
|
merged_config = self.merge_configs(yaml_config, cli_args)
|
|
|
|
logger.info("Configuration loaded and merged successfully")
|
|
return merged_config
|
|
|
|
|
|
def load_config_from_yaml(config_path: str) -> Dict[str, Any]:
|
|
"""Simple function to load config from YAML file"""
|
|
config_manager = ConfigManager()
|
|
return config_manager.load_yaml_config(config_path)
|
|
|
|
|
|
def create_config_from_yaml_and_cli(config_path: Optional[str] = None, **cli_args) -> Dict[str, Any]:
|
|
"""Create configuration from YAML file and CLI arguments"""
|
|
config_manager = ConfigManager()
|
|
return config_manager.load_and_merge_config(config_path, cli_args) |