import yaml import argparse from pathlib import Path from typing import Dict, Any, Optional import logging logger = logging.getLogger(__name__) class ConfigManager: """Manages configuration loading from YAML files and command-line arguments""" def __init__(self): self.config = {} def load_yaml_config(self, config_path: str) -> Dict[str, Any]: """Load configuration from YAML file""" config_file = Path(config_path) if not config_file.exists(): raise FileNotFoundError(f"Configuration file not found: {config_path}") with open(config_file, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) logger.info(f"Loaded configuration from: {config_path}") return config def merge_configs(self, yaml_config: Dict[str, Any], cli_args: Dict[str, Any]) -> Dict[str, Any]: """Merge YAML configuration with command-line arguments""" merged_config = yaml_config.copy() # Override YAML config with CLI args (CLI takes precedence) for key, value in cli_args.items(): if value is not None: # Only override if CLI arg is provided self._set_nested_value(merged_config, key, value) return merged_config def _set_nested_value(self, config: Dict[str, Any], key: str, value: Any): """Set a nested value in config using dot notation (e.g., 'training.batch_size')""" keys = key.split('.') current = config for k in keys[:-1]: if k not in current: current[k] = {} current = current[k] current[keys[-1]] = value def get_config_value(self, config: Dict[str, Any], key: str, default: Any = None) -> Any: """Get a nested value from config using dot notation""" keys = key.split('.') current = config for k in keys: if isinstance(current, dict) and k in current: current = current[k] else: return default return current def create_argparser_with_yaml(self, description: str = "Pipeline with YAML configuration") -> argparse.ArgumentParser: """Create argument parser that supports YAML config files""" parser = argparse.ArgumentParser(description=description) # YAML configuration parser.add_argument("--config", type=str, help="Path to YAML configuration file") # Data source arguments parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source") parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name") parser.add_argument("--data-path", type=str, help="Path to custom data file") parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format") # Field mapping parser.add_argument("--input-field", type=str, help="Input field name") parser.add_argument("--label-field", type=str, help="Label field name") # Processing options parser.add_argument("--max-samples", type=int, help="Maximum samples to process") parser.add_argument("--output-dir", type=str, help="Output directory") # Model options parser.add_argument("--model-name", type=str, help="Model name") parser.add_argument("--max-length", type=int, help="Maximum sequence length") # Training options parser.add_argument("--num-epochs", type=int, help="Number of epochs") parser.add_argument("--batch-size", type=int, help="Batch size") parser.add_argument("--learning-rate", type=float, help="Learning rate") # Inference options parser.add_argument("--model-path", type=str, help="Path to saved model") parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device") # Logging parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level") return parser def load_and_merge_config(self, config_path: Optional[str] = None, cli_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Load YAML config and merge with CLI arguments""" yaml_config = {} if config_path: yaml_config = self.load_yaml_config(config_path) cli_args = cli_args or {} # Filter out None values from CLI args cli_args = {k: v for k, v in cli_args.items() if v is not None} merged_config = self.merge_configs(yaml_config, cli_args) logger.info("Configuration loaded and merged successfully") return merged_config def load_config_from_yaml(config_path: str) -> Dict[str, Any]: """Simple function to load config from YAML file""" config_manager = ConfigManager() return config_manager.load_yaml_config(config_path) def create_config_from_yaml_and_cli(config_path: Optional[str] = None, **cli_args) -> Dict[str, Any]: """Create configuration from YAML file and CLI arguments""" config_manager = ConfigManager() return config_manager.load_and_merge_config(config_path, cli_args)