#!/usr/bin/env python3
import re
import torch
import os
import json
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams
import trl.trainer.grpo_trainer
from conf_train import load_config
class GRPOTrainerWrapper:
"""
Wrapper class for GRPO training with object-oriented design.
"""
# Constants
SYSTEM_PROMPT = """
Respond in the following format:
...
...
"""
XML_COT_FORMAT = """\
{reasoning}
{answer}
"""
def __init__(self, config):
"""
Initialize the trainer with configuration.
"""
self.config = config
self.model = None
self.tokenizer = None
self.trainer = None
# Enable Unsloth's CLI training metrics visualization
os.environ["UNSLOTH_DISPLAY_METRICS"] = "true"
# Apply patch
PatchFastRL("GRPO", FastLanguageModel)
# Monkey patch the validation in the GRPOTrainer
self._patch_grpo_trainer()
def _patch_grpo_trainer(self):
"""
Monkey patch the GRPOTrainer to bypass the divisibility check.
"""
original_init = trl.trainer.grpo_trainer.GRPOTrainer.__init__
def patched_init(self, *args, **kwargs):
try:
original_init(self, *args, **kwargs)
except ValueError as e:
if "evenly divisible by the number of generations per prompt" in str(e):
print("Bypassing TRL's batch divisibility check...")
# Continue with initialization despite the error
self.args = kwargs.get("args")
self.model = kwargs.get("model")
self.processing_class = kwargs.get("processing_class")
self.reward_funcs = kwargs.get("reward_funcs")
self.train_dataset = kwargs.get("train_dataset")
# Set up necessary trainer components without the check
self._setup_trainer()
else:
raise e
trl.trainer.grpo_trainer.GRPOTrainer.__init__ = patched_init
def load_model(self):
"""
Load the model and tokenizer.
"""
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.config.model_name,
max_seq_length=self.config.max_seq_length,
load_in_4bit=self.config.load_in_4bit,
fast_inference=self.config.fast_inference,
max_lora_rank=self.config.lora_rank,
gpu_memory_utilization=self.config.gpu_memory_utilization,
)
self.model = FastLanguageModel.get_peft_model(
self.model,
r=self.config.lora_rank,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=self.config.lora_rank,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
@staticmethod
def extract_xml_answer(text: str) -> str:
"""
Extract answer from XML formatted text.
"""
answer = text.split("")[-1]
answer = answer.split("")[0]
return answer.strip()
@staticmethod
def extract_hash_answer(text: str) -> str | None:
"""
Extract answer from hash formatted text.
"""
if "####" not in text:
return None
return text.split("####")[1].strip()
def load_dataset(self) -> Dataset:
"""
Load and prepare the training dataset.
"""
# Read JSONL file
data = []
with open(self.config.train_data_path, 'r') as f:
for line in f:
data.append(json.loads(line))
# Convert list to HuggingFace Dataset object
return Dataset.from_list(data)
def correctness_reward_func(self, prompts, completions, answer, **kwargs) -> list[float]:
"""
Reward function for answer correctness.
"""
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [self.extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}",
f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(self, completions, **kwargs) -> list[float]:
"""
Reward function for integer answers.
"""
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [self.extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(self, completions, **kwargs) -> list[float]:
"""
Strict format reward function.
"""
pattern = r"^\n.*?\n\n\n.*?\n\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(self, completions, **kwargs) -> list[float]:
"""
Soft format reward function.
"""
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(self, text) -> float:
"""
Count XML tags in text.
"""
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
count -= len(text.split("\n\n")[-1])*0.001
if text.count("\n") == 1:
count += 0.125
count -= (len(text.split("\n")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(self, completions, **kwargs) -> list[float]:
"""
Reward function based on XML tag count.
"""
contents = [completion[0]["content"] for completion in completions]
return [self.count_xml(c) for c in contents]
def prepare_training_args(self) -> GRPOConfig:
"""
Prepare training arguments from config.
"""
return GRPOConfig(
use_vllm=self.config.use_vllm,
learning_rate=float(self.config.learning_rate) ,
adam_beta1=self.config.adam_beta1,
adam_beta2=self.config.adam_beta2,
weight_decay=self.config.weight_decay,
warmup_ratio=self.config.warmup_ratio,
lr_scheduler_type=self.config.lr_scheduler_type,
optim=self.config.optim,
logging_steps=self.config.logging_steps,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=self.config.per_device_train_batch_size,
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
num_generations=self.config.num_generations,
max_prompt_length=self.config.max_prompt_length,
max_completion_length=self.config.max_completion_length,
max_steps=self.config.max_steps,
save_steps=self.config.save_steps,
max_grad_norm=self.config.max_grad_norm,
report_to=self.config.report_to,
output_dir=self.config.output_dir,
save_total_limit=3,
log_level="info",
disable_tqdm=False,
evaluation_strategy="no",
)
def initialize_trainer(self, dataset):
"""
Initialize the GRPO trainer.
"""
training_args = self.prepare_training_args()
self.trainer = GRPOTrainer(
model=self.model,
processing_class=self.tokenizer,
reward_funcs=[
self.xmlcount_reward_func,
self.soft_format_reward_func,
self.strict_format_reward_func,
self.int_reward_func,
self.correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
)
def train(self):
"""
Execute the training process.
"""
print("Starting GRPO training...")
print(f"per_device_train_batch_size = {self.config.per_device_train_batch_size}")
print(f"gradient_accumulation_steps = {self.config.gradient_accumulation_steps}")
print(f"num_generations = {self.config.num_generations}")
print(f"max_steps = {self.config.max_steps}")
self.trainer.train()
def save_model(self):
"""
Save the trained model.
"""
print(f"Saving model to {self.config.save_path}...")
os.makedirs(os.path.dirname(self.config.save_path), exist_ok=True)
self.model.save_pretrained(self.config.save_path)
self.tokenizer.save_pretrained(self.config.save_path)
print("Training complete!")
def main():
# Load configuration
config = load_config()
# Initialize and run trainer
trainer = GRPOTrainerWrapper(config)
trainer.load_model()
dataset = trainer.load_dataset()
trainer.initialize_trainer(dataset)
trainer.train()
trainer.save_model()
if __name__ == "__main__":
main()