import os import torch import torch.distributed as dist from unsloth import FastLanguageModel from unsloth import is_bfloat16_supported from trl import GRPOConfig, GRPOTrainer from datasets import load_dataset from conf_train import Config ,load_config # 导入配置文件 import re class ModelTrainer: def __init__(self, config:Config): """ 初始化 ModelTrainer 类,加载配置参数。 :param config: 配置对象,包含模型训练所需的参数 """ self.config = config self.model_name = config.model_name self.max_seq_length = config.max_seq_length self.dtype = torch.float16 if config.dtype == "float16" else torch.bfloat16 self.load_in_4bit = config.load_in_4bit self.lora_rank = config.lora_rank self.gpu_memory_utilization=config.gpu_memory_utilization def load_model(self): """ 加载预训练模型和分词器。 :return: 返回加载的模型和分词器 """ model, tokenizer = FastLanguageModel.from_pretrained( model_name=self.model_name, max_seq_length=self.max_seq_length, load_in_4bit=self.load_in_4bit, dtype=self.dtype, fast_inference=False, max_lora_rank=self.lora_rank, gpu_memory_utilization=0.6, ) model = model.to_empty(device='cuda') # 初始化模型的权重 for param in model.parameters(): if param.is_meta: param.data = torch.randn_like(param) # 添加 LoRA 适配器 model = FastLanguageModel.get_peft_model( model, max_seq_length=self.max_seq_length, r=self.lora_rank, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha=16, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, use_rslora=False, loftq_config=None, ) return model, tokenizer def load_data(self, train_data_path): """ 加载训练数据集。 :param train_data_path: 训练数据路径 :return: 返回加载的训练数据集 """ with open(train_data_path, 'r') as f: train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train") return train_dataset def train(self, model, tokenizer, train_dataset): """ 训练模型。 :param model: 预训练模型 :param tokenizer: 分词器 :param train_dataset: 训练数据集 :return: 返回训练后的模型 """ print("is_bfloat16_supported()=", is_bfloat16_supported()) print(f"Reserved memory: {torch.cuda.memory_reserved()}") print(f"Allocated memory: {torch.cuda.memory_allocated()}") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=True, pin_memory=True ) torch.cuda.empty_cache() training_args = GRPOConfig( use_vllm=False, learning_rate=self.config.learning_rate, adam_beta1=self.config.adam_beta1, adam_beta2=self.config.adam_beta2, weight_decay=self.config.weight_decay, warmup_ratio=self.config.warmup_ratio, lr_scheduler_type=self.config.lr_scheduler_type, optim=self.config.optim, logging_steps=self.config.logging_steps, bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), per_device_train_batch_size=self.config.per_device_train_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, num_generations=self.config.num_generations, max_prompt_length=self.config.max_prompt_length, max_completion_length=self.config.max_completion_length, num_train_epochs=self.config.num_train_epochs, max_steps=self.config.max_steps, save_steps=self.config.save_steps, max_grad_norm=self.config.max_grad_norm, report_to=self.config.report_to, output_dir=self.config.output_dir, ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[ self.xmlcount_reward_func, self.soft_format_reward_func, self.strict_format_reward_func, self.int_reward_func, self.correctness_reward_func, ], args=training_args, train_dataset=train_dataset, ) trainer.train() return model def save_model(self, model, tokenizer, save_path): """ 保存训练后的模型和分词器。 :param model: 训练后的模型 :param tokenizer: 分词器 :param save_path: 保存路径 """ model.save_pretrained(save_path) tokenizer.save_pretrained(save_path) print(f"Model saved to {save_path}") @staticmethod def extract_xml_answer(text: str) -> str: answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() @staticmethod def count_xml(text) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 return count @staticmethod def xmlcount_reward_func(completions, **kwargs): """ Reward function that counts XML tags in the completion. """ contents = [completion[0]["content"] for completion in completions] return [ModelTrainer.count_xml(c) for c in contents] @staticmethod def soft_format_reward_func(completions, **kwargs): """ Reward function that checks if the completion has a specific format. """ pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] @staticmethod def strict_format_reward_func(completions, **kwargs): """ Reward function that checks if the completion has a specific format. """ pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] @staticmethod def int_reward_func(completions, **kwargs): """ Reward function that checks if the completion contains an integer. """ responses = [completion[0]['content'] for completion in completions] extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] @staticmethod def correctness_reward_func(prompts, completions, answer, **kwargs): """ Reward function that checks if the completion matches the correct answer. """ responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] if __name__ == "__main__": # 加载配置文件 config = load_config() # 设置环境变量 """ # 多机多卡 # export RANK=0 # 第一台机器的 rank # export WORLD_SIZE=4 # 总共有 4 台机器 # export MASTER_ADDR=<主节点 IP> # export MASTER_PORT=12345 """ # 单机多卡 os.environ['RANK'] = '0' os.environ['WORLD_SIZE'] = '1' os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12345' # 初始化进程组 dist.init_process_group(backend='nccl', init_method='env://') # 初始化 ModelTrainer trainer = ModelTrainer(config) # 加载模型和分词器 model, tokenizer = trainer.load_model() # 加载数据集 train_dataset = trainer.load_data(config.train_data_path) # 训练模型 model = trainer.train(model, tokenizer, train_dataset) # 保存模型 trainer.save_model(model, tokenizer, config.save_path) # 确保进程组被销毁 if dist.is_initialized(): dist.destroy_process_group() print("Training completed.")