| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- # conf_train.py
- import yaml
- from dataclasses import dataclass
- @dataclass
- class Config:
- """
- 配置类,用于加载和管理训练配置。
- """
- model_name: str
- max_seq_length: int
- dtype: str
- load_in_4bit: bool
- fast_inference: bool # Enable vLLM fast inference
- lora_rank: int
- gpu_memory_utilization: float
-
- use_vllm:bool
- learning_rate: float
- adam_beta1: float
- adam_beta2: float
- weight_decay: float
- warmup_ratio: float
- lr_scheduler_type: str
- optim: str
- logging_steps: int
- per_device_train_batch_size: int
- gradient_accumulation_steps: int
- num_generations: int
- max_prompt_length: int
- max_completion_length: int
- num_train_epochs: int
- max_steps: int
- save_steps: int
- max_grad_norm: float
- report_to: str
- output_dir: str
- train_data_path: str
- save_path: str
- def load_config(config_path: str=f"../conf/conf_train.yaml") -> Config:
- """
- 加载配置文件。
- :param config_path: 配置文件路径
- :return: 返回配置对象
- """
- with open(config_path, 'r', encoding='utf-8') as f:
- config_dict = yaml.safe_load(f)
- return Config(**config_dict)
- # # 加载配置文件
- # config = load_config(config_path=f"../conf/conf_train.yaml")
|