# conf_train.py import yaml from dataclasses import dataclass @dataclass class Config: """ 配置类,用于加载和管理训练配置。 """ model_name: str max_seq_length: int dtype: str load_in_4bit: bool fast_inference: bool # Enable vLLM fast inference lora_rank: int gpu_memory_utilization: float use_vllm:bool learning_rate: float adam_beta1: float adam_beta2: float weight_decay: float warmup_ratio: float lr_scheduler_type: str optim: str logging_steps: int per_device_train_batch_size: int gradient_accumulation_steps: int num_generations: int max_prompt_length: int max_completion_length: int num_train_epochs: int max_steps: int save_steps: int max_grad_norm: float report_to: str output_dir: str train_data_path: str save_path: str def load_config(config_path: str=f"../conf/conf_train.yaml") -> Config: """ 加载配置文件。 :param config_path: 配置文件路径 :return: 返回配置对象 """ with open(config_path, 'r', encoding='utf-8') as f: config_dict = yaml.safe_load(f) return Config(**config_dict) # # 加载配置文件 # config = load_config(config_path=f"../conf/conf_train.yaml")