|
@@ -9,7 +9,6 @@ from trl import GRPOConfig, GRPOTrainer
|
|
|
from datasets import load_dataset
|
|
from datasets import load_dataset
|
|
|
from conf_train import Config, load_config # 导入配置文件
|
|
from conf_train import Config, load_config # 导入配置文件
|
|
|
import re
|
|
import re
|
|
|
-from transformers import LongformerTokenizer, LongformerModel # 分词模型最大支持 4096 个token
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
@@ -27,9 +26,6 @@ class ModelTrainer:
|
|
|
self.fast_inference = config.fast_inference
|
|
self.fast_inference = config.fast_inference
|
|
|
self.lora_rank = config.lora_rank
|
|
self.lora_rank = config.lora_rank
|
|
|
self.gpu_memory_utilization = config.gpu_memory_utilization
|
|
self.gpu_memory_utilization = config.gpu_memory_utilization
|
|
|
- # 初始化 Longformer 模型和分词器
|
|
|
|
|
- self.tokenizer = LongformerTokenizer.from_pretrained(f'../models/allenai/longformer-base-4096')
|
|
|
|
|
- self.longformer_model = LongformerModel.from_pretrained(f'../models/allenai/longformer-base-4096')
|
|
|
|
|
|
|
|
|
|
def load_model(self):
|
|
def load_model(self):
|
|
|
"""
|
|
"""
|
|
@@ -131,14 +127,11 @@ class ModelTrainer:
|
|
|
model=model,
|
|
model=model,
|
|
|
processing_class=tokenizer, # 用于处理输入文本的分词器(tokenizer)
|
|
processing_class=tokenizer, # 用于处理输入文本的分词器(tokenizer)
|
|
|
reward_funcs=[
|
|
reward_funcs=[
|
|
|
- # self.xmlcount_reward_func, # XML 标签完整性奖励函数
|
|
|
|
|
- # self.soft_format_reward_func, # 软格式奖励函数
|
|
|
|
|
- # self.strict_format_reward_func, # 严格格式奖励函数
|
|
|
|
|
- # self.int_reward_func, # 整数奖励函数
|
|
|
|
|
- # self.correctness_reward_func, # 正确性奖励函数
|
|
|
|
|
- # self.semantic_correctness_reward_func, # 语义正确性奖励函数
|
|
|
|
|
- # self.reasoning_quality_reward_func, # 推理质量奖励函数
|
|
|
|
|
- self.combined_reward_func, # 综合奖励函数
|
|
|
|
|
|
|
+ self.xmlcount_reward_func,
|
|
|
|
|
+ self.soft_format_reward_func,
|
|
|
|
|
+ self.strict_format_reward_func,
|
|
|
|
|
+ self.int_reward_func,
|
|
|
|
|
+ self.correctness_reward_func,
|
|
|
],
|
|
],
|
|
|
args=training_args, # 定义的训练超参数
|
|
args=training_args, # 定义的训练超参数
|
|
|
train_dataset=train_dataset, # 训练数据集
|
|
train_dataset=train_dataset, # 训练数据集
|
|
@@ -159,116 +152,13 @@ class ModelTrainer:
|
|
|
print(f"Model saved to {save_path}")
|
|
print(f"Model saved to {save_path}")
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def cosine_similarity(vec1, vec2):
|
|
|
|
|
- """
|
|
|
|
|
- 计算两个向量的余弦相似度。
|
|
|
|
|
- :param vec1: 第一个向量,形状为 (1, 768)
|
|
|
|
|
- :param vec2: 第二个向量,形状为 (1, 768)
|
|
|
|
|
- :return: 余弦相似度
|
|
|
|
|
- """
|
|
|
|
|
- vec1 = vec1.squeeze() # 形状从 (1, 768) 变为 (768,)
|
|
|
|
|
- vec2 = vec2.squeeze() # 形状从 (1, 768) 变为 (768,)
|
|
|
|
|
- return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
|
-
|
|
|
|
|
- def semantic_correctness_reward_func(self, prompts, completions, answer, **kwargs):
|
|
|
|
|
- """
|
|
|
|
|
- 使用 Longformer 计算生成答案与标准答案的语义相似度。
|
|
|
|
|
- """
|
|
|
|
|
- responses = [completion[0]['content'] for completion in completions]
|
|
|
|
|
- extracted_responses = [self.extract_xml_answer(r) for r in responses]
|
|
|
|
|
- scores = []
|
|
|
|
|
- for resp, ans in zip(extracted_responses, answer):
|
|
|
|
|
- # 截断文本,确保长度不超过 4096
|
|
|
|
|
- resp = self.tokenizer.decode(self.tokenizer.encode(resp, truncation=True, max_length=4096))
|
|
|
|
|
- ans = self.tokenizer.decode(self.tokenizer.encode(ans, truncation=True, max_length=4096))
|
|
|
|
|
- # 编码生成答案和标准答案
|
|
|
|
|
- inputs_resp = self.tokenizer(resp, return_tensors='pt', padding=True, truncation=True, max_length=4096)
|
|
|
|
|
- inputs_ans = self.tokenizer(ans, return_tensors='pt', padding=True, truncation=True, max_length=4096)
|
|
|
|
|
- with torch.no_grad():
|
|
|
|
|
- outputs_resp = self.longformer_model(**inputs_resp).last_hidden_state.mean(dim=1) # 形状为 (1, 768)
|
|
|
|
|
- outputs_ans = self.longformer_model(**inputs_ans).last_hidden_state.mean(dim=1) # 形状为 (1, 768)
|
|
|
|
|
- # 计算余弦相似度
|
|
|
|
|
- similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
|
|
|
|
|
- scores.append(similarity)
|
|
|
|
|
- return scores
|
|
|
|
|
-
|
|
|
|
|
- def combined_reward_func(self, prompts, completions, answer, **kwargs):
|
|
|
|
|
- """
|
|
|
|
|
- 综合多个奖励函数,动态调整权重。
|
|
|
|
|
- :param prompts: 输入提示
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :param answer: 标准答案
|
|
|
|
|
- :return: 综合得分列表
|
|
|
|
|
- """
|
|
|
|
|
- # 计算各奖励函数的得分
|
|
|
|
|
- format_score = self.strict_format_reward_func(completions)
|
|
|
|
|
- semantic_score = self.semantic_correctness_reward_func(prompts, completions, answer)
|
|
|
|
|
- correctness_score = self.correctness_reward_func(prompts, completions, answer)
|
|
|
|
|
-
|
|
|
|
|
- # 动态调整权重
|
|
|
|
|
- combined_scores = []
|
|
|
|
|
- for fs, ss, cs in zip(format_score, semantic_score, correctness_score):
|
|
|
|
|
- if cs == 2.0: # 答案完全正确
|
|
|
|
|
- combined_scores.append(fs * 0.2 + ss * 0.3 + cs * 0.5)
|
|
|
|
|
- else: # 答案不完全正确
|
|
|
|
|
- combined_scores.append(fs * 0.4 + ss * 0.4 + cs * 0.2)
|
|
|
|
|
- return combined_scores
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def reasoning_quality_reward_func(completions, **kwargs):
|
|
|
|
|
- """
|
|
|
|
|
- 检查推理过程的质量。
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :return: 推理过程质量得分列表
|
|
|
|
|
- """
|
|
|
|
|
- responses = [completion[0]["content"] for completion in completions]
|
|
|
|
|
- scores = []
|
|
|
|
|
- for response in responses:
|
|
|
|
|
- reasoning_match = re.search(r"<reasoning>\n(.+?)\n</reasoning>", response, re.DOTALL)
|
|
|
|
|
- if reasoning_match:
|
|
|
|
|
- reasoning_content = reasoning_match.group(1).strip()
|
|
|
|
|
- # 简单检查推理内容是否包含关键词
|
|
|
|
|
- if "诊断" in reasoning_content and "原因" in reasoning_content:
|
|
|
|
|
- scores.append(1.0)
|
|
|
|
|
- else:
|
|
|
|
|
- scores.append(0.5)
|
|
|
|
|
- else:
|
|
|
|
|
- scores.append(0.0)
|
|
|
|
|
- return scores
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
def extract_xml_answer(text: str) -> str:
|
|
def extract_xml_answer(text: str) -> str:
|
|
|
- """
|
|
|
|
|
- 从文本中提取 XML 格式的答案。
|
|
|
|
|
- :param text: 包含 XML 格式的文本
|
|
|
|
|
- :return: 提取的答案
|
|
|
|
|
- """
|
|
|
|
|
- try:
|
|
|
|
|
- print("text -> \n", text)
|
|
|
|
|
- if "<answer>" in text and "</answer>" in text:
|
|
|
|
|
- answer = text.split("<answer>")[-1]
|
|
|
|
|
- answer = answer.split("</answer>")[0]
|
|
|
|
|
- return answer.strip()
|
|
|
|
|
- else:
|
|
|
|
|
- print("Warning: <answer> tag not found in response.")
|
|
|
|
|
- # 尝试提取其他有意义的部分
|
|
|
|
|
- if "诊断" in text:
|
|
|
|
|
- return text.split("诊断")[-1].strip()
|
|
|
|
|
- elif "排查建议" in text:
|
|
|
|
|
- return text.split("排查建议")[-1].strip()
|
|
|
|
|
- else:
|
|
|
|
|
- return text.strip() # 返回原始文本作为备用
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- print(f"Error extracting XML answer: {e}")
|
|
|
|
|
- return "" # 返回空字符串或其他默认值
|
|
|
|
|
-
|
|
|
|
|
|
|
+ answer = text.split("<answer>")[-1]
|
|
|
|
|
+ answer = answer.split("</answer>")[0]
|
|
|
|
|
+ return answer.strip()
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def count_xml(text) -> float:
|
|
def count_xml(text) -> float:
|
|
|
- """
|
|
|
|
|
- 计算 XML 标签的数量和完整性。
|
|
|
|
|
- :param text: 包含 XML 格式的文本
|
|
|
|
|
- :return: XML 标签的完整性得分
|
|
|
|
|
- """
|
|
|
|
|
count = 0.0
|
|
count = 0.0
|
|
|
if text.count("<reasoning>\n") == 1:
|
|
if text.count("<reasoning>\n") == 1:
|
|
|
count += 0.125
|
|
count += 0.125
|
|
@@ -276,18 +166,17 @@ class ModelTrainer:
|
|
|
count += 0.125
|
|
count += 0.125
|
|
|
if text.count("\n<answer>\n") == 1:
|
|
if text.count("\n<answer>\n") == 1:
|
|
|
count += 0.125
|
|
count += 0.125
|
|
|
- count -= len(text.split("\n</answer>\n")[-1]) * 0.001
|
|
|
|
|
|
|
+ count -= len(text.split("\n</answer>\n")[-1])*0.001
|
|
|
if text.count("\n</answer>") == 1:
|
|
if text.count("\n</answer>") == 1:
|
|
|
count += 0.125
|
|
count += 0.125
|
|
|
- count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
|
|
|
|
|
|
|
+ count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
|
|
|
return count
|
|
return count
|
|
|
|
|
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def xmlcount_reward_func(completions, **kwargs):
|
|
def xmlcount_reward_func(completions, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
- 计算 XML 标签的完整性得分。
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :return: XML 标签的完整性得分列表
|
|
|
|
|
|
|
+ Reward function that counts XML tags in the completion.
|
|
|
"""
|
|
"""
|
|
|
contents = [completion[0]["content"] for completion in completions]
|
|
contents = [completion[0]["content"] for completion in completions]
|
|
|
return [ModelTrainer.count_xml(c) for c in contents]
|
|
return [ModelTrainer.count_xml(c) for c in contents]
|
|
@@ -295,9 +184,7 @@ class ModelTrainer:
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def soft_format_reward_func(completions, **kwargs):
|
|
def soft_format_reward_func(completions, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
- 检查补全内容是否符合软格式要求。
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :return: 符合软格式要求的得分列表
|
|
|
|
|
|
|
+ Reward function that checks if the completion has a specific format.
|
|
|
"""
|
|
"""
|
|
|
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
|
|
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
|
|
|
responses = [completion[0]["content"] for completion in completions]
|
|
responses = [completion[0]["content"] for completion in completions]
|
|
@@ -307,33 +194,17 @@ class ModelTrainer:
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def strict_format_reward_func(completions, **kwargs):
|
|
def strict_format_reward_func(completions, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
- 检查响应是否符合严格的 XML 格式要求,并确保标签内容非空。
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :return: 符合严格格式要求的得分列表
|
|
|
|
|
|
|
+ Reward function that checks if the completion has a specific format.
|
|
|
"""
|
|
"""
|
|
|
- pattern = r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\n$"
|
|
|
|
|
|
|
+ pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
|
|
|
responses = [completion[0]["content"] for completion in completions]
|
|
responses = [completion[0]["content"] for completion in completions]
|
|
|
- scores = []
|
|
|
|
|
- for response in responses:
|
|
|
|
|
- match = re.match(pattern, response, re.DOTALL)
|
|
|
|
|
- if match:
|
|
|
|
|
- reasoning_content = match.group(1).strip()
|
|
|
|
|
- answer_content = match.group(2).strip()
|
|
|
|
|
- # 检查内容是否非空
|
|
|
|
|
- if reasoning_content and answer_content:
|
|
|
|
|
- scores.append(1.0) # 格式和内容均符合要求
|
|
|
|
|
- else:
|
|
|
|
|
- scores.append(0.5) # 格式符合但内容为空
|
|
|
|
|
- else:
|
|
|
|
|
- scores.append(0.0) # 格式不符合
|
|
|
|
|
- return scores
|
|
|
|
|
|
|
+ matches = [re.match(pattern, r) for r in responses]
|
|
|
|
|
+ return [0.5 if match else 0.0 for match in matches]
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def int_reward_func(completions, **kwargs):
|
|
def int_reward_func(completions, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
- 检查补全内容是否包含整数。
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :return: 包含整数的得分列表
|
|
|
|
|
|
|
+ Reward function that checks if the completion contains an integer.
|
|
|
"""
|
|
"""
|
|
|
responses = [completion[0]['content'] for completion in completions]
|
|
responses = [completion[0]['content'] for completion in completions]
|
|
|
extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
|
|
extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
|
|
@@ -342,20 +213,14 @@ class ModelTrainer:
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
|
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
|
|
"""
|
|
"""
|
|
|
- 检查补全内容是否正确。
|
|
|
|
|
- :param prompts: 输入提示
|
|
|
|
|
- :param completions: 模型生成的补全内容
|
|
|
|
|
- :param answer: 正确答案
|
|
|
|
|
- :return: 补全内容正确的得分列表
|
|
|
|
|
|
|
+ Reward function that checks if the completion matches the correct answer.
|
|
|
"""
|
|
"""
|
|
|
- print("completions : \n ", completions)
|
|
|
|
|
responses = [completion[0]['content'] for completion in completions]
|
|
responses = [completion[0]['content'] for completion in completions]
|
|
|
q = prompts[0][-1]['content']
|
|
q = prompts[0][-1]['content']
|
|
|
extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
|
|
extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
|
|
|
- print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
|
|
|
|
|
|
|
+ print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
|
|
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
|
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
|
|
|
|
|
|
|
-
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
try:
|
|
try:
|
|
|
# 加载配置文件
|
|
# 加载配置文件
|