Browse Source

2025-3-5 18:29 README.MD大模型选型评估,训练、微调所需计算资源评估;完善训练数据集生成、训练、推理源程序;

zhouyang.xie 4 months ago
parent
commit
b777163795
22 changed files with 250 additions and 14 deletions
  1. 8 4
      README.MD
  2. 11 8
      src/train_model.py
  3. 229 0
      src/train_model_grpo.py
  4. 1 1
      src/unsloth_compiled_cache/UnslothAlignPropTrainer.py
  5. 1 1
      src/unsloth_compiled_cache/UnslothDDPOTrainer.py
  6. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc
  7. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc
  8. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc
  9. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
  10. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc
  11. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc
  12. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc
  13. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc
  14. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc
  15. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc
  16. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc
  17. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc
  18. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc
  19. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc
  20. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc
  21. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc
  22. BIN
      src/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc

+ 8 - 4
README.MD

@@ -609,19 +609,23 @@ Unsloth 的核心思想是通过 **高效的计算优化**、**内存管理优
 
 #  大语言模型(LLM)应用评估
 
-## 大模型选型
+## 大模型及训练技术选型
 
 **选型原则:**
 1. 开源大模型,可任意修改,无权属纠纷风险,完全免费,具体开源协议如: MIT、Apache等
-2. 大模型参数量规模,及相应综合评分值排名属于前20%范围,投入计算资源的经济成本;
+2. 大模型参数量规模,及相应综合评分值排名属于前20%范围;
+3. 大模型训练、微调技术方案,实现可行,且投入计算资源的经济成本相对较低;
 
 **参考材料**
 <div align=center><img src="./resources/images/LLM评分-Part-deepseek-20250120.png"></div>
 [fine-tuning-vram-requirements]:  https://docs.unsloth.ai/get-started/beginner-start-here/unsloth-requirements#fine-tuning-vram-requirements "Unsloth训练、微调模型对GPU VRAM需求"
 
-**大模型选型:**
+**技术选型:**
+大模型:
 DeepSeek-R1-Distill-Qwen-32B
-理由: 符合选型原则,且技术预研较多使用1.5B、7B规模模型。
+Tiny-R1-32B
+训练、微调技术:Unsloth
+理由: 符合选型原则,且已成功利用 Unsloth ,训练1.5B规模模型。
 
 ## 大模型训练、微调计算资源需求
 前提:基于 GRPO转换后的模型

+ 11 - 8
src/train_model.py

@@ -2,32 +2,35 @@ import os
 import torch
 from unsloth import FastLanguageModel
 from unsloth import is_bfloat16_supported
-from trl import SFTTrainer
+from trl import SFTTrainer, GRPOConfig, GRPOTrainer
 from datasets import load_dataset
 from transformers import TrainingArguments
 
 class ModelTrainer:
-    def __init__(self, model_name, max_seq_length, dtype, load_in_4bit):
+    def __init__(self, model_name, max_seq_length, dtype, load_in_4bit,lora_rank=32):
         # 初始化 ModelTrainer 类,设置模型名称、最大序列长度、数据类型和是否以4位加载
         self.model_name = model_name
         self.max_seq_length = max_seq_length
-        self.dtype = dtype
-        # dtype: 数据类型,如 torch.float16 或 torch.bfloat16
-        self.load_in_4bit = load_in_4bit
-        # load_in_4bit: 是否以4位精度加载模型,用于节省显存
+        self.dtype = dtype         # dtype: 数据类型,如 torch.float16 或 torch.bfloat16
+        self.load_in_4bit = load_in_4bit         # load_in_4bit: 是否以4位精度加载模型,用于节省显存
+        self.lora_rank=lora_rank  #Larger rank = smarter, but slower
 
     def load_model(self):
         # 加载预训练模型和分词器
         model, tokenizer = FastLanguageModel.from_pretrained(
             model_name=self.model_name,
             max_seq_length=self.max_seq_length,
+            load_in_4bit=self.load_in_4bit, # 值为True 以 4 bit量化进行微调,为False LoRA 16bit。这将内存使用量减少了 4 倍,使我们能够在免费的 16GB 内存 GPU 中实际进行微调。4 位量化本质上将权重转换为一组有限的数字以减少内存使用量。这样做的缺点是准确度会下降 1-2%。如果您想要这种微小的额外准确度,请在较大的 GPU(如 H100)上将其设置为 False。
             dtype=self.dtype,
-            load_in_4bit=self.load_in_4bit, # 以 4 位量化进行微调。这将内存使用量减少了 4 倍,使我们能够在免费的 16GB 内存 GPU 中实际进行微调。4 位量化本质上将权重转换为一组有限的数字以减少内存使用量。这样做的缺点是准确度会下降 1-2%。如果您想要这种微小的额外准确度,请在较大的 GPU(如 H100)上将其设置为 False。
+            fast_inference = True, # Enable vLLM fast inference
+            max_lora_rank = lora_rank,
+            gpu_memory_utilization=0.6,# Reduce if out of memory
         )
         
         # 添加 LoRA 适配器
         model = FastLanguageModel.get_peft_model(
             model,
+            max_seq_length=self.max_seq_length,  # 最大上下文(序列)长度
             r=16,  # LoRA 的秩,控制适配器的复杂度
             target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                           "gate_proj", "up_proj", "down_proj"],  # 应用 LoRA 的目标模块
@@ -36,7 +39,6 @@ class ModelTrainer:
             bias="none",     # 是否在 LoRA 中添加偏置,设置为 "none" 以优化性能
             use_gradient_checkpointing="unsloth",  # 使用梯度检查点以节省显存,对于非常长的上下文,使用 True 或 "unsloth"
             random_state=3407,  # 随机种子,确保实验可复现
-            max_seq_length=self.max_seq_length,  # 最大上下文(序列)长度
             use_rslora=False,  # 是否使用 rank stabilized LoRA,当前不支持
             loftq_config=None,  # LoftQ 配置,当前不支持
         )
@@ -95,6 +97,7 @@ if __name__ == "__main__":
     max_seq_length = 2048  # 最大序列长度
     dtype = torch.float16  # 数据类型
     load_in_4bit = True  # 是否以4位精度加载模型
+    lora_rank=32
 
     # 定义训练集和测试集路径
     train_data_path = os.path.join('..', 'data', 'processed', 'train.jsonl')

+ 229 - 0
src/train_model_grpo.py

@@ -0,0 +1,229 @@
+import os
+import torch
+from unsloth import FastLanguageModel
+from unsloth import is_bfloat16_supported
+from trl import SFTTrainer, GRPOConfig, GRPOTrainer
+from datasets import load_dataset
+from transformers import TrainingArguments
+
+import re
+from datasets import load_dataset, Dataset
+from modelscope.msdatasets import MsDataset
+
+# Load and prep dataset
+SYSTEM_PROMPT = """
+Respond in the following format:
+<reasoning>
+...
+</reasoning>
+<answer>
+...
+</answer>
+"""
+
+XML_COT_FORMAT = """\
+<reasoning>
+{reasoning}
+</reasoning>
+<answer>
+{answer}
+</answer>
+"""
+
+def extract_xml_answer(text: str) -> str:
+    answer = text.split("<answer>")[-1]
+    answer = answer.split("</answer>")[0]
+    return answer.strip()
+
+def extract_hash_answer(text: str) -> str | None:
+    if "####" not in text:
+        return None
+    return text.split("####")[1].strip()
+
+# uncomment middle messages for 1-shot prompting
+def get_gsm8k_questions(split = "train") -> Dataset:
+    # data = load_dataset('https://huggingface.co/datasets/openai/gsm8k', 'main')[split] # type: ignore
+    
+    data =  MsDataset.load('openai-mirror/gsm8k', subset_name='main', split=split)
+    data = data.map(lambda x: { # type: ignore
+        'prompt': [
+            {'role': 'system', 'content': SYSTEM_PROMPT},
+            {'role': 'user', 'content': x['question']}
+        ],
+        'answer': extract_hash_answer(x['answer'])
+    }) # type: ignore
+    return data # type: ignore
+
+dataset = get_gsm8k_questions()
+
+# Reward functions
+def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
+    responses = [completion[0]['content'] for completion in completions]
+    q = prompts[0][-1]['content']
+    extracted_responses = [extract_xml_answer(r) for r in responses]
+    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
+    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+def int_reward_func(completions, **kwargs) -> list[float]:
+    responses = [completion[0]['content'] for completion in completions]
+    extracted_responses = [extract_xml_answer(r) for r in responses]
+    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+
+def strict_format_reward_func(completions, **kwargs) -> list[float]:
+    """Reward function that checks if the completion has a specific format."""
+    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
+    responses = [completion[0]["content"] for completion in completions]
+    matches = [re.match(pattern, r) for r in responses]
+    return [0.5 if match else 0.0 for match in matches]
+
+def soft_format_reward_func(completions, **kwargs) -> list[float]:
+    """Reward function that checks if the completion has a specific format."""
+    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
+    responses = [completion[0]["content"] for completion in completions]
+    matches = [re.match(pattern, r) for r in responses]
+    return [0.5 if match else 0.0 for match in matches]
+
+def count_xml(text) -> float:
+    count = 0.0
+    if text.count("<reasoning>\n") == 1:
+        count += 0.125
+    if text.count("\n</reasoning>\n") == 1:
+        count += 0.125
+    if text.count("\n<answer>\n") == 1:
+        count += 0.125
+        count -= len(text.split("\n</answer>\n")[-1])*0.001
+    if text.count("\n</answer>") == 1:
+        count += 0.125
+        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
+    return count
+
+def xmlcount_reward_func(completions, **kwargs) -> list[float]:
+    contents = [completion[0]["content"] for completion in completions]
+    return [count_xml(c) for c in contents]
+
+class ModelTrainer:
+    def __init__(self, model_name, max_seq_length, dtype, load_in_4bit,lora_rank=32):
+        # 初始化 ModelTrainer 类,设置模型名称、最大序列长度、数据类型和是否以4位加载
+        self.model_name = model_name
+        self.max_seq_length = max_seq_length
+        self.dtype = dtype         # dtype: 数据类型,如 torch.float16 或 torch.bfloat16
+        self.load_in_4bit = load_in_4bit         # load_in_4bit: 是否以4位精度加载模型,用于节省显存
+        self.lora_rank=lora_rank  #Larger rank = smarter, but slower
+
+    def load_model(self):
+        # 加载预训练模型和分词器
+        model, tokenizer = FastLanguageModel.from_pretrained(
+            model_name=self.model_name,
+            max_seq_length=self.max_seq_length,
+            load_in_4bit=self.load_in_4bit, # 值为True 以 4 bit量化进行微调,为False LoRA 16bit。这将内存使用量减少了 4 倍,使我们能够在免费的 16GB 内存 GPU 中实际进行微调。4 位量化本质上将权重转换为一组有限的数字以减少内存使用量。这样做的缺点是准确度会下降 1-2%。如果您想要这种微小的额外准确度,请在较大的 GPU(如 H100)上将其设置为 False。
+            dtype=self.dtype,
+            fast_inference = False, # Enable vLLM fast inference
+            max_lora_rank = lora_rank,
+            gpu_memory_utilization=0.6,# Reduce if out of memory
+        )
+        
+        # 添加 LoRA 适配器
+        model = FastLanguageModel.get_peft_model(
+            model,
+            max_seq_length=self.max_seq_length,  # 最大上下文(序列)长度
+            r=16,  # LoRA 的秩,控制适配器的复杂度
+            target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
+                          "gate_proj", "up_proj", "down_proj"],  # 应用 LoRA 的目标模块
+            lora_alpha=16,  # LoRA 的 alpha 参数,控制适配器的缩放
+            lora_dropout=0,  # LoRA 的 dropout 率,设置为0以优化性能
+            bias="none",     # 是否在 LoRA 中添加偏置,设置为 "none" 以优化性能
+            use_gradient_checkpointing="unsloth",  # 使用梯度检查点以节省显存,对于非常长的上下文,使用 True 或 "unsloth"
+            random_state=3407,  # 随机种子,确保实验可复现
+            use_rslora=False,  # 是否使用 rank stabilized LoRA,当前不支持
+            loftq_config=None,  # LoftQ 配置,当前不支持
+        )
+
+        return model, tokenizer
+
+    def load_data(self, train_data_path):
+        # 加载训练集和测试集
+        train_dataset = load_dataset("json", data_files={"train": train_data_path}, split="train")
+        # train_data_path: 训练数据路径,格式为 JSONL
+        return train_dataset
+
+    def train(self, model, tokenizer, train_dataset):
+        print("is_bfloat16_supported()=",is_bfloat16_supported())
+        training_args = GRPOConfig(
+            use_vllm = True, # use vLLM for fast inference!
+            learning_rate = 5e-6,
+            adam_beta1 = 0.9,
+            adam_beta2 = 0.99,
+            weight_decay = 0.1,
+            warmup_ratio = 0.1,
+            lr_scheduler_type = "cosine",
+            optim = "adamw_8bit",
+            logging_steps = 1,
+            bf16 = is_bfloat16_supported(),
+            fp16 = not is_bfloat16_supported(),
+            per_device_train_batch_size = 1,
+            gradient_accumulation_steps = 1, # Increase to 4 for smoother training
+            num_generations = 8, # Decrease if out of memory
+            max_prompt_length = 256,
+            max_completion_length = 200,
+            # num_train_epochs = 1, # Set to 1 for a full training run
+            max_steps = 250,
+            save_steps = 250,
+            max_grad_norm = 0.1,
+            report_to = "none", # Can use Weights & Biases
+            output_dir = "outputs",
+        )
+
+        # 初始化 SFTTrainer
+        trainer = GRPOTrainer(
+            model = model,
+            processing_class = tokenizer,
+            reward_funcs = [
+                xmlcount_reward_func,
+                soft_format_reward_func,
+                strict_format_reward_func,
+                int_reward_func,
+                correctness_reward_func,
+            ],
+            args = training_args,
+            train_dataset = dataset,
+        )
+        
+        # 训练模型
+        trainer.train()
+        
+        return model
+
+    def save_model(self, model, tokenizer, save_path):
+        # 保存模型和分词器
+        model.save_pretrained(save_path)
+        tokenizer.save_pretrained(save_path)
+        print(f"Model saved to {save_path}")
+
+if __name__ == "__main__":
+    # 配置参数
+    model_name = os.path.join('..', 'models', 'pretrained', 'DeepSeek-R1-Distill-Qwen-1.5B')
+    # model_name: 预训练模型的路径
+    max_seq_length = 2048  # 最大序列长度
+    dtype = torch.float16  # 数据类型
+    load_in_4bit = True  # 是否以4位精度加载模型
+    lora_rank=32
+
+    # 定义训练集和测试集路径
+    train_data_path = os.path.join('..', 'data', 'processed', 'train.jsonl')
+    # train_data_path: 训练数据路径
+
+    # 初始化 ModelTrainer
+    trainer = ModelTrainer(model_name, max_seq_length, dtype, load_in_4bit)
+    
+    # 加载模型和分词器
+    model, tokenizer = trainer.load_model()
+
+    # 加载数据集
+    train_dataset = trainer.load_data(train_data_path)
+
+    # 训练模型
+    model = trainer.train(model, tokenizer, train_dataset)
+
+    # 保存模型
+    save_path = os.path.join('..', 'models', 'trained', 'DeepSeek-R1-Distill-Qwen-1.5B')
+    trainer.save_model(model, tokenizer, save_path)

+ 1 - 1
src/unsloth_compiled_cache/UnslothAlignPropTrainer.py

@@ -120,7 +120,7 @@ class UnslothAlignPropConfig(AlignPropConfig):
     )
     def __init__(
         self,
-        exp_name = 'inference',
+        exp_name = 'train_model_grpo',
         run_name = '',
         seed = 3407,
         log_with = None,

+ 1 - 1
src/unsloth_compiled_cache/UnslothDDPOTrainer.py

@@ -136,7 +136,7 @@ class UnslothDDPOConfig(DDPOConfig):
     )
     def __init__(
         self,
-        exp_name = 'inference',
+        exp_name = 'train_model_grpo',
         run_name = '',
         seed = 3407,
         log_with = None,

BIN
src/unsloth_compiled_cache/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothDDPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-311.pyc


BIN
src/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-311.pyc