|
@@ -9,6 +9,8 @@ from trl import GRPOConfig, GRPOTrainer
|
|
from datasets import load_dataset
|
|
from datasets import load_dataset
|
|
from conf_train import Config, load_config # 导入配置文件
|
|
from conf_train import Config, load_config # 导入配置文件
|
|
import re
|
|
import re
|
|
|
|
+from transformers import BertTokenizer, BertModel
|
|
|
|
+import numpy as np
|
|
|
|
|
|
class ModelTrainer:
|
|
class ModelTrainer:
|
|
def __init__(self, config: Config):
|
|
def __init__(self, config: Config):
|
|
@@ -24,6 +26,9 @@ class ModelTrainer:
|
|
self.fast_inference = config.fast_inference
|
|
self.fast_inference = config.fast_inference
|
|
self.lora_rank = config.lora_rank
|
|
self.lora_rank = config.lora_rank
|
|
self.gpu_memory_utilization = config.gpu_memory_utilization
|
|
self.gpu_memory_utilization = config.gpu_memory_utilization
|
|
|
|
+ # 初始化 BERT 模型和分词器
|
|
|
|
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
+ self.bert_model = BertModel.from_pretrained('bert-base-uncased')
|
|
|
|
|
|
def load_model(self):
|
|
def load_model(self):
|
|
"""
|
|
"""
|
|
@@ -122,11 +127,15 @@ class ModelTrainer:
|
|
model=model,
|
|
model=model,
|
|
processing_class=tokenizer,
|
|
processing_class=tokenizer,
|
|
reward_funcs=[
|
|
reward_funcs=[
|
|
- self.xmlcount_reward_func,
|
|
|
|
- self.soft_format_reward_func,
|
|
|
|
|
|
+ # self.xmlcount_reward_func,
|
|
|
|
+ # self.soft_format_reward_func,
|
|
|
|
+ # self.strict_format_reward_func,
|
|
|
|
+ # self.int_reward_func,
|
|
|
|
+ # self.correctness_reward_func,
|
|
self.strict_format_reward_func,
|
|
self.strict_format_reward_func,
|
|
- self.int_reward_func,
|
|
|
|
- self.correctness_reward_func,
|
|
|
|
|
|
+ self.semantic_correctness_reward_func,
|
|
|
|
+ self.reasoning_quality_reward_func,
|
|
|
|
+ self.combined_reward_func,
|
|
],
|
|
],
|
|
args=training_args,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
train_dataset=train_dataset,
|
|
@@ -145,6 +154,80 @@ class ModelTrainer:
|
|
model.save_pretrained(save_path)
|
|
model.save_pretrained(save_path)
|
|
tokenizer.save_pretrained(save_path)
|
|
tokenizer.save_pretrained(save_path)
|
|
print(f"Model saved to {save_path}")
|
|
print(f"Model saved to {save_path}")
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def cosine_similarity(vec1, vec2):
|
|
|
|
+ """
|
|
|
|
+ 计算两个向量的余弦相似度。
|
|
|
|
+ """
|
|
|
|
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
+
|
|
|
|
+ def semantic_correctness_reward_func(self, prompts, completions, answer, **kwargs):
|
|
|
|
+ """
|
|
|
|
+ 使用 BERT 计算生成答案与标准答案的语义相似度。
|
|
|
|
+ :param prompts: 输入提示
|
|
|
|
+ :param completions: 模型生成的补全内容
|
|
|
|
+ :param answer: 标准答案
|
|
|
|
+ :return: 语义相似度得分列表
|
|
|
|
+ """
|
|
|
|
+ responses = [completion[0]['content'] for completion in completions]
|
|
|
|
+ extracted_responses = [self.extract_xml_answer(r) for r in responses]
|
|
|
|
+ scores = []
|
|
|
|
+ for resp, ans in zip(extracted_responses, answer):
|
|
|
|
+ # 编码生成答案和标准答案
|
|
|
|
+ inputs_resp = self.tokenizer(resp, return_tensors='pt', padding=True, truncation=True)
|
|
|
|
+ inputs_ans = self.tokenizer(ans, return_tensors='pt', padding=True, truncation=True)
|
|
|
|
+ with torch.no_grad():
|
|
|
|
+ outputs_resp = self.bert_model(**inputs_resp).last_hidden_state.mean(dim=1)
|
|
|
|
+ outputs_ans = self.bert_model(**inputs_ans).last_hidden_state.mean(dim=1)
|
|
|
|
+ # 计算余弦相似度
|
|
|
|
+ similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
|
|
|
|
+ scores.append(similarity)
|
|
|
|
+ return scores
|
|
|
|
+
|
|
|
|
+ def combined_reward_func(self, prompts, completions, answer, **kwargs):
|
|
|
|
+ """
|
|
|
|
+ 综合多个奖励函数,动态调整权重。
|
|
|
|
+ :param prompts: 输入提示
|
|
|
|
+ :param completions: 模型生成的补全内容
|
|
|
|
+ :param answer: 标准答案
|
|
|
|
+ :return: 综合得分列表
|
|
|
|
+ """
|
|
|
|
+ # 计算各奖励函数的得分
|
|
|
|
+ format_score = self.strict_format_reward_func(completions)
|
|
|
|
+ semantic_score = self.semantic_correctness_reward_func(prompts, completions, answer)
|
|
|
|
+ correctness_score = self.correctness_reward_func(prompts, completions, answer)
|
|
|
|
+
|
|
|
|
+ # 动态调整权重
|
|
|
|
+ combined_scores = []
|
|
|
|
+ for fs, ss, cs in zip(format_score, semantic_score, correctness_score):
|
|
|
|
+ if cs == 2.0: # 答案完全正确
|
|
|
|
+ combined_scores.append(fs * 0.2 + ss * 0.3 + cs * 0.5)
|
|
|
|
+ else: # 答案不完全正确
|
|
|
|
+ combined_scores.append(fs * 0.4 + ss * 0.4 + cs * 0.2)
|
|
|
|
+ return combined_scores
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def reasoning_quality_reward_func(completions, **kwargs):
|
|
|
|
+ """
|
|
|
|
+ 检查推理过程的质量。
|
|
|
|
+ :param completions: 模型生成的补全内容
|
|
|
|
+ :return: 推理过程质量得分列表
|
|
|
|
+ """
|
|
|
|
+ responses = [completion[0]["content"] for completion in completions]
|
|
|
|
+ scores = []
|
|
|
|
+ for response in responses:
|
|
|
|
+ reasoning_match = re.search(r"<reasoning>\n(.+?)\n</reasoning>", response, re.DOTALL)
|
|
|
|
+ if reasoning_match:
|
|
|
|
+ reasoning_content = reasoning_match.group(1).strip()
|
|
|
|
+ # 简单检查推理内容是否包含关键词
|
|
|
|
+ if "诊断" in reasoning_content and "原因" in reasoning_content:
|
|
|
|
+ scores.append(1.0)
|
|
|
|
+ else:
|
|
|
|
+ scores.append(0.5)
|
|
|
|
+ else:
|
|
|
|
+ scores.append(0.0)
|
|
|
|
+ return scores
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def extract_xml_answer(text: str) -> str:
|
|
def extract_xml_answer(text: str) -> str:
|
|
@@ -202,14 +285,26 @@ class ModelTrainer:
|
|
@staticmethod
|
|
@staticmethod
|
|
def strict_format_reward_func(completions, **kwargs):
|
|
def strict_format_reward_func(completions, **kwargs):
|
|
"""
|
|
"""
|
|
- 检查补全内容是否符合严格格式要求。
|
|
|
|
|
|
+ 检查响应是否符合严格的 XML 格式要求,并确保标签内容非空。
|
|
:param completions: 模型生成的补全内容
|
|
:param completions: 模型生成的补全内容
|
|
:return: 符合严格格式要求的得分列表
|
|
:return: 符合严格格式要求的得分列表
|
|
"""
|
|
"""
|
|
- pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
|
|
|
|
|
|
+ pattern = r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\n$"
|
|
responses = [completion[0]["content"] for completion in completions]
|
|
responses = [completion[0]["content"] for completion in completions]
|
|
- matches = [re.match(pattern, r) for r in responses]
|
|
|
|
- return [0.5 if match else 0.0 for match in matches]
|
|
|
|
|
|
+ scores = []
|
|
|
|
+ for response in responses:
|
|
|
|
+ match = re.match(pattern, response, re.DOTALL)
|
|
|
|
+ if match:
|
|
|
|
+ reasoning_content = match.group(1).strip()
|
|
|
|
+ answer_content = match.group(2).strip()
|
|
|
|
+ # 检查内容是否非空
|
|
|
|
+ if reasoning_content and answer_content:
|
|
|
|
+ scores.append(1.0) # 格式和内容均符合要求
|
|
|
|
+ else:
|
|
|
|
+ scores.append(0.5) # 格式符合但内容为空
|
|
|
|
+ else:
|
|
|
|
+ scores.append(0.0) # 格式不符合
|
|
|
|
+ return scores
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def int_reward_func(completions, **kwargs):
|
|
def int_reward_func(completions, **kwargs):
|