# -*- coding: utf-8 -*- # @Time : 2024/6/7 # @Author : 魏志亮 import os import yaml from typing import Any, Optional, Dict def load_yaml_config(file_path: str, encoding: str = 'utf-8') -> Dict[str, Any]: """ 加载YAML配置文件 Args: file_path: YAML文件路径 encoding: 文件编码,默认为utf-8 Returns: 解析后的配置字典 Raises: FileNotFoundError: 文件不存在时抛出 yaml.YAMLError: YAML解析错误时抛出 """ try: with open(file_path, 'r', encoding=encoding) as f: data = yaml.safe_load(f) # 确保返回字典类型,防止YAML文件为空时返回None return data if isinstance(data, dict) else {} except FileNotFoundError: raise FileNotFoundError(f"配置文件不存在: {file_path}") except yaml.YAMLError as e: raise yaml.YAMLError(f"YAML解析错误: {e}") def get_config_value(config: Dict[str, Any], key: str, default: Optional[Any] = None) -> Any: """ 从配置字典中安全地获取值 Args: config: 配置字典 key: 配置键名 default: 默认值,当键不存在或值为None时返回 Returns: 配置值或默认值 """ # 处理config为None的情况 if config is None: return default # 支持嵌套键,如 "database.host" keys = key.split('.') value = config for k in keys: if isinstance(value, dict) and k in value: value = value[k] else: value = None break # 如果值为None且提供了默认值,返回默认值 if value is None and default is not None: return default return value def merge_configs(base_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]: """ 合并配置字典 Args: base_config: 基础配置 override_config: 覆盖配置 Returns: 合并后的配置 """ result = base_config.copy() for key, value in override_config.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): # 递归合并嵌套字典 result[key] = merge_configs(result[key], value) else: # 直接覆盖 result[key] = value return result def load_config_with_env(file_path: str, encoding: str = 'utf-8') -> Dict[str, Any]: """ 加载配置文件并支持环境变量覆盖 Args: file_path: YAML文件路径 encoding: 文件编码,默认为utf-8 Returns: 解析后的配置字典 """ # 加载基础配置 base_config = load_yaml_config(file_path, encoding) # 检查是否有环境变量覆盖 env_prefix = "ETL_" override_config = {} for key, value in os.environ.items(): if key.startswith(env_prefix): # 转换环境变量名到配置键名 config_key = key[len(env_prefix):].lower().replace('_', '.') # 解析值 if value.lower() == 'true': parsed_value = True elif value.lower() == 'false': parsed_value = False elif value.isdigit(): parsed_value = int(value) elif '.' in value and all(part.isdigit() for part in value.split('.')): parsed_value = float(value) else: parsed_value = value # 构建嵌套配置 keys = config_key.split('.') current = override_config for k in keys[:-1]: if k not in current: current[k] = {} current = current[k] current[keys[-1]] = parsed_value # 合并配置 if override_config: base_config = merge_configs(base_config, override_config) return base_config # 为了保持向后兼容,保留原函数名(可选) yaml_conf = load_yaml_config read_conf = get_config_value