Browse Source

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 2 months ago
parent
commit
8f30a3f1e1
2 changed files with 288 additions and 0 deletions
  1. 286 0
      src/train_model_grpo_v1.1.py
  2. 2 0
      src/train_model_grpo_v1.py

+ 286 - 0
src/train_model_grpo_v1.1.py

@@ -0,0 +1,286 @@
+# train_model_grpo_v1.py
+
+import os
+import torch
+import torch.distributed as dist
+from unsloth import FastLanguageModel
+from unsloth import is_bfloat16_supported
+from trl import GRPOConfig, GRPOTrainer
+from datasets import load_dataset
+from conf_train import Config, load_config  # 导入配置文件
+import re
+
+class ModelTrainer:
+    def __init__(self, config: Config):
+        """
+        初始化 ModelTrainer 类,加载配置参数。
+        :param config: 配置对象,包含模型训练所需的参数
+        """
+        self.config: Config = config
+        self.model_name = config.model_name
+        self.max_seq_length = config.max_seq_length
+        self.dtype = torch.float16 if config.dtype == "float16" else torch.bfloat16
+        self.load_in_4bit = config.load_in_4bit
+        self.fast_inference = config.fast_inference
+        self.lora_rank = config.lora_rank
+        self.gpu_memory_utilization = config.gpu_memory_utilization
+
+    def load_model(self):
+        """
+        加载预训练模型和分词器。
+        :return: 返回加载的模型和分词器
+        """
+        model, tokenizer = FastLanguageModel.from_pretrained(
+            model_name=self.model_name,
+            max_seq_length=self.max_seq_length,
+            load_in_4bit=self.load_in_4bit,
+            dtype=self.dtype,
+            fast_inference=self.fast_inference,
+            max_lora_rank=self.lora_rank,
+            gpu_memory_utilization=self.gpu_memory_utilization,
+        )
+
+        model = model.to_empty(device='cuda')
+
+        # 初始化模型的权重
+        for param in model.parameters():
+            if param.is_meta:
+                param.data = torch.randn_like(param)
+
+        # 添加 LoRA 适配器
+        model = FastLanguageModel.get_peft_model(
+            model,
+            max_seq_length=self.max_seq_length,
+            r=self.lora_rank,
+            target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
+                          "gate_proj", "up_proj", "down_proj"],
+            lora_alpha=16,
+            lora_dropout=0,
+            bias="none",
+            use_gradient_checkpointing="unsloth",
+            random_state=3407,
+            use_rslora=False,
+            loftq_config=None,
+        )
+
+        return model, tokenizer
+
+    def load_data(self, train_data_path):
+        """
+        加载训练数据集。
+        :param train_data_path: 训练数据路径
+        :return: 返回加载的训练数据集
+        """
+        with open(train_data_path, 'r') as f:
+            train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
+        return train_dataset
+
+    def train(self, model, tokenizer, train_dataset):
+        """
+        训练模型。
+        :param model: 预训练模型
+        :param tokenizer: 分词器
+        :param train_dataset: 训练数据集
+        :return: 返回训练后的模型
+        """
+        print("is_bfloat16_supported()=", is_bfloat16_supported())
+        print(f"Reserved memory: {torch.cuda.memory_reserved()}")
+        print(f"Allocated memory: {torch.cuda.memory_allocated()}")
+
+        train_loader = torch.utils.data.DataLoader(
+            train_dataset, batch_size=1, shuffle=True, pin_memory=True
+        )
+
+        torch.cuda.empty_cache()
+        print("self.config.learning_rate=", float(self.config.learning_rate))
+        training_args = GRPOConfig(
+            use_vllm=self.config.use_vllm,
+            learning_rate=float(self.config.learning_rate),
+            adam_beta1=self.config.adam_beta1,
+            adam_beta2=self.config.adam_beta2,
+            weight_decay=self.config.weight_decay,
+            warmup_ratio=self.config.warmup_ratio,
+            lr_scheduler_type=self.config.lr_scheduler_type,
+            optim=self.config.optim,
+            logging_steps=self.config.logging_steps,
+            bf16=is_bfloat16_supported(),
+            fp16=not is_bfloat16_supported(),
+            per_device_train_batch_size=self.config.per_device_train_batch_size,
+            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
+            num_generations=self.config.num_generations,
+            max_prompt_length=self.config.max_prompt_length,
+            max_completion_length=self.config.max_completion_length,
+            num_train_epochs=self.config.num_train_epochs,
+            max_steps=self.config.max_steps,
+            save_steps=self.config.save_steps,
+            max_grad_norm=self.config.max_grad_norm,
+            report_to=self.config.report_to,
+            output_dir=self.config.output_dir,
+        )
+
+        trainer = GRPOTrainer(
+            model=model,
+            processing_class=tokenizer,
+            reward_funcs=[
+                self.xmlcount_reward_func,
+                self.soft_format_reward_func,
+                self.strict_format_reward_func,
+                self.int_reward_func,
+                self.correctness_reward_func,
+            ],
+            args=training_args,
+            train_dataset=train_dataset,
+        )
+
+        trainer.train()
+        return model
+
+    def save_model(self, model, tokenizer, save_path):
+        """
+        保存训练后的模型和分词器。
+        :param model: 训练后的模型
+        :param tokenizer: 分词器
+        :param save_path: 保存路径
+        """
+        model.save_pretrained(save_path)
+        tokenizer.save_pretrained(save_path)
+        print(f"Model saved to {save_path}")
+
+    @staticmethod
+    def extract_xml_answer(text: str) -> str:
+        """
+        从文本中提取 XML 格式的答案。
+        :param text: 包含 XML 格式的文本
+        :return: 提取的答案
+        """
+        answer = text.split("<answer>")[-1]
+        answer = answer.split("</answer>")[0]
+        return answer.strip()
+
+    @staticmethod
+    def count_xml(text) -> float:
+        """
+        计算 XML 标签的数量和完整性。
+        :param text: 包含 XML 格式的文本
+        :return: XML 标签的完整性得分
+        """
+        count = 0.0
+        if text.count("<reasoning>\n") == 1:
+            count += 0.125
+        if text.count("\n</reasoning>\n") == 1:
+            count += 0.125
+        if text.count("\n<answer>\n") == 1:
+            count += 0.125
+            count -= len(text.split("\n</answer>\n")[-1]) * 0.001
+        if text.count("\n</answer>") == 1:
+            count += 0.125
+            count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
+        return count
+
+    @staticmethod
+    def xmlcount_reward_func(completions, **kwargs):
+        """
+        计算 XML 标签的完整性得分。
+        :param completions: 模型生成的补全内容
+        :return: XML 标签的完整性得分列表
+        """
+        contents = [completion[0]["content"] for completion in completions]
+        return [ModelTrainer.count_xml(c) for c in contents]
+
+    @staticmethod
+    def soft_format_reward_func(completions, **kwargs):
+        """
+        检查补全内容是否符合软格式要求。
+        :param completions: 模型生成的补全内容
+        :return: 符合软格式要求的得分列表
+        """
+        pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
+        responses = [completion[0]["content"] for completion in completions]
+        matches = [re.match(pattern, r) for r in responses]
+        return [0.5 if match else 0.0 for match in matches]
+
+    @staticmethod
+    def strict_format_reward_func(completions, **kwargs):
+        """
+        检查补全内容是否符合严格格式要求。
+        :param completions: 模型生成的补全内容
+        :return: 符合严格格式要求的得分列表
+        """
+        pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
+        responses = [completion[0]["content"] for completion in completions]
+        matches = [re.match(pattern, r) for r in responses]
+        return [0.5 if match else 0.0 for match in matches]
+
+    @staticmethod
+    def int_reward_func(completions, **kwargs):
+        """
+        检查补全内容是否包含整数。
+        :param completions: 模型生成的补全内容
+        :return: 包含整数的得分列表
+        """
+        responses = [completion[0]['content'] for completion in completions]
+        extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
+        return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+
+    @staticmethod
+    def correctness_reward_func(prompts, completions, answer, **kwargs):
+        """
+        检查补全内容是否正确。
+        :param prompts: 输入提示
+        :param completions: 模型生成的补全内容
+        :param answer: 正确答案
+        :return: 补全内容正确的得分列表
+        """
+        responses = [completion[0]['content'] for completion in completions]
+        q = prompts[0][-1]['content']
+        extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
+        print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
+        return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+if __name__ == "__main__":
+    # 加载配置文件
+    config = load_config(f"../conf/conf_train.yaml")
+
+    # 设置环境变量
+    """
+        # 多机多卡
+        # export RANK=0  # 第一台机器的 rank
+        # export WORLD_SIZE=4  # 总共有 4 台机器
+        # export MASTER_ADDR=<主节点 IP>
+        # export MASTER_PORT=12345
+
+    """
+    # 单机多卡
+    os.environ['RANK'] = '0'
+    os.environ['WORLD_SIZE'] = '1'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    # 根据操作系统选择后端
+    backend = 'gloo' if os.name == 'nt' else 'nccl'
+
+    # 使用文件初始化方法  2025-3-11 成功验证支持windows
+    init_method = f'env://'  # env://  # 文件路径需要所有进程都能访问
+    dist.init_process_group(backend=backend, init_method=init_method)
+
+    print(f"Initialized distributed training with backend: {backend}")
+
+    # 初始化 ModelTrainer
+    trainer = ModelTrainer(config)
+
+    # 加载模型和分词器
+    model, tokenizer = trainer.load_model()
+
+    # 加载数据集
+    train_dataset = trainer.load_data(config.train_data_path)
+
+    # 训练模型
+    model = trainer.train(model, tokenizer, train_dataset)
+
+    # 保存模型
+    trainer.save_model(model, tokenizer, config.save_path)
+
+    # 确保进程组被销毁
+    if dist.is_initialized():
+        dist.destroy_process_group()
+    print("Training completed.")

+ 2 - 0
src/train_model_grpo_v1.py

@@ -1,3 +1,5 @@
+# train_model_grpo_v1.py
+
 import os
 import torch
 import torch.distributed as dist