Files
DS-LLM-TEMPLATE-FINETUNING/utils/config/config_manager.py
T
OwusuBlessing fef3f5ae35 initial setupt
2025-08-06 22:45:37 +01:00

130 lines
5.2 KiB
Python

import yaml
import argparse
from pathlib import Path
from typing import Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
class ConfigManager:
"""Manages configuration loading from YAML files and command-line arguments"""
def __init__(self):
self.config = {}
def load_yaml_config(self, config_path: str) -> Dict[str, Any]:
"""Load configuration from YAML file"""
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
logger.info(f"Loaded configuration from: {config_path}")
return config
def merge_configs(self, yaml_config: Dict[str, Any], cli_args: Dict[str, Any]) -> Dict[str, Any]:
"""Merge YAML configuration with command-line arguments"""
merged_config = yaml_config.copy()
# Override YAML config with CLI args (CLI takes precedence)
for key, value in cli_args.items():
if value is not None: # Only override if CLI arg is provided
self._set_nested_value(merged_config, key, value)
return merged_config
def _set_nested_value(self, config: Dict[str, Any], key: str, value: Any):
"""Set a nested value in config using dot notation (e.g., 'training.batch_size')"""
keys = key.split('.')
current = config
for k in keys[:-1]:
if k not in current:
current[k] = {}
current = current[k]
current[keys[-1]] = value
def get_config_value(self, config: Dict[str, Any], key: str, default: Any = None) -> Any:
"""Get a nested value from config using dot notation"""
keys = key.split('.')
current = config
for k in keys:
if isinstance(current, dict) and k in current:
current = current[k]
else:
return default
return current
def create_argparser_with_yaml(self, description: str = "Pipeline with YAML configuration") -> argparse.ArgumentParser:
"""Create argument parser that supports YAML config files"""
parser = argparse.ArgumentParser(description=description)
# YAML configuration
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
# Data source arguments
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
parser.add_argument("--data-path", type=str, help="Path to custom data file")
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
# Field mapping
parser.add_argument("--input-field", type=str, help="Input field name")
parser.add_argument("--label-field", type=str, help="Label field name")
# Processing options
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
parser.add_argument("--output-dir", type=str, help="Output directory")
# Model options
parser.add_argument("--model-name", type=str, help="Model name")
parser.add_argument("--max-length", type=int, help="Maximum sequence length")
# Training options
parser.add_argument("--num-epochs", type=int, help="Number of epochs")
parser.add_argument("--batch-size", type=int, help="Batch size")
parser.add_argument("--learning-rate", type=float, help="Learning rate")
# Inference options
parser.add_argument("--model-path", type=str, help="Path to saved model")
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device")
# Logging
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level")
return parser
def load_and_merge_config(self, config_path: Optional[str] = None, cli_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Load YAML config and merge with CLI arguments"""
yaml_config = {}
if config_path:
yaml_config = self.load_yaml_config(config_path)
cli_args = cli_args or {}
# Filter out None values from CLI args
cli_args = {k: v for k, v in cli_args.items() if v is not None}
merged_config = self.merge_configs(yaml_config, cli_args)
logger.info("Configuration loaded and merged successfully")
return merged_config
def load_config_from_yaml(config_path: str) -> Dict[str, Any]:
"""Simple function to load config from YAML file"""
config_manager = ConfigManager()
return config_manager.load_yaml_config(config_path)
def create_config_from_yaml_and_cli(config_path: Optional[str] = None, **cli_args) -> Dict[str, Any]:
"""Create configuration from YAML file and CLI arguments"""
config_manager = ConfigManager()
return config_manager.load_and_merge_config(config_path, cli_args)