|
@@ -17,11 +17,11 @@ PatchFastRL("GRPO", FastLanguageModel)
|
|
|
model_name = f"../models/pretrained/DeepSeek-R1-Distill-Qwen-1.5B"
|
|
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
model_name=model_name,
|
|
|
- max_seq_length=2048,
|
|
|
+ max_seq_length=1024, # 2028
|
|
|
load_in_4bit=True,
|
|
|
fast_inference=True,
|
|
|
max_lora_rank=128,
|
|
|
- gpu_memory_utilization=0.80,
|
|
|
+ gpu_memory_utilization=0.60,
|
|
|
)
|
|
|
|
|
|
model = FastLanguageModel.get_peft_model(
|
|
@@ -145,7 +145,7 @@ training_args = GRPOConfig(
|
|
|
bf16 = is_bfloat16_supported(),
|
|
|
fp16 = not is_bfloat16_supported(),
|
|
|
per_device_train_batch_size = 1,
|
|
|
- gradient_accumulation_steps = 2, # Increased for better stability
|
|
|
+ gradient_accumulation_steps = 1, # 2 Increased for better stability
|
|
|
num_generations = 8,
|
|
|
max_prompt_length = 256,
|
|
|
max_completion_length = 200,
|