qwen_notebook_clone.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #!/usr/bin/env python3
  2. # Direct clone of notebook implementation with minimal changes
  3. # Import in the same order as the notebook
  4. import re
  5. import torch
  6. import os
  7. from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
  8. # Enable Unsloth's CLI training metrics visualization
  9. os.environ["UNSLOTH_DISPLAY_METRICS"] = "true"
  10. # Apply patch exactly like notebook
  11. PatchFastRL("GRPO", FastLanguageModel)
  12. # Load the model just like the notebook
  13. model_name = f"../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B"
  14. model, tokenizer = FastLanguageModel.from_pretrained(
  15. model_name=model_name,
  16. max_seq_length=2048,
  17. load_in_4bit=True,
  18. fast_inference=True,
  19. max_lora_rank=128,
  20. gpu_memory_utilization=0.80,
  21. )
  22. model = FastLanguageModel.get_peft_model(
  23. model,
  24. r=128,
  25. target_modules=[
  26. "q_proj", "k_proj", "v_proj", "o_proj",
  27. "gate_proj", "up_proj", "down_proj",
  28. ],
  29. lora_alpha=128,
  30. use_gradient_checkpointing="unsloth",
  31. random_state=3407,
  32. )
  33. # Constants
  34. SYSTEM_PROMPT = """
  35. Respond in the following format:
  36. <reasoning>
  37. ...
  38. </reasoning>
  39. <answer>
  40. ...
  41. </answer>
  42. """
  43. XML_COT_FORMAT = """\
  44. <reasoning>
  45. {reasoning}
  46. </reasoning>
  47. <answer>
  48. {answer}
  49. </answer>
  50. """
  51. # Helper functions
  52. def extract_xml_answer(text: str) -> str:
  53. answer = text.split("<answer>")[-1]
  54. answer = answer.split("</answer>")[0]
  55. return answer.strip()
  56. def extract_hash_answer(text: str) -> str | None:
  57. if "####" not in text:
  58. return None
  59. return text.split("####")[1].strip()
  60. # Dataset preparation
  61. from datasets import load_dataset, Dataset
  62. def get_gsm8k_questions(split="train") -> Dataset:
  63. data = load_dataset('openai/gsm8k', 'main')[split]
  64. data = data.map(lambda x: {
  65. 'prompt': [
  66. {'role': 'system', 'content': SYSTEM_PROMPT},
  67. {'role': 'user', 'content': x['question']}
  68. ],
  69. 'answer': extract_hash_answer(x['answer'])
  70. })
  71. return data
  72. # Get dataset
  73. dataset = get_gsm8k_questions()
  74. # Reward functions
  75. def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
  76. responses = [completion[0]['content'] for completion in completions]
  77. q = prompts[0][-1]['content']
  78. extracted_responses = [extract_xml_answer(r) for r in responses]
  79. print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
  80. return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
  81. def int_reward_func(completions, **kwargs) -> list[float]:
  82. responses = [completion[0]['content'] for completion in completions]
  83. extracted_responses = [extract_xml_answer(r) for r in responses]
  84. return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  85. def strict_format_reward_func(completions, **kwargs) -> list[float]:
  86. pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
  87. responses = [completion[0]["content"] for completion in completions]
  88. matches = [re.match(pattern, r) for r in responses]
  89. return [0.5 if match else 0.0 for match in matches]
  90. def soft_format_reward_func(completions, **kwargs) -> list[float]:
  91. pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
  92. responses = [completion[0]["content"] for completion in completions]
  93. matches = [re.match(pattern, r) for r in responses]
  94. return [0.5 if match else 0.0 for match in matches]
  95. def count_xml(text) -> float:
  96. count = 0.0
  97. if text.count("<reasoning>\n") == 1:
  98. count += 0.125
  99. if text.count("\n</reasoning>\n") == 1:
  100. count += 0.125
  101. if text.count("\n<answer>\n") == 1:
  102. count += 0.125
  103. count -= len(text.split("\n</answer>\n")[-1])*0.001
  104. if text.count("\n</answer>") == 1:
  105. count += 0.125
  106. count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
  107. return count
  108. def xmlcount_reward_func(completions, **kwargs) -> list[float]:
  109. contents = [completion[0]["content"] for completion in completions]
  110. return [count_xml(c) for c in contents]
  111. # Set up training args
  112. from trl import GRPOConfig, GRPOTrainer
  113. from vllm import SamplingParams
  114. # IMPORTANT: Extended training configuration for better results
  115. training_args = GRPOConfig(
  116. use_vllm = True,
  117. learning_rate = 5e-6,
  118. adam_beta1 = 0.9,
  119. adam_beta2 = 0.99,
  120. weight_decay = 0.1,
  121. warmup_ratio = 0.1,
  122. lr_scheduler_type = "cosine",
  123. optim = "adamw_8bit",
  124. logging_steps = 5, # More frequent logs for better CLI visualization
  125. bf16 = is_bfloat16_supported(),
  126. fp16 = not is_bfloat16_supported(),
  127. per_device_train_batch_size = 1,
  128. gradient_accumulation_steps = 2, # Increased for better stability
  129. num_generations = 8,
  130. max_prompt_length = 256,
  131. max_completion_length = 200,
  132. max_steps = 2000, # Increased 8x for longer training
  133. save_steps = 500, # Save checkpoints more frequently
  134. max_grad_norm = 0.1,
  135. report_to = "tensorboard", # Enable tensorboard reporting for metrics display
  136. output_dir = "outputs",
  137. save_total_limit = 3, # Keep only the last 3 checkpoints to save disk space
  138. # Enable detailed metrics logging
  139. log_level = "info",
  140. disable_tqdm = False, # Ensure progress bars are displayed
  141. # logging_steps = 5, # Log metrics frequently
  142. evaluation_strategy = "no", # Disable evaluation since we don't have an eval dataset
  143. )
  144. # Train the model with extended training
  145. print("Starting GRPO training with EXTENDED training settings...")
  146. print(f"per_device_train_batch_size = {training_args.per_device_train_batch_size}")
  147. print(f"gradient_accumulation_steps = {training_args.gradient_accumulation_steps}")
  148. print(f"num_generations = {training_args.num_generations}")
  149. print(f"max_steps = {training_args.max_steps} (increased for better results)")
  150. # Monkey patch the validation in the GRPOTrainer to bypass the divisibility check
  151. # This is a workaround for the mysterious bug in TRL's implementation
  152. import trl.trainer.grpo_trainer
  153. original_init = trl.trainer.grpo_trainer.GRPOTrainer.__init__
  154. def patched_init(self, *args, **kwargs):
  155. try:
  156. original_init(self, *args, **kwargs)
  157. except ValueError as e:
  158. if "evenly divisible by the number of generations per prompt" in str(e):
  159. print("Bypassing TRL's batch divisibility check...")
  160. # Continue with initialization despite the error
  161. self.args = kwargs.get("args")
  162. self.model = kwargs.get("model")
  163. self.processing_class = kwargs.get("processing_class")
  164. self.reward_funcs = kwargs.get("reward_funcs")
  165. self.train_dataset = kwargs.get("train_dataset")
  166. # Set up necessary trainer components without the check
  167. self._setup_trainer()
  168. else:
  169. raise e
  170. trl.trainer.grpo_trainer.GRPOTrainer.__init__ = patched_init
  171. # Initialize trainer and train
  172. trainer = GRPOTrainer(
  173. model = model,
  174. processing_class = tokenizer,
  175. reward_funcs = [
  176. xmlcount_reward_func,
  177. soft_format_reward_func,
  178. strict_format_reward_func,
  179. int_reward_func,
  180. correctness_reward_func,
  181. ],
  182. args = training_args,
  183. train_dataset = dataset,
  184. )
  185. # Train the model
  186. trainer.train()
  187. # Save the trained model
  188. print("Saving LoRA weights to grpo_saved_lora...")
  189. model.save_lora(f"../models/trained/grpoModel")
  190. print("Training complete!")