#!/usr/bin/env python3 # Direct clone of notebook implementation with minimal changes # Import in the same order as the notebook import re import torch import os import json from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported # Enable Unsloth's CLI training metrics visualization os.environ["UNSLOTH_DISPLAY_METRICS"] = "true" # Apply patch exactly like notebook PatchFastRL("GRPO", FastLanguageModel) # Load the model just like the notebook model_name = f"../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B" model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=2048, # 2048 load_in_4bit=True, fast_inference=False, max_lora_rank=128, gpu_memory_utilization=0.5, ) model = FastLanguageModel.get_peft_model( model, r=128, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=128, use_gradient_checkpointing="unsloth", random_state=3407, ) # Constants SYSTEM_PROMPT = """ Respond in the following format: ... ... """ XML_COT_FORMAT = """\ {reasoning} {answer} """ # Helper functions def extract_xml_answer(text: str) -> str: answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() def extract_hash_answer(text: str) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() # Dataset preparation from datasets import load_dataset, Dataset from modelscope.msdatasets import MsDataset def get_gsm8k_questions(split="train") -> Dataset: # data = load_dataset('openai/gsm8k', 'main')[split] data = MsDataset.load('openai-mirror/gsm8k', subset_name='main', split=split) os.makedirs(f'../data/temp/',exist_ok=False) # Save original datasets to JSONL with open(f'../data/temp/gsm8k_original_{split}.jsonl', 'w') as f: for item in data: f.write(json.dumps(item) + '\n') data = data.map(lambda x: { 'prompt': [ {'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': x['question']} ], 'answer': extract_hash_answer(x['answer']) }) # Save formatted datasets to JSONL with open(f'../data/temp/gsm8k_formatted_{split}.jsonl', 'w') as f: for item in data: f.write(json.dumps(item) + '\n') return data # Get dataset dataset = get_gsm8k_questions() # Reward functions def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] q = prompts[0][-1]['content'] extracted_responses = [extract_xml_answer(r) for r in responses] print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def int_reward_func(completions, **kwargs) -> list[float]: responses = [completion[0]['content'] for completion in completions] extracted_responses = [extract_xml_answer(r) for r in responses] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def strict_format_reward_func(completions, **kwargs) -> list[float]: pattern = r"^\n.*?\n\n\n.*?\n\n$" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def soft_format_reward_func(completions, **kwargs) -> list[float]: pattern = r".*?\s*.*?" responses = [completion[0]["content"] for completion in completions] matches = [re.match(pattern, r) for r in responses] return [0.5 if match else 0.0 for match in matches] def count_xml(text) -> float: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 return count def xmlcount_reward_func(completions, **kwargs) -> list[float]: contents = [completion[0]["content"] for completion in completions] return [count_xml(c) for c in contents] # Set up training args from trl import GRPOConfig, GRPOTrainer from vllm import SamplingParams # IMPORTANT: Extended training configuration for better results training_args = GRPOConfig( use_vllm = False, learning_rate = 5e-6, adam_beta1 = 0.9, adam_beta2 = 0.99, weight_decay = 0.1, warmup_ratio = 0.1, lr_scheduler_type = "cosine", optim = "adamw_8bit", logging_steps = 5, # More frequent logs for better CLI visualization bf16 = is_bfloat16_supported(), fp16 = not is_bfloat16_supported(), per_device_train_batch_size = 1, gradient_accumulation_steps = 1, # 2 Increased for better stability num_generations = 8, max_prompt_length = 256, max_completion_length = 200, max_steps = 10, # 2000 Increased 8x for longer training save_steps = 10, # Save checkpoints more frequently max_grad_norm = 0.1, report_to = "tensorboard", # Enable tensorboard reporting for metrics display output_dir = "outputs", save_total_limit = 3, # Keep only the last 3 checkpoints to save disk space # Enable detailed metrics logging log_level = "info", disable_tqdm = False, # Ensure progress bars are displayed # logging_steps = 5, # Log metrics frequently evaluation_strategy = "no", # Disable evaluation since we don't have an eval dataset ) # Train the model with extended training print("Starting GRPO training with EXTENDED training settings...") print(f"per_device_train_batch_size = {training_args.per_device_train_batch_size}") print(f"gradient_accumulation_steps = {training_args.gradient_accumulation_steps}") print(f"num_generations = {training_args.num_generations}") print(f"max_steps = {training_args.max_steps} (increased for better results)") # Monkey patch the validation in the GRPOTrainer to bypass the divisibility check # This is a workaround for the mysterious bug in TRL's implementation import trl.trainer.grpo_trainer original_init = trl.trainer.grpo_trainer.GRPOTrainer.__init__ def patched_init(self, *args, **kwargs): try: original_init(self, *args, **kwargs) except ValueError as e: if "evenly divisible by the number of generations per prompt" in str(e): print("Bypassing TRL's batch divisibility check...") # Continue with initialization despite the error self.args = kwargs.get("args") self.model = kwargs.get("model") self.processing_class = kwargs.get("processing_class") self.reward_funcs = kwargs.get("reward_funcs") self.train_dataset = kwargs.get("train_dataset") # Set up necessary trainer components without the check self._setup_trainer() else: raise e trl.trainer.grpo_trainer.GRPOTrainer.__init__ = patched_init # Initialize trainer and train trainer = GRPOTrainer( model = model, processing_class = tokenizer, reward_funcs = [ xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func, ], args = training_args, train_dataset = dataset, ) # Train the model trainer.train() # Save the trained model print("Saving LoRA weights to grpo_saved_lora...") model.save_lora(f"../models/trained/grpoModel") print("Training complete!")