Procházet zdrojové kódy

修改train_model_grpo_v1.2.py 试图恢复模型训练自我思考

zhouyang.xie před 8 měsíci
rodič
revize
526921091e
36 změnil soubory, kde provedl 130 přidání a 362 odebrání
  1. 3 3
      README.MD
  2. 21 156
      src/train_model_grpo_v1.2.py
  3. 0 7
      src/unsloth_compiled_cache/UnslothAlignPropTrainer.py
  4. 0 7
      src/unsloth_compiled_cache/UnslothBCOTrainer.py
  5. 0 7
      src/unsloth_compiled_cache/UnslothCPOTrainer.py
  6. 0 7
      src/unsloth_compiled_cache/UnslothDDPOTrainer.py
  7. 0 7
      src/unsloth_compiled_cache/UnslothDPOTrainer.py
  8. 0 7
      src/unsloth_compiled_cache/UnslothGKDTrainer.py
  9. 4 17
      src/unsloth_compiled_cache/UnslothGRPOTrainer.py
  10. 0 7
      src/unsloth_compiled_cache/UnslothKTOTrainer.py
  11. 0 7
      src/unsloth_compiled_cache/UnslothNashMDTrainer.py
  12. 0 7
      src/unsloth_compiled_cache/UnslothORPOTrainer.py
  13. 0 7
      src/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
  14. 0 7
      src/unsloth_compiled_cache/UnslothPPOTrainer.py
  15. 0 7
      src/unsloth_compiled_cache/UnslothPRMTrainer.py
  16. 0 7
      src/unsloth_compiled_cache/UnslothRLOOTrainer.py
  17. 0 7
      src/unsloth_compiled_cache/UnslothRewardTrainer.py
  18. 102 81
      src/unsloth_compiled_cache/UnslothSFTTrainer.py
  19. 0 7
      src/unsloth_compiled_cache/UnslothXPOTrainer.py
  20. binární
      src/unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc
  21. binární
      src/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc
  22. binární
      src/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc
  23. binární
      src/unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
  24. binární
      src/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc
  25. binární
      src/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc
  26. binární
      src/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc
  27. binární
      src/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc
  28. binární
      src/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc
  29. binární
      src/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc
  30. binární
      src/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc
  31. binární
      src/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc
  32. binární
      src/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc
  33. binární
      src/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc
  34. binární
      src/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc
  35. binární
      src/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc
  36. binární
      src/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc

+ 3 - 3
README.MD

@@ -744,14 +744,14 @@ GPU:VRAM≥ 192 GB ( 164 GB);硬盘:NVMe SSD,≥ 8 TB,读取速度 ≥
 **AI智算云方案**
 昇腾910B (计算节点)x2 ≈ 61.4 万元/年
 网络(100Mbps):6.6万元/年
-费用合计:68万元/年
+**费用合计:68万元/年**
 
 计算资源配置及报价如下图:
 <div align=center><img src="./resources/images/计算资源-单节点-源自-AI智算云.png"></div>
 <div align=center><img src="./resources/images/计算资源及报价-源自-AI智算云.png"></div>
 
 **华为云方案**
-昇腾云服务 费用合计 ≥ 120 万元
+**昇腾云服务 费用合计 ≥ 120 万元**
 
 <div align=center><img src="./resources/images/计算资源及报价-源自-华为云.png"></div>
 注:
@@ -759,7 +759,7 @@ GPU:VRAM≥ 192 GB ( 164 GB);硬盘:NVMe SSD,≥ 8 TB,读取速度 ≥
 2. "AI专业服务"包含训练、微调解决方案,但 **不提供计算机资源(需单独购买)**。
 
 **李明星**
-昇腾云服务 费用合计: 16.3 万元
+**昇腾云服务 费用合计: 16.3 万元**
 
 计算资源配置及报价如下图:
 <div align=center><img src="./resources/images/计算资源及报价-源自-李明星.png"></div>

+ 21 - 156
src/train_model_grpo_v1.2.py

@@ -9,7 +9,6 @@ from trl import GRPOConfig, GRPOTrainer
 from datasets import load_dataset
 from conf_train import Config, load_config  # 导入配置文件
 import re
-from transformers import LongformerTokenizer, LongformerModel  # 分词模型最大支持 4096 个token
 import numpy as np
 
 
@@ -27,9 +26,6 @@ class ModelTrainer:
         self.fast_inference = config.fast_inference
         self.lora_rank = config.lora_rank
         self.gpu_memory_utilization = config.gpu_memory_utilization
-        # 初始化 Longformer 模型和分词器
-        self.tokenizer = LongformerTokenizer.from_pretrained(f'../models/allenai/longformer-base-4096')
-        self.longformer_model = LongformerModel.from_pretrained(f'../models/allenai/longformer-base-4096')
 
     def load_model(self):
         """
@@ -131,14 +127,11 @@ class ModelTrainer:
             model=model,
             processing_class=tokenizer,  # 用于处理输入文本的分词器(tokenizer)
             reward_funcs=[
-                # self.xmlcount_reward_func,  # XML 标签完整性奖励函数
-                # self.soft_format_reward_func,  # 软格式奖励函数
-                # self.strict_format_reward_func,  # 严格格式奖励函数
-                # self.int_reward_func,  # 整数奖励函数
-                # self.correctness_reward_func,  # 正确性奖励函数
-                # self.semantic_correctness_reward_func,  # 语义正确性奖励函数
-                # self.reasoning_quality_reward_func,  # 推理质量奖励函数
-                self.combined_reward_func,  # 综合奖励函数
+                self.xmlcount_reward_func,
+                self.soft_format_reward_func,
+                self.strict_format_reward_func,
+                self.int_reward_func,
+                self.correctness_reward_func,
             ],
             args=training_args,  # 定义的训练超参数
             train_dataset=train_dataset,  # 训练数据集
@@ -159,116 +152,13 @@ class ModelTrainer:
         print(f"Model saved to {save_path}")
 
     @staticmethod
-    def cosine_similarity(vec1, vec2):
-        """
-        计算两个向量的余弦相似度。
-        :param vec1: 第一个向量,形状为 (1, 768)
-        :param vec2: 第二个向量,形状为 (1, 768)
-        :return: 余弦相似度
-        """
-        vec1 = vec1.squeeze()  # 形状从 (1, 768) 变为 (768,)
-        vec2 = vec2.squeeze()  # 形状从 (1, 768) 变为 (768,)
-        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
-
-    def semantic_correctness_reward_func(self, prompts, completions, answer, **kwargs):
-        """
-        使用 Longformer 计算生成答案与标准答案的语义相似度。
-        """
-        responses = [completion[0]['content'] for completion in completions]
-        extracted_responses = [self.extract_xml_answer(r) for r in responses]
-        scores = []
-        for resp, ans in zip(extracted_responses, answer):
-            # 截断文本,确保长度不超过 4096
-            resp = self.tokenizer.decode(self.tokenizer.encode(resp, truncation=True, max_length=4096))
-            ans = self.tokenizer.decode(self.tokenizer.encode(ans, truncation=True, max_length=4096))
-            # 编码生成答案和标准答案
-            inputs_resp = self.tokenizer(resp, return_tensors='pt', padding=True, truncation=True, max_length=4096)
-            inputs_ans = self.tokenizer(ans, return_tensors='pt', padding=True, truncation=True, max_length=4096)
-            with torch.no_grad():
-                outputs_resp = self.longformer_model(**inputs_resp).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
-                outputs_ans = self.longformer_model(**inputs_ans).last_hidden_state.mean(dim=1)  # 形状为 (1, 768)
-            # 计算余弦相似度
-            similarity = self.cosine_similarity(outputs_resp.numpy(), outputs_ans.numpy())
-            scores.append(similarity)
-        return scores
-
-    def combined_reward_func(self, prompts, completions, answer, **kwargs):
-        """
-        综合多个奖励函数,动态调整权重。
-        :param prompts: 输入提示
-        :param completions: 模型生成的补全内容
-        :param answer: 标准答案
-        :return: 综合得分列表
-        """
-        # 计算各奖励函数的得分
-        format_score = self.strict_format_reward_func(completions)
-        semantic_score = self.semantic_correctness_reward_func(prompts, completions, answer)
-        correctness_score = self.correctness_reward_func(prompts, completions, answer)
-
-        # 动态调整权重
-        combined_scores = []
-        for fs, ss, cs in zip(format_score, semantic_score, correctness_score):
-            if cs == 2.0:  # 答案完全正确
-                combined_scores.append(fs * 0.2 + ss * 0.3 + cs * 0.5)
-            else:  # 答案不完全正确
-                combined_scores.append(fs * 0.4 + ss * 0.4 + cs * 0.2)
-        return combined_scores
-
-    @staticmethod
-    def reasoning_quality_reward_func(completions, **kwargs):
-        """
-        检查推理过程的质量。
-        :param completions: 模型生成的补全内容
-        :return: 推理过程质量得分列表
-        """
-        responses = [completion[0]["content"] for completion in completions]
-        scores = []
-        for response in responses:
-            reasoning_match = re.search(r"<reasoning>\n(.+?)\n</reasoning>", response, re.DOTALL)
-            if reasoning_match:
-                reasoning_content = reasoning_match.group(1).strip()
-                # 简单检查推理内容是否包含关键词
-                if "诊断" in reasoning_content and "原因" in reasoning_content:
-                    scores.append(1.0)
-                else:
-                    scores.append(0.5)
-            else:
-                scores.append(0.0)
-        return scores
-
-    @staticmethod
     def extract_xml_answer(text: str) -> str:
-        """
-        从文本中提取 XML 格式的答案。
-        :param text: 包含 XML 格式的文本
-        :return: 提取的答案
-        """
-        try:
-            print("text -> \n", text)
-            if "<answer>" in text and "</answer>" in text:
-                answer = text.split("<answer>")[-1]
-                answer = answer.split("</answer>")[0]
-                return answer.strip()
-            else:
-                print("Warning: <answer> tag not found in response.")
-                # 尝试提取其他有意义的部分
-                if "诊断" in text:
-                    return text.split("诊断")[-1].strip()
-                elif "排查建议" in text:
-                    return text.split("排查建议")[-1].strip()
-                else:
-                    return text.strip()  # 返回原始文本作为备用
-        except Exception as e:
-            print(f"Error extracting XML answer: {e}")
-            return ""  # 返回空字符串或其他默认值
-
+        answer = text.split("<answer>")[-1]
+        answer = answer.split("</answer>")[0]
+        return answer.strip()
+    
     @staticmethod
     def count_xml(text) -> float:
-        """
-        计算 XML 标签的数量和完整性。
-        :param text: 包含 XML 格式的文本
-        :return: XML 标签的完整性得分
-        """
         count = 0.0
         if text.count("<reasoning>\n") == 1:
             count += 0.125
@@ -276,18 +166,17 @@ class ModelTrainer:
             count += 0.125
         if text.count("\n<answer>\n") == 1:
             count += 0.125
-            count -= len(text.split("\n</answer>\n")[-1]) * 0.001
+            count -= len(text.split("\n</answer>\n")[-1])*0.001
         if text.count("\n</answer>") == 1:
             count += 0.125
-            count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
+            count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
         return count
 
+
     @staticmethod
     def xmlcount_reward_func(completions, **kwargs):
         """
-        计算 XML 标签的完整性得分。
-        :param completions: 模型生成的补全内容
-        :return: XML 标签的完整性得分列表
+        Reward function that counts XML tags in the completion.
         """
         contents = [completion[0]["content"] for completion in completions]
         return [ModelTrainer.count_xml(c) for c in contents]
@@ -295,9 +184,7 @@ class ModelTrainer:
     @staticmethod
     def soft_format_reward_func(completions, **kwargs):
         """
-        检查补全内容是否符合软格式要求。
-        :param completions: 模型生成的补全内容
-        :return: 符合软格式要求的得分列表
+        Reward function that checks if the completion has a specific format.
         """
         pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
         responses = [completion[0]["content"] for completion in completions]
@@ -307,33 +194,17 @@ class ModelTrainer:
     @staticmethod
     def strict_format_reward_func(completions, **kwargs):
         """
-        检查响应是否符合严格的 XML 格式要求,并确保标签内容非空。
-        :param completions: 模型生成的补全内容
-        :return: 符合严格格式要求的得分列表
+        Reward function that checks if the completion has a specific format.
         """
-        pattern = r"^<reasoning>\n(.+?)\n</reasoning>\n<answer>\n(.+?)\n</answer>\n$"
+        pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
         responses = [completion[0]["content"] for completion in completions]
-        scores = []
-        for response in responses:
-            match = re.match(pattern, response, re.DOTALL)
-            if match:
-                reasoning_content = match.group(1).strip()
-                answer_content = match.group(2).strip()
-                # 检查内容是否非空
-                if reasoning_content and answer_content:
-                    scores.append(1.0)  # 格式和内容均符合要求
-                else:
-                    scores.append(0.5)  # 格式符合但内容为空
-            else:
-                scores.append(0.0)  # 格式不符合
-        return scores
+        matches = [re.match(pattern, r) for r in responses]
+        return [0.5 if match else 0.0 for match in matches]
 
     @staticmethod
     def int_reward_func(completions, **kwargs):
         """
-        检查补全内容是否包含整数。
-        :param completions: 模型生成的补全内容
-        :return: 包含整数的得分列表
+        Reward function that checks if the completion contains an integer.
         """
         responses = [completion[0]['content'] for completion in completions]
         extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
@@ -342,20 +213,14 @@ class ModelTrainer:
     @staticmethod
     def correctness_reward_func(prompts, completions, answer, **kwargs):
         """
-        检查补全内容是否正确。
-        :param prompts: 输入提示
-        :param completions: 模型生成的补全内容
-        :param answer: 正确答案
-        :return: 补全内容正确的得分列表
+        Reward function that checks if the completion matches the correct answer.
         """
-        print("completions : \n ", completions)
         responses = [completion[0]['content'] for completion in completions]
         q = prompts[0][-1]['content']
         extracted_responses = [ModelTrainer.extract_xml_answer(r) for r in responses]
-        print('-' * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
+        print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
         return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
 
-
 if __name__ == "__main__":
     try:
         # 加载配置文件

+ 0 - 7
src/unsloth_compiled_cache/UnslothAlignPropTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothBCOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothCPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothDDPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothDPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothGKDTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 4 - 17
src/unsloth_compiled_cache/UnslothGRPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn
@@ -120,7 +113,7 @@ class UnslothEfficientGRPO(torch.autograd.Function):
             fullgraph = True,
             options = torch_compile_options,
         )
-
+        
         grad_inputs_chunks = torch.chunk(grad_inputs,        chunks = n_chunks, dim = 0)
         new_hidden_states  = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
         old_hidden_states  = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
@@ -1089,20 +1082,14 @@ class _UnslothGRPOTrainer(Trainer):
                 self, _input_ids, logits_to_keep, completion_mask, advantages,
                 n_chunks = self.args.unsloth_num_chunks,
             )
-
+        
         # Log the metrics
         # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
+        self._metrics["completion_length"].append(completion_length.item())
 
         # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
         # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
-
-        if "train" in self._metrics:
-            mode = "eval" if self.control.should_evaluate else "train"
-            self._metrics[mode]["completion_length"].append(completion_length.item())
-            self._metrics[mode]["kl"].append(mean_kl.item())
-        else:
-            self._metrics["completion_length"].append(completion_length.item())
-            self._metrics["kl"].append(mean_kl.item())
+        self._metrics["kl"].append(mean_kl.item())
         return loss
 
     def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):

+ 0 - 7
src/unsloth_compiled_cache/UnslothKTOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothNashMDTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothORPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothPPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothPRMTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothRLOOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 0 - 7
src/unsloth_compiled_cache/UnslothRewardTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

+ 102 - 81
src/unsloth_compiled_cache/UnslothSFTTrainer.py

@@ -1,15 +1,8 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn
 from torch.nn import functional as F
-from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, Callable, ConstantLengthDataset, DataCollator, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
+from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PartialState, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_conversational, is_liger_kernel_available, is_peft_available, is_wandb_available, maybe_apply_chat_template, maybe_convert_to_chatml, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, os)
 
 
 import os
@@ -618,89 +611,117 @@ class _UnslothSFTTrainer(Trainer):
     def _prepare_dataset(
         self,
         dataset: Union[Dataset, IterableDataset],
-        processing_class,
-        args,
+        processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
+        args: SFTConfig,
         packing: bool,
         formatting_func: Optional[Callable[[dict], str]],
         dataset_name: str,
     ) -> Union[Dataset, IterableDataset]:
-        # All Unsloth Zoo code licensed under LGPLv3
-        if isinstance(dataset, ConstantLengthDataset): return dataset
-    
-        map_kwargs = {}
-        use_desc = isinstance(dataset, Dataset)
-    
-        # Get max length
-        max_seq_length = getattr(args, "max_length", 0)
-        if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
-        if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
-        if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
-        dataset_text_field = getattr(args, "dataset_text_field", "text")
-        do_truncation = max_seq_length != 0
-        do_formatting_func = False
-    
-        # Check if already tokenized so skip
-        from transformers import DataCollatorForSeq2Seq
-        column_names = set(next(iter(dataset)).keys())
-        if "input_ids" in column_names:
-            # Most likely forgot data collator!
-            from transformers import DataCollatorForSeq2Seq
-            self.data_collator = DataCollatorForSeq2Seq(processing_class)
+        # Convert the dataset to an IterableDataset if it is a ConstantLengthDataset
+        if isinstance(dataset, ConstantLengthDataset):
             return dataset
-        elif dataset_text_field not in column_names:
-            do_formatting_func = True
-            if formatting_func is None:
-                raise RuntimeError("Unsloth: You must specify a `formatting_func`")
-        pass
-    
-        # Check double BOS tokens
-        if do_formatting_func:
-            test_text = formatting_func(dataset[0])
-            if not isinstance(test_text, list):
-                raise ValueError(
-                    "Unsloth: The `formatting_func` should return a list of processed strings."
+
+        # If the dataset is already preprocessed (tokenized), skip the processing steps.
+        column_names = list(next(iter(dataset)).keys())
+        is_processed = "input_ids" in column_names
+
+        # Build the kwargs for the `map` function
+        map_kwargs = {}
+        if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc
+            map_kwargs["num_proc"] = args.dataset_num_proc
+
+        with PartialState().local_main_process_first():
+            # Apply the formatting function if any
+            if formatting_func is not None and is_processed:
+                warnings.warn(
+                    "You passed a dataset that is already processed (contains an `input_ids` field) together with a "
+                    "formatting function. Therefore `formatting_func` will be ignored. Either remove the "
+                    "`formatting_func` or pass a dataset that is not already processed.",
+                    UserWarning,
                 )
-            test_text = test_text[0]
-        else:
-            test_text = dataset[0][dataset_text_field]
-        chat_template = getattr(processing_class, 'chat_template', None)
-        chat_template = '' if chat_template is None else chat_template
-        add_special_tokens = True
-    
-        if getattr(processing_class, 'bos_token', None) is not None:
-            if test_text.startswith(processing_class.bos_token) or processing_class.bos_token in chat_template:
-                add_special_tokens = False
-                print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
-        pass
-    
-        # Create tokenize function
-        def _tokenize(example):
-            return processing_class(
-                example[dataset_text_field] if not do_formatting_func else formatting_func(example),
-                truncation = do_truncation,
-                max_length = max_seq_length,
-                return_token_type_ids = False,
-                add_special_tokens = add_special_tokens,
+
+            if formatting_func is not None and not is_processed:
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
+
+                batched = isinstance(formatting_func(next(iter(dataset))), list)
+
+                def _func(example):
+                    return {"text": formatting_func(example)}
+
+                dataset = dataset.map(_func, batched=batched, **map_kwargs)
+
+            # If the dataset is prompt-completion, convert it to language modeling type
+            if "prompt" in dataset.column_names and "completion" in dataset.column_names:
+                key = "messages" if is_conversational(dataset[0]) else "text"
+
+                def concat_prompt_completion(example):
+                    return {key: example["prompt"] + example["completion"]}
+
+                dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
+
+            # Convert the dataset to ChatML if needed
+            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
+            dataset = dataset.map(
+                maybe_convert_to_chatml,
+                remove_columns="conversations" if "conversations" in dataset.column_names else None,
+                **map_kwargs,
             )
-        pass
-    
-        map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
-        if use_desc: map_kwargs["desc"] = f'Tokenizing to ["{dataset_text_field}"]'
-        dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
-    
-        if packing:
-            if max_seq_length == 0:
-                raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
-    
-            if use_desc: map_kwargs["desc"] = f"Packing {dataset_name} dataset"
-            dataset = dataset.select_columns("input_ids").map(
-                pack_examples,
-                batched = True,
-                fn_kwargs = {"seq_length": max_seq_length,},
+
+            # Apply the chat template if needed
+            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
+            dataset = dataset.map(
+                maybe_apply_chat_template,
+                fn_kwargs={"tokenizer": processing_class},
+                remove_columns="messages" if "messages" in dataset.column_names else None,  # renamed to "text"
                 **map_kwargs,
             )
+
+            # Tokenize the dataset if needed
+            if not is_processed:
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
+
+                def tokenize(example, processing_class, dataset_text_field):
+                    return processing_class(example[dataset_text_field])
+
+                dataset = dataset.map(
+                    tokenize,
+                    fn_kwargs={"processing_class": processing_class, "dataset_text_field": args.dataset_text_field},
+                    **map_kwargs,
+                )
+
+            # Pack or truncate
+            if packing:
+                if args.max_seq_length is None:
+                    raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Packing {dataset_name} dataset"
+                dataset = dataset.select_columns("input_ids")
+                dataset = dataset.map(
+                    pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
+                )
+            elif args.max_seq_length is not None:
+                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
+                    map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
+
+                def truncate(example, max_seq_length):
+                    return {key: example[key][:max_seq_length] for key in ["input_ids", "attention_mask"]}
+
+                dataset = dataset.map(
+                    truncate,
+                    fn_kwargs={"max_seq_length": args.max_seq_length},
+                    **map_kwargs,
+                )
+
+            # For Liger kernel, ensure only input_ids is present
+            if args.use_liger:
+                dataset = dataset.select_columns("input_ids")
+
         return dataset
-    
+
     def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
         outputs = super().compute_loss(
             model,

+ 0 - 7
src/unsloth_compiled_cache/UnslothXPOTrainer.py

@@ -1,10 +1,3 @@
-"""
-2025.3.3
-2025.3.5
-4.49.0
-0.15.2
-__UNSLOTH_VERSIONING__
-"""
 from torch import Tensor
 import torch
 import torch.nn as nn

binární
src/unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc


binární
src/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc