Ver Fonte

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie há 2 meses atrás
pai
commit
ad874737a9
2 ficheiros alterados com 5 adições e 3 exclusões
  1. 3 1
      src/conf_train.py
  2. 2 2
      src/train_model_grpo_v1.py

+ 3 - 1
src/conf_train.py

@@ -12,6 +12,7 @@ class Config:
     load_in_4bit: bool
     lora_rank: int
     gpu_memory_utilization: float
+    
     learning_rate: float
     adam_beta1: float
     adam_beta2: float
@@ -31,6 +32,7 @@ class Config:
     max_grad_norm: float
     report_to: str
     output_dir: str
+
     train_data_path: str
     save_path: str
 
@@ -45,4 +47,4 @@ def load_config(config_path: str) -> Config:
     return Config(**config_dict)
 
 # 加载配置文件
-config = load_config(f"../conf/conf_train.yaml")
+config = load_config(config_path=f"../conf/conf_train.yaml")

+ 2 - 2
src/train_model_grpo_v1.py

@@ -5,7 +5,7 @@ from unsloth import FastLanguageModel
 from unsloth import is_bfloat16_supported
 from trl import GRPOConfig, GRPOTrainer
 from datasets import load_dataset
-from conf_train import Config  # 导入配置文件
+from conf_train import Config ,load_config # 导入配置文件
 import re
 
 class ModelTrainer:
@@ -215,7 +215,7 @@ class ModelTrainer:
 
 if __name__ == "__main__":
     # 加载配置文件
-    config = Config()
+    config = load_config()
 
     # 设置环境变量
     """