conf_train.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # conf_train.py
  2. import yaml
  3. from dataclasses import dataclass
  4. @dataclass
  5. class Config:
  6. """
  7. 配置类,用于加载和管理训练配置。
  8. """
  9. model_name: str
  10. max_seq_length: int
  11. dtype: str
  12. load_in_4bit: bool
  13. fast_inference: bool # Enable vLLM fast inference
  14. lora_rank: int
  15. gpu_memory_utilization: float
  16. use_vllm:bool
  17. learning_rate: float
  18. adam_beta1: float
  19. adam_beta2: float
  20. weight_decay: float
  21. warmup_ratio: float
  22. lr_scheduler_type: str
  23. optim: str
  24. logging_steps: int
  25. per_device_train_batch_size: int
  26. gradient_accumulation_steps: int
  27. num_generations: int
  28. max_prompt_length: int
  29. max_completion_length: int
  30. num_train_epochs: int
  31. max_steps: int
  32. save_steps: int
  33. max_grad_norm: float
  34. report_to: str
  35. output_dir: str
  36. train_data_path: str
  37. save_path: str
  38. def load_config(config_path: str=f"../conf/conf_train.yaml") -> Config:
  39. """
  40. 加载配置文件。
  41. :param config_path: 配置文件路径
  42. :return: 返回配置对象
  43. """
  44. with open(config_path, 'r', encoding='utf-8') as f:
  45. config_dict = yaml.safe_load(f)
  46. return Config(**config_dict)
  47. # # 加载配置文件
  48. # config = load_config(config_path=f"../conf/conf_train.yaml")