|
@@ -17,7 +17,7 @@ PatchFastRL("GRPO", FastLanguageModel)
|
|
|
model_name = f"../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B"
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
model_name=model_name,
|
|
|
- max_seq_length=512, # 2028
|
|
|
+ max_seq_length=2048, # 2048
|
|
|
load_in_4bit=True,
|
|
|
fast_inference=False,
|
|
|
max_lora_rank=128,
|
|
@@ -151,8 +151,8 @@ training_args = GRPOConfig(
|
|
|
num_generations = 8,
|
|
|
max_prompt_length = 256,
|
|
|
max_completion_length = 200,
|
|
|
- max_steps = 2000, # Increased 8x for longer training
|
|
|
- save_steps = 500, # Save checkpoints more frequently
|
|
|
+ max_steps = 250, # 2000 Increased 8x for longer training
|
|
|
+ save_steps = 250, # Save checkpoints more frequently
|
|
|
max_grad_norm = 0.1,
|
|
|
report_to = "tensorboard", # Enable tensorboard reporting for metrics display
|
|
|
output_dir = "outputs",
|