Kaynağa Gözat

遵循面向对象思想重构train_model_grpo.py

zhouyang.xie 2 ay önce
ebeveyn
işleme
7196266feb
3 değiştirilmiş dosya ile 306 ekleme ve 1 silme
  1. 48 0
      src/conf_train.py
  2. 1 1
      src/train_model_grpo.py
  3. 257 0
      src/train_model_grpo_v1.py

+ 48 - 0
src/conf_train.py

@@ -0,0 +1,48 @@
+import yaml
+from dataclasses import dataclass
+
+@dataclass
+class Config:
+    """
+    配置类,用于加载和管理训练配置。
+    """
+    model_name: str
+    max_seq_length: int
+    dtype: str
+    load_in_4bit: bool
+    lora_rank: int
+    gpu_memory_utilization: float
+    learning_rate: float
+    adam_beta1: float
+    adam_beta2: float
+    weight_decay: float
+    warmup_ratio: float
+    lr_scheduler_type: str
+    optim: str
+    logging_steps: int
+    per_device_train_batch_size: int
+    gradient_accumulation_steps: int
+    num_generations: int
+    max_prompt_length: int
+    max_completion_length: int
+    num_train_epochs: int
+    max_steps: int
+    save_steps: int
+    max_grad_norm: float
+    report_to: str
+    output_dir: str
+    train_data_path: str
+    save_path: str
+
+def load_config(config_path: str) -> Config:
+    """
+    加载配置文件。
+    :param config_path: 配置文件路径
+    :return: 返回配置对象
+    """
+    with open(config_path, 'r') as f:
+        config_dict = yaml.safe_load(f)
+    return Config(**config_dict)
+
+# 加载配置文件
+config = load_config("./conf/conf_train.yaml")

+ 1 - 1
src/train_model_grpo.py

@@ -253,7 +253,7 @@ if __name__ == "__main__":
         # export MASTER_PORT=12345
 
         # 初始化进程组
-        # dist.init_process_group(backend='nccl', init_method='env://')
+        dist.init_process_group(backend='nccl', init_method='env://')
         # 初始化 ModelTrainer
         trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit,lora_rank)
         

+ 257 - 0
src/train_model_grpo_v1.py

@@ -0,0 +1,257 @@
+import os
+import torch
+import torch.distributed as dist
+from unsloth import FastLanguageModel
+from unsloth import is_bfloat16_supported
+from trl import GRPOConfig, GRPOTrainer
+from datasets import load_dataset
+from conf_train import Config  # 导入配置文件
+import re
+
+class ModelTrainer:
+    def __init__(self, config:Config):
+        """
+        初始化 ModelTrainer 类,加载配置参数。
+        :param config: 配置对象,包含模型训练所需的参数
+        """
+        self.config = config
+        self.model_name = config.model_name
+        self.max_seq_length = config.max_seq_length
+        self.dtype = torch.float16 if config.dtype == "float16" else torch.bfloat16
+        self.load_in_4bit = config.load_in_4bit
+        self.lora_rank = config.lora_rank
+        self.gpu_memory_utilization=config.gpu_memory_utilization
+
+    def load_model(self):
+        """
+        加载预训练模型和分词器。
+        :return: 返回加载的模型和分词器
+        """
+        model, tokenizer = FastLanguageModel.from_pretrained(
+            model_name=self.model_name,
+            max_seq_length=self.max_seq_length,
+            load_in_4bit=self.load_in_4bit,
+            dtype=self.dtype,
+            fast_inference=False,
+            max_lora_rank=self.lora_rank,
+            gpu_memory_utilization=0.6,
+        )
+
+        model = model.to_empty(device='cuda')
+
+        # 初始化模型的权重
+        for param in model.parameters():
+            if param.is_meta:
+                param.data = torch.randn_like(param)
+
+        # 添加 LoRA 适配器
+        model = FastLanguageModel.get_peft_model(
+            model,
+            max_seq_length=self.max_seq_length,
+            r=self.lora_rank,
+            target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
+                          "gate_proj", "up_proj", "down_proj"],
+            lora_alpha=16,
+            lora_dropout=0,
+            bias="none",
+            use_gradient_checkpointing="unsloth",
+            random_state=3407,
+            use_rslora=False,
+            loftq_config=None,
+        )
+
+        return model, tokenizer
+
+    def load_data(self, train_data_path):
+        """
+        加载训练数据集。
+        :param train_data_path: 训练数据路径
+        :return: 返回加载的训练数据集
+        """
+        with open(train_data_path, 'r') as f:
+            train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
+        return train_dataset
+
+    def train(self, model, tokenizer, train_dataset):
+        """
+        训练模型。
+        :param model: 预训练模型
+        :param tokenizer: 分词器
+        :param train_dataset: 训练数据集
+        :return: 返回训练后的模型
+        """
+        print("is_bfloat16_supported()=", is_bfloat16_supported())
+        print(f"Reserved memory: {torch.cuda.memory_reserved()}")
+        print(f"Allocated memory: {torch.cuda.memory_allocated()}")
+
+        train_loader = torch.utils.data.DataLoader(
+            train_dataset, batch_size=1, shuffle=True, pin_memory=True
+        )
+
+        torch.cuda.empty_cache()
+
+        training_args = GRPOConfig(
+            use_vllm=False,
+            learning_rate=self.config.learning_rate,
+            adam_beta1=self.config.adam_beta1,
+            adam_beta2=self.config.adam_beta2,
+            weight_decay=self.config.weight_decay,
+            warmup_ratio=self.config.warmup_ratio,
+            lr_scheduler_type=self.config.lr_scheduler_type,
+            optim=self.config.optim,
+            logging_steps=self.config.logging_steps,
+            bf16=is_bfloat16_supported(),
+            fp16=not is_bfloat16_supported(),
+            per_device_train_batch_size=self.config.per_device_train_batch_size,
+            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
+            num_generations=self.config.num_generations,
+            max_prompt_length=self.config.max_prompt_length,
+            max_completion_length=self.config.max_completion_length,
+            num_train_epochs=self.config.num_train_epochs,
+            max_steps=self.config.max_steps,
+            save_steps=self.config.save_steps,
+            max_grad_norm=self.config.max_grad_norm,
+            report_to=self.config.report_to,
+            output_dir=self.config.output_dir,
+        )
+
+        trainer = GRPOTrainer(
+            model=model,
+            processing_class=tokenizer,
+            reward_funcs=[
+                self.xmlcount_reward_func,
+                self.soft_format_reward_func,
+                self.strict_format_reward_func,
+                self.int_reward_func,
+                self.correctness_reward_func,
+            ],
+            args=training_args,
+            train_dataset=train_dataset,
+        )
+
+        trainer.train()
+        return model
+
+    def save_model(self, model, tokenizer, save_path):
+        """
+        保存训练后的模型和分词器。
+        :param model: 训练后的模型
+        :param tokenizer: 分词器
+        :param save_path: 保存路径
+        """
+        model.save_pretrained(save_path)
+        tokenizer.save_pretrained(save_path)
+        print(f"Model saved to {save_path}")
+
+    @staticmethod
+    def extract_xml_answer(text: str) -> str:
+        answer = text.split("<answer>")[-1]
+        answer = answer.split("</answer>")[0]
+        return answer.strip()
+    
+    @staticmethod
+    def count_xml(text) -> float:
+        count = 0.0
+        if text.count("<reasoning>\n") == 1:
+            count += 0.125
+        if text.count("\n</reasoning>\n") == 1:
+            count += 0.125
+        if text.count("\n<answer>\n") == 1:
+            count += 0.125
+            count -= len(text.split("\n</answer>\n")[-1])*0.001
+        if text.count("\n</answer>") == 1:
+            count += 0.125
+            count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
+        return count
+
+
+    @staticmethod
+    def xmlcount_reward_func(completions, **kwargs):
+        """
+        Reward function that counts XML tags in the completion.
+        """
+        contents = [completion[0]["content"] for completion in completions]
+        return [ModelTrainer.count_xml(c) for c in contents]
+
+    @staticmethod
+    def soft_format_reward_func(completions, **kwargs):
+        """
+        Reward function that checks if the completion has a specific format.
+        """
+        pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
+        responses = [completion[0]["content"] for completion in completions]
+        matches = [re.match(pattern, r) for r in responses]
+        return [0.5 if match else 0.0 for match in matches]
+
+    @staticmethod
+    def strict_format_reward_func(completions, **kwargs):
+        """
+        Reward function that checks if the completion has a specific format.
+        """
+        pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
+        responses = [completion[0]["content"] for completion in completions]
+        matches = [re.match(pattern, r) for r in responses]
+        return [0.5 if match else 0.0 for match in matches]
+
+    @staticmethod
+    def int_reward_func(completions, **kwargs):
+        """
+        Reward function that checks if the completion contains an integer.
+        """
+        responses = [completion[0]['content'] for completion in completions]
+        extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
+        return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+
+    @staticmethod
+    def correctness_reward_func(prompts, completions, answer, **kwargs):
+        """
+        Reward function that checks if the completion matches the correct answer.
+        """
+        responses = [completion[0]['content'] for completion in completions]
+        q = prompts[0][-1]['content']
+        extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
+        print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
+        return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+if __name__ == "__main__":
+    # 加载配置文件
+    config = Config()
+
+    # 设置环境变量
+    """
+        # 多机多卡
+        # export RANK=0  # 第一台机器的 rank
+        # export WORLD_SIZE=4  # 总共有 4 台机器
+        # export MASTER_ADDR=<主节点 IP>
+        # export MASTER_PORT=12345
+
+    """
+    # 单机多卡
+    os.environ['RANK'] = '0'
+    os.environ['WORLD_SIZE'] = '1'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+    
+
+    # 初始化进程组
+    dist.init_process_group(backend='nccl', init_method='env://')
+
+    # 初始化 ModelTrainer
+    trainer = ModelTrainer(config)
+
+    # 加载模型和分词器
+    model, tokenizer = trainer.load_model()
+
+    # 加载数据集
+    train_dataset = trainer.load_data(config.train_data_path)
+
+    # 训练模型
+    model = trainer.train(model, tokenizer, train_dataset)
+
+    # 保存模型
+    trainer.save_model(model, tokenizer, config.save_path)
+
+    # 确保进程组被销毁
+    if dist.is_initialized():
+        dist.destroy_process_group()
+    print("Training completed.")