initial setupt
This commit is contained in:
@@ -0,0 +1,130 @@
|
||||
import yaml
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConfigManager:
|
||||
"""Manages configuration loading from YAML files and command-line arguments"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = {}
|
||||
|
||||
def load_yaml_config(self, config_path: str) -> Dict[str, Any]:
|
||||
"""Load configuration from YAML file"""
|
||||
config_file = Path(config_path)
|
||||
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
logger.info(f"Loaded configuration from: {config_path}")
|
||||
return config
|
||||
|
||||
def merge_configs(self, yaml_config: Dict[str, Any], cli_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge YAML configuration with command-line arguments"""
|
||||
merged_config = yaml_config.copy()
|
||||
|
||||
# Override YAML config with CLI args (CLI takes precedence)
|
||||
for key, value in cli_args.items():
|
||||
if value is not None: # Only override if CLI arg is provided
|
||||
self._set_nested_value(merged_config, key, value)
|
||||
|
||||
return merged_config
|
||||
|
||||
def _set_nested_value(self, config: Dict[str, Any], key: str, value: Any):
|
||||
"""Set a nested value in config using dot notation (e.g., 'training.batch_size')"""
|
||||
keys = key.split('.')
|
||||
current = config
|
||||
|
||||
for k in keys[:-1]:
|
||||
if k not in current:
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
|
||||
current[keys[-1]] = value
|
||||
|
||||
def get_config_value(self, config: Dict[str, Any], key: str, default: Any = None) -> Any:
|
||||
"""Get a nested value from config using dot notation"""
|
||||
keys = key.split('.')
|
||||
current = config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
def create_argparser_with_yaml(self, description: str = "Pipeline with YAML configuration") -> argparse.ArgumentParser:
|
||||
"""Create argument parser that supports YAML config files"""
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
# YAML configuration
|
||||
parser.add_argument("--config", type=str, help="Path to YAML configuration file")
|
||||
|
||||
# Data source arguments
|
||||
parser.add_argument("--data-source", choices=["huggingface", "custom"], help="Data source")
|
||||
parser.add_argument("--dataset-name", type=str, help="HuggingFace dataset name")
|
||||
parser.add_argument("--data-path", type=str, help="Path to custom data file")
|
||||
parser.add_argument("--data-format", choices=["jsonl", "csv", "json"], help="Data format")
|
||||
|
||||
# Field mapping
|
||||
parser.add_argument("--input-field", type=str, help="Input field name")
|
||||
parser.add_argument("--label-field", type=str, help="Label field name")
|
||||
|
||||
# Processing options
|
||||
parser.add_argument("--max-samples", type=int, help="Maximum samples to process")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory")
|
||||
|
||||
# Model options
|
||||
parser.add_argument("--model-name", type=str, help="Model name")
|
||||
parser.add_argument("--max-length", type=int, help="Maximum sequence length")
|
||||
|
||||
# Training options
|
||||
parser.add_argument("--num-epochs", type=int, help="Number of epochs")
|
||||
parser.add_argument("--batch-size", type=int, help="Batch size")
|
||||
parser.add_argument("--learning-rate", type=float, help="Learning rate")
|
||||
|
||||
# Inference options
|
||||
parser.add_argument("--model-path", type=str, help="Path to saved model")
|
||||
parser.add_argument("--device", choices=["auto", "cuda", "cpu"], help="Device")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Log level")
|
||||
|
||||
return parser
|
||||
|
||||
def load_and_merge_config(self, config_path: Optional[str] = None, cli_args: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Load YAML config and merge with CLI arguments"""
|
||||
yaml_config = {}
|
||||
|
||||
if config_path:
|
||||
yaml_config = self.load_yaml_config(config_path)
|
||||
|
||||
cli_args = cli_args or {}
|
||||
|
||||
# Filter out None values from CLI args
|
||||
cli_args = {k: v for k, v in cli_args.items() if v is not None}
|
||||
|
||||
merged_config = self.merge_configs(yaml_config, cli_args)
|
||||
|
||||
logger.info("Configuration loaded and merged successfully")
|
||||
return merged_config
|
||||
|
||||
|
||||
def load_config_from_yaml(config_path: str) -> Dict[str, Any]:
|
||||
"""Simple function to load config from YAML file"""
|
||||
config_manager = ConfigManager()
|
||||
return config_manager.load_yaml_config(config_path)
|
||||
|
||||
|
||||
def create_config_from_yaml_and_cli(config_path: Optional[str] = None, **cli_args) -> Dict[str, Any]:
|
||||
"""Create configuration from YAML file and CLI arguments"""
|
||||
config_manager = ConfigManager()
|
||||
return config_manager.load_and_merge_config(config_path, cli_args)
|
||||
Reference in New Issue
Block a user