train_model_github_jwjohns.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. #!/usr/bin/env python3
  2. # Direct clone of notebook implementation with minimal changes
  3. # Import in the same order as the notebook
  4. import re
  5. import torch
  6. import os
  7. import json
  8. from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
  9. # Enable Unsloth's CLI training metrics visualization
  10. os.environ["UNSLOTH_DISPLAY_METRICS"] = "true"
  11. # Apply patch exactly like notebook
  12. PatchFastRL("GRPO", FastLanguageModel)
  13. # Load the model just like the notebook
  14. model_name = f"../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B"
  15. model, tokenizer = FastLanguageModel.from_pretrained(
  16. model_name=model_name,
  17. max_seq_length=2048, # 2048
  18. load_in_4bit=True,
  19. fast_inference=False,
  20. max_lora_rank=128,
  21. gpu_memory_utilization=0.5,
  22. )
  23. model = FastLanguageModel.get_peft_model(
  24. model,
  25. r=128,
  26. target_modules=[
  27. "q_proj", "k_proj", "v_proj", "o_proj",
  28. "gate_proj", "up_proj", "down_proj",
  29. ],
  30. lora_alpha=128,
  31. use_gradient_checkpointing="unsloth",
  32. random_state=3407,
  33. )
  34. # Constants
  35. SYSTEM_PROMPT = """
  36. Respond in the following format:
  37. <reasoning>
  38. ...
  39. </reasoning>
  40. <answer>
  41. ...
  42. </answer>
  43. """
  44. XML_COT_FORMAT = """\
  45. <reasoning>
  46. {reasoning}
  47. </reasoning>
  48. <answer>
  49. {answer}
  50. </answer>
  51. """
  52. # Helper functions
  53. def extract_xml_answer(text: str) -> str:
  54. answer = text.split("<answer>")[-1]
  55. answer = answer.split("</answer>")[0]
  56. return answer.strip()
  57. def extract_hash_answer(text: str) -> str | None:
  58. if "####" not in text:
  59. return None
  60. return text.split("####")[1].strip()
  61. # Dataset preparation
  62. from datasets import load_dataset, Dataset
  63. from modelscope.msdatasets import MsDataset
  64. def get_gsm8k_questions(split="train") -> Dataset:
  65. # data = load_dataset('openai/gsm8k', 'main')[split]
  66. data = MsDataset.load('openai-mirror/gsm8k', subset_name='main', split=split)
  67. os.makedirs(f'../data/temp/',exist_ok=False)
  68. # Save original datasets to JSONL
  69. with open(f'../data/temp/gsm8k_original_{split}.jsonl', 'w') as f:
  70. for item in data:
  71. f.write(json.dumps(item) + '\n')
  72. data = data.map(lambda x: {
  73. 'prompt': [
  74. {'role': 'system', 'content': SYSTEM_PROMPT},
  75. {'role': 'user', 'content': x['question']}
  76. ],
  77. 'answer': extract_hash_answer(x['answer'])
  78. })
  79. # Save formatted datasets to JSONL
  80. with open(f'../data/temp/gsm8k_formatted_{split}.jsonl', 'w') as f:
  81. for item in data:
  82. f.write(json.dumps(item) + '\n')
  83. return data
  84. # Get dataset
  85. dataset = get_gsm8k_questions()
  86. # Reward functions
  87. def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
  88. responses = [completion[0]['content'] for completion in completions]
  89. q = prompts[0][-1]['content']
  90. extracted_responses = [extract_xml_answer(r) for r in responses]
  91. print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
  92. return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
  93. def int_reward_func(completions, **kwargs) -> list[float]:
  94. responses = [completion[0]['content'] for completion in completions]
  95. extracted_responses = [extract_xml_answer(r) for r in responses]
  96. return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  97. def strict_format_reward_func(completions, **kwargs) -> list[float]:
  98. pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
  99. responses = [completion[0]["content"] for completion in completions]
  100. matches = [re.match(pattern, r) for r in responses]
  101. return [0.5 if match else 0.0 for match in matches]
  102. def soft_format_reward_func(completions, **kwargs) -> list[float]:
  103. pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
  104. responses = [completion[0]["content"] for completion in completions]
  105. matches = [re.match(pattern, r) for r in responses]
  106. return [0.5 if match else 0.0 for match in matches]
  107. def count_xml(text) -> float:
  108. count = 0.0
  109. if text.count("<reasoning>\n") == 1:
  110. count += 0.125
  111. if text.count("\n</reasoning>\n") == 1:
  112. count += 0.125
  113. if text.count("\n<answer>\n") == 1:
  114. count += 0.125
  115. count -= len(text.split("\n</answer>\n")[-1])*0.001
  116. if text.count("\n</answer>") == 1:
  117. count += 0.125
  118. count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
  119. return count
  120. def xmlcount_reward_func(completions, **kwargs) -> list[float]:
  121. contents = [completion[0]["content"] for completion in completions]
  122. return [count_xml(c) for c in contents]
  123. # Set up training args
  124. from trl import GRPOConfig, GRPOTrainer
  125. from vllm import SamplingParams
  126. # IMPORTANT: Extended training configuration for better results
  127. training_args = GRPOConfig(
  128. use_vllm = False,
  129. learning_rate = 5e-6,
  130. adam_beta1 = 0.9,
  131. adam_beta2 = 0.99,
  132. weight_decay = 0.1,
  133. warmup_ratio = 0.1,
  134. lr_scheduler_type = "cosine",
  135. optim = "adamw_8bit",
  136. logging_steps = 5, # More frequent logs for better CLI visualization
  137. bf16 = is_bfloat16_supported(),
  138. fp16 = not is_bfloat16_supported(),
  139. per_device_train_batch_size = 1,
  140. gradient_accumulation_steps = 1, # 2 Increased for better stability
  141. num_generations = 8,
  142. max_prompt_length = 256,
  143. max_completion_length = 200,
  144. max_steps = 10, # 2000 Increased 8x for longer training
  145. save_steps = 10, # Save checkpoints more frequently
  146. max_grad_norm = 0.1,
  147. report_to = "tensorboard", # Enable tensorboard reporting for metrics display
  148. output_dir = "outputs",
  149. save_total_limit = 3, # Keep only the last 3 checkpoints to save disk space
  150. # Enable detailed metrics logging
  151. log_level = "info",
  152. disable_tqdm = False, # Ensure progress bars are displayed
  153. # logging_steps = 5, # Log metrics frequently
  154. evaluation_strategy = "no", # Disable evaluation since we don't have an eval dataset
  155. )
  156. # Train the model with extended training
  157. print("Starting GRPO training with EXTENDED training settings...")
  158. print(f"per_device_train_batch_size = {training_args.per_device_train_batch_size}")
  159. print(f"gradient_accumulation_steps = {training_args.gradient_accumulation_steps}")
  160. print(f"num_generations = {training_args.num_generations}")
  161. print(f"max_steps = {training_args.max_steps} (increased for better results)")
  162. # Monkey patch the validation in the GRPOTrainer to bypass the divisibility check
  163. # This is a workaround for the mysterious bug in TRL's implementation
  164. import trl.trainer.grpo_trainer
  165. original_init = trl.trainer.grpo_trainer.GRPOTrainer.__init__
  166. def patched_init(self, *args, **kwargs):
  167. try:
  168. original_init(self, *args, **kwargs)
  169. except ValueError as e:
  170. if "evenly divisible by the number of generations per prompt" in str(e):
  171. print("Bypassing TRL's batch divisibility check...")
  172. # Continue with initialization despite the error
  173. self.args = kwargs.get("args")
  174. self.model = kwargs.get("model")
  175. self.processing_class = kwargs.get("processing_class")
  176. self.reward_funcs = kwargs.get("reward_funcs")
  177. self.train_dataset = kwargs.get("train_dataset")
  178. # Set up necessary trainer components without the check
  179. self._setup_trainer()
  180. else:
  181. raise e
  182. trl.trainer.grpo_trainer.GRPOTrainer.__init__ = patched_init
  183. # Initialize trainer and train
  184. trainer = GRPOTrainer(
  185. model = model,
  186. processing_class = tokenizer,
  187. reward_funcs = [
  188. xmlcount_reward_func,
  189. soft_format_reward_func,
  190. strict_format_reward_func,
  191. int_reward_func,
  192. correctness_reward_func,
  193. ],
  194. args = training_args,
  195. train_dataset = dataset,
  196. )
  197. # Train the model
  198. trainer.train()
  199. # Save the trained model
  200. print("Saving LoRA weights to grpo_saved_lora...")
  201. model.save_lora(f"../models/trained/grpoModel")
  202. print("Training complete!")