#!/usr/bin/env python3 import re import torch import os import json from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from vllm import SamplingParams import trl.trainer.grpo_trainer from conf_train import load_config class GRPOTrainerWrapper: """ Wrapper class for GRPO training with object-oriented design. """ # Constants SYSTEM_PROMPT = """ Respond in the following format: ... ... """ XML_COT_FORMAT = """\ {reasoning} {answer} """ def __init__(self, config): """ Initialize the trainer with configuration. """ self.config = config self.model = None self.tokenizer = None self.trainer = None # Enable Unsloth's CLI training metrics visualization os.environ["UNSLOTH_DISPLAY_METRICS"] = "true" # Apply patch PatchFastRL("GRPO", FastLanguageModel) # Monkey patch the validation in the GRPOTrainer self._patch_grpo_trainer() def _patch_grpo_trainer(self): """ Monkey patch the GRPOTrainer to bypass the divisibility check. """ original_init = trl.trainer.grpo_trainer.GRPOTrainer.__init__ def patched_init(self, *args, **kwargs): try: original_init(self, *args, **kwargs) except ValueError as e: if "evenly divisible by the number of generations per prompt" in str(e): print("Bypassing TRL's batch divisibility check...") # Continue with initialization despite the error self.args = kwargs.get("args") self.model = kwargs.get("model") self.processing_class = kwargs.get("processing_class") self.reward_funcs = kwargs.get("reward_funcs") self.train_dataset = kwargs.get("train_dataset") # Set up necessary trainer components without the check self._setup_trainer() else: raise e trl.trainer.grpo_trainer.GRPOTrainer.__init__ = patched_init def load_model(self): """ Load the model and tokenizer. """ self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name=self.config.model_name, max_seq_length=self.config.max_seq_length, load_in_4bit=self.config.load_in_4bit, fast_inference=self.config.fast_inference, max_lora_rank=self.config.lora_rank, gpu_memory_utilization=self.config.gpu_memory_utilization, ) self.model = FastLanguageModel.get_peft_model( self.model, r=self.config.lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=self.config.lora_rank, use_gradient_checkpointing="unsloth", random_state=3407, ) @staticmethod def extract_xml_answer(text: str) -> str: """ Extract answer from XML formatted text. """ answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() @staticmethod def extract_hash_answer(text: str) -> str | None: """ Extract answer from hash formatted text. """ if "####" not in text: return None return text.split("####")[1].strip() def load_dataset(self) -> Dataset: """ Load and prepare the training dataset. """ # Read JSONL file data = [] with open(self.config.train_data_path, 'r') as f: for line in f: data.append(json.loads(line)) # Convert list to HuggingFace Dataset object return Dataset.from_list(data) def correctness_reward_func(self, prompts, completions, answer, **kwargs) -> list[float]: """ Reward function for answer correctness. """ responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [self.extract_xml_answer(r) for r in responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(self, completions, **kwargs) -> list[float]: """ Reward function for integer answers. """ responses = [completion[0]['content'] for completion in completions] extracted_responses = [self.extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(self, completions, **kwargs) -> list[float]: """ Strict format reward function. """ pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(self, completions, **kwargs) -> list[float]: """ Soft format reward function. """ pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(self, text) -> float: """ Count XML tags in text. """ count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 return count def xmlcount_reward_func(self, completions, **kwargs) -> list[float]: """ Reward function based on XML tag count. """ contents = [completion[0]["content"] for completion in completions] return [self.count_xml(c) for c in contents] def prepare_training_args(self) -> GRPOConfig: """ Prepare training arguments from config. """ return GRPOConfig( use_vllm=self.config.use_vllm, learning_rate=float(self.config.learning_rate) , adam_beta1=self.config.adam_beta1, adam_beta2=self.config.adam_beta2, weight_decay=self.config.weight_decay, warmup_ratio=self.config.warmup_ratio, lr_scheduler_type=self.config.lr_scheduler_type, optim=self.config.optim, logging_steps=self.config.logging_steps, bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), per_device_train_batch_size=self.config.per_device_train_batch_size, gradient_accumulation_steps=self.config.gradient_accumulation_steps, num_generations=self.config.num_generations, max_prompt_length=self.config.max_prompt_length, max_completion_length=self.config.max_completion_length, max_steps=self.config.max_steps, save_steps=self.config.save_steps, max_grad_norm=self.config.max_grad_norm, report_to=self.config.report_to, output_dir=self.config.output_dir, save_total_limit=3, log_level="info", disable_tqdm=False, evaluation_strategy="no", ) def initialize_trainer(self, dataset): """ Initialize the GRPO trainer. """ training_args = self.prepare_training_args() self.trainer = GRPOTrainer( model=self.model, processing_class=self.tokenizer, reward_funcs=[ self.xmlcount_reward_func, self.soft_format_reward_func, self.strict_format_reward_func, self.int_reward_func, self.correctness_reward_func, ], args=training_args, train_dataset=dataset, ) def train(self): """ Execute the training process. """ print("Starting GRPO training...") print(f"per_device_train_batch_size = {self.config.per_device_train_batch_size}") print(f"gradient_accumulation_steps = {self.config.gradient_accumulation_steps}") print(f"num_generations = {self.config.num_generations}") print(f"max_steps = {self.config.max_steps}") self.trainer.train() def save_model(self): """ Save the trained model. """ print(f"Saving model to {self.config.save_path}...") os.makedirs(os.path.dirname(self.config.save_path), exist_ok=True) self.model.save_pretrained(self.config.save_path) self.tokenizer.save_pretrained(self.config.save_path) print("Training complete!") def main(): # Load configuration config = load_config() # Initialize and run trainer trainer = GRPOTrainerWrapper(config) trainer.load_model() dataset = trainer.load_dataset() trainer.initialize_trainer(dataset) trainer.train() trainer.save_model() if __name__ == "__main__": main()