fine_tune_model.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import torch
  3. from unsloth import FastLanguageModel
  4. class ModelFineTuner:
  5. def __init__(self, model_path, max_seq_length):
  6. self.model_path = model_path
  7. self.max_seq_length = max_seq_length
  8. def load_model(self):
  9. model, tokenizer = FastLanguageModel.from_pretrained(self.model_path)
  10. return model, tokenizer
  11. def fine_tune(self, model, tokenizer, dataset):
  12. model = FastLanguageModel.get_peft_model(
  13. model,
  14. r=16,
  15. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  16. lora_alpha=16,
  17. lora_dropout=0,
  18. bias="none",
  19. use_gradient_checkpointing=True,
  20. random_state=3407,
  21. max_seq_length=self.max_seq_length,
  22. )
  23. trainer = torch.optim.AdamW(model.parameters(), lr=2e-5)
  24. dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding='max_length', max_length=self.max_seq_length), batched=True)
  25. model.train()
  26. for epoch in range(3):
  27. for batch in dataset['train']:
  28. outputs = model(**batch)
  29. loss = outputs.loss
  30. loss.backward()
  31. trainer.step()
  32. trainer.zero_grad()
  33. return model
  34. def save_fine_tuned_model(self, model, save_path):
  35. model.save_pretrained(save_path)
  36. if __name__ == "__main__":
  37. model_path = os.path.join('..', 'models', 'deepseek-r1-distill-1.5B')
  38. max_seq_length = 2048
  39. fine_tuner = ModelFineTuner(model_path, max_seq_length)
  40. model, tokenizer = fine_tuner.load_model()
  41. dataset = fine_tuner.load_data(os.path.join('..', 'data', 'processed', 'train.json'))
  42. model = fine_tuner.fine_tune(model, tokenizer, dataset)
  43. fine_tuner.save_fine_tuned_model(model, os.path.join('..', 'models', 'deepseek-r1-distill-1.5B-fine-tuned'))