|
@@ -2,6 +2,7 @@ import os
|
|
|
import torch
|
|
|
from unsloth import FastLanguageModel
|
|
|
from transformers import TextStreamer
|
|
|
+from conf_train import load_config
|
|
|
|
|
|
class ModelInference:
|
|
|
def __init__(self, model_path, max_seq_length, dtype, load_in_4bit):
|
|
@@ -53,8 +54,11 @@ class ModelInference:
|
|
|
print(f"人工智能: {model_response}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
+ # Load configuration
|
|
|
+ config = load_config()
|
|
|
+
|
|
|
# 配置参数
|
|
|
- model_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B-GRPO')
|
|
|
+ model_path = config.save_path
|
|
|
max_seq_length = 2048
|
|
|
dtype = torch.float16
|
|
|
load_in_4bit = True
|