quantize_model.py 892 B

123456789101112131415161718192021222324
  1. import os
  2. import torch
  3. from unsloth import FastLanguageModel
  4. class ModelQuantizer:
  5. def __init__(self, model_path):
  6. self.model_path = model_path
  7. def quantize(self):
  8. model, tokenizer = FastLanguageModel.from_pretrained(self.model_path)
  9. # Apply dynamic quantization
  10. quantized_model = torch.quantization.quantize_dynamic(
  11. model, {torch.nn.Linear}, dtype=torch.qint8
  12. )
  13. return quantized_model, tokenizer
  14. def save_quantized_model(self, model, save_path):
  15. model.save_pretrained(save_path)
  16. if __name__ == "__main__":
  17. model_path = os.path.join('..', 'models', 'deepseek-r1-distill-1.5B-finetuned')
  18. quantizer = ModelQuantizer(model_path)
  19. quantized_model, tokenizer = quantizer.quantize()
  20. quantizer.save_quantized_model(quantized_model, os.path.join('..', 'models', 'deepseek-r1-distill-1.5B-quantized'))