inference.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import os
  2. import torch
  3. from unsloth import FastLanguageModel
  4. from transformers import TextStreamer
  5. from conf_train import load_config
  6. class ModelInference:
  7. def __init__(self, model_path, max_seq_length, dtype, load_in_4bit):
  8. self.model_path = model_path
  9. self.max_seq_length = max_seq_length
  10. self.dtype = dtype
  11. self.load_in_4bit = load_in_4bit
  12. self.model = None
  13. self.tokenizer = None
  14. self.lora_rank=64
  15. def load_model(self):
  16. # 加载训练好的模型和分词器
  17. self.model, self.tokenizer = FastLanguageModel.from_pretrained(
  18. model_name=self.model_path,
  19. max_seq_length=self.max_seq_length,
  20. load_in_4bit=self.load_in_4bit, # 值为True 以 4 bit量化进行微调,为False LoRA 16bit。这将内存使用量减少了 4 倍,使我们能够在免费的 16GB 内存 GPU 中实际进行微调。4 位量化本质上将权重转换为一组有限的数字以减少内存使用量。这样做的缺点是准确度会下降 1-2%。如果您想要这种微小的额外准确度,请在较大的 GPU(如 H100)上将其设置为 False。
  21. dtype=self.dtype,
  22. fast_inference = False, # Enable vLLM fast inference
  23. max_lora_rank = self.lora_rank,
  24. gpu_memory_utilization=0.6, # 0.6 # Reduce if out of memory
  25. )
  26. # 将模型设置为推理模式
  27. self.model = FastLanguageModel.for_inference(self.model)
  28. print("Model and tokenizer loaded successfully.")
  29. def chat(self):
  30. # 与模型进行交互
  31. print("Start chatting with the model (type 'exit' to stop)!")
  32. while True:
  33. user_input = input("You: ")
  34. if user_input.lower() == "exit":
  35. print("Exiting chat.")
  36. break
  37. # 将用户输入编码为模型输入
  38. inputs = self.tokenizer(user_input, return_tensors="pt", max_length=self.max_seq_length, truncation=True)
  39. inputs = inputs.to("cuda") # 将输入数据移动到GPU
  40. # 生成模型的回复
  41. with torch.no_grad():
  42. text_streamer =TextStreamer(self.tokenizer ,skip_prompt = True)
  43. outputs = self.model.generate(**inputs, streamer = text_streamer, max_length=self.max_seq_length, pad_token_id=self.tokenizer.eos_token_id)
  44. # 解码模型的输出
  45. model_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
  46. print(f"人工智能: {model_response}")
  47. if __name__ == "__main__":
  48. # Load configuration
  49. config = load_config()
  50. # 配置参数
  51. model_path = config.save_path
  52. max_seq_length = 2048
  53. dtype = torch.float16
  54. load_in_4bit = True
  55. # 初始化 ModelInference
  56. inference = ModelInference(model_path, max_seq_length, dtype, load_in_4bit)
  57. # 加载模型和分词器
  58. inference.load_model()
  59. # 开始与模型对话
  60. inference.chat()