123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- #!/usr/bin/env python3
- import re
- import torch
- import os
- import json
- from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
- from datasets import Dataset
- from trl import GRPOConfig, GRPOTrainer
- from vllm import SamplingParams
- import trl.trainer.grpo_trainer
- from conf_train import load_config
- class GRPOTrainerWrapper:
- """
- Wrapper class for GRPO training with object-oriented design.
- """
-
- # Constants
- SYSTEM_PROMPT = """
- Respond in the following format:
- <reasoning>
- ...
- </reasoning>
- <answer>
- ...
- </answer>
- """
-
- XML_COT_FORMAT = """\
- <reasoning>
- {reasoning}
- </reasoning>
- <answer>
- {answer}
- </answer>
- """
-
- def __init__(self, config):
- """
- Initialize the trainer with configuration.
- """
- self.config = config
- self.model = None
- self.tokenizer = None
- self.trainer = None
-
- # Enable Unsloth's CLI training metrics visualization
- os.environ["UNSLOTH_DISPLAY_METRICS"] = "true"
-
- # Apply patch
- PatchFastRL("GRPO", FastLanguageModel)
-
- # Monkey patch the validation in the GRPOTrainer
- self._patch_grpo_trainer()
-
- def _patch_grpo_trainer(self):
- """
- Monkey patch the GRPOTrainer to bypass the divisibility check.
- """
- original_init = trl.trainer.grpo_trainer.GRPOTrainer.__init__
- def patched_init(self, *args, **kwargs):
- try:
- original_init(self, *args, **kwargs)
- except ValueError as e:
- if "evenly divisible by the number of generations per prompt" in str(e):
- print("Bypassing TRL's batch divisibility check...")
- # Continue with initialization despite the error
- self.args = kwargs.get("args")
- self.model = kwargs.get("model")
- self.processing_class = kwargs.get("processing_class")
- self.reward_funcs = kwargs.get("reward_funcs")
- self.train_dataset = kwargs.get("train_dataset")
- # Set up necessary trainer components without the check
- self._setup_trainer()
- else:
- raise e
- trl.trainer.grpo_trainer.GRPOTrainer.__init__ = patched_init
-
- def load_model(self):
- """
- Load the model and tokenizer.
- """
- self.model, self.tokenizer = FastLanguageModel.from_pretrained(
- model_name=self.config.model_name,
- max_seq_length=self.config.max_seq_length,
- load_in_4bit=self.config.load_in_4bit,
- fast_inference=self.config.fast_inference,
- max_lora_rank=self.config.lora_rank,
- gpu_memory_utilization=self.config.gpu_memory_utilization,
- )
- self.model = FastLanguageModel.get_peft_model(
- self.model,
- r=self.config.lora_rank,
- target_modules=[
- "q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj",
- ],
- lora_alpha=self.config.lora_rank,
- use_gradient_checkpointing="unsloth",
- random_state=3407,
- )
-
- @staticmethod
- def extract_xml_answer(text: str) -> str:
- """
- Extract answer from XML formatted text.
- """
- answer = text.split("<answer>")[-1]
- answer = answer.split("</answer>")[0]
- return answer.strip()
- @staticmethod
- def extract_hash_answer(text: str) -> str | None:
- """
- Extract answer from hash formatted text.
- """
- if "####" not in text:
- return None
- return text.split("####")[1].strip()
-
- def load_dataset(self) -> Dataset:
- """
- Load and prepare the training dataset.
- """
- # Read JSONL file
- data = []
- with open(self.config.train_data_path, 'r') as f:
- for line in f:
- data.append(json.loads(line))
-
- # Convert list to HuggingFace Dataset object
- return Dataset.from_list(data)
-
- def correctness_reward_func(self, prompts, completions, answer, **kwargs) -> list[float]:
- """
- Reward function for answer correctness.
- """
- responses = [completion[0]['content'] for completion in completions]
- q = prompts[0][-1]['content']
- extracted_responses = [self.extract_xml_answer(r) for r in responses]
- print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}",
- f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
- return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
- def int_reward_func(self, completions, **kwargs) -> list[float]:
- """
- Reward function for integer answers.
- """
- responses = [completion[0]['content'] for completion in completions]
- extracted_responses = [self.extract_xml_answer(r) for r in responses]
- return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
- def strict_format_reward_func(self, completions, **kwargs) -> list[float]:
- """
- Strict format reward function.
- """
- pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
- def soft_format_reward_func(self, completions, **kwargs) -> list[float]:
- """
- Soft format reward function.
- """
- pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
- responses = [completion[0]["content"] for completion in completions]
- matches = [re.match(pattern, r) for r in responses]
- return [0.5 if match else 0.0 for match in matches]
- def count_xml(self, text) -> float:
- """
- Count XML tags in text.
- """
- count = 0.0
- if text.count("<reasoning>\n") == 1:
- count += 0.125
- if text.count("\n</reasoning>\n") == 1:
- count += 0.125
- if text.count("\n<answer>\n") == 1:
- count += 0.125
- count -= len(text.split("\n</answer>\n")[-1])*0.001
- if text.count("\n</answer>") == 1:
- count += 0.125
- count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
- return count
- def xmlcount_reward_func(self, completions, **kwargs) -> list[float]:
- """
- Reward function based on XML tag count.
- """
- contents = [completion[0]["content"] for completion in completions]
- return [self.count_xml(c) for c in contents]
-
- def prepare_training_args(self) -> GRPOConfig:
- """
- Prepare training arguments from config.
- """
- return GRPOConfig(
- use_vllm=self.config.use_vllm,
- learning_rate=float(self.config.learning_rate) ,
- adam_beta1=self.config.adam_beta1,
- adam_beta2=self.config.adam_beta2,
- weight_decay=self.config.weight_decay,
- warmup_ratio=self.config.warmup_ratio,
- lr_scheduler_type=self.config.lr_scheduler_type,
- optim=self.config.optim,
- logging_steps=self.config.logging_steps,
- bf16=is_bfloat16_supported(),
- fp16=not is_bfloat16_supported(),
- per_device_train_batch_size=self.config.per_device_train_batch_size,
- gradient_accumulation_steps=self.config.gradient_accumulation_steps,
- num_generations=self.config.num_generations,
- max_prompt_length=self.config.max_prompt_length,
- max_completion_length=self.config.max_completion_length,
- max_steps=self.config.max_steps,
- save_steps=self.config.save_steps,
- max_grad_norm=self.config.max_grad_norm,
- report_to=self.config.report_to,
- output_dir=self.config.output_dir,
- save_total_limit=3,
- log_level="info",
- disable_tqdm=False,
- evaluation_strategy="no",
- )
-
- def initialize_trainer(self, dataset):
- """
- Initialize the GRPO trainer.
- """
- training_args = self.prepare_training_args()
-
- self.trainer = GRPOTrainer(
- model=self.model,
- processing_class=self.tokenizer,
- reward_funcs=[
- self.xmlcount_reward_func,
- self.soft_format_reward_func,
- self.strict_format_reward_func,
- self.int_reward_func,
- self.correctness_reward_func,
- ],
- args=training_args,
- train_dataset=dataset,
- )
-
- def train(self):
- """
- Execute the training process.
- """
- print("Starting GRPO training...")
- print(f"per_device_train_batch_size = {self.config.per_device_train_batch_size}")
- print(f"gradient_accumulation_steps = {self.config.gradient_accumulation_steps}")
- print(f"num_generations = {self.config.num_generations}")
- print(f"max_steps = {self.config.max_steps}")
-
- self.trainer.train()
-
- def save_model(self):
- """
- Save the trained model.
- """
- print(f"Saving model to {self.config.save_path}...")
- os.makedirs(os.path.dirname(self.config.save_path), exist_ok=True)
- self.model.save_pretrained(self.config.save_path)
- self.tokenizer.save_pretrained(self.config.save_path)
- print("Training complete!")
- def main():
- # Load configuration
- config = load_config()
-
- # Initialize and run trainer
- trainer = GRPOTrainerWrapper(config)
- trainer.load_model()
- dataset = trainer.load_dataset()
- trainer.initialize_trainer(dataset)
- trainer.train()
- trainer.save_model()
- if __name__ == "__main__":
- main()
|