|
|
@@ -1,15 +1,8 @@
|
|
|
-"""
|
|
|
-2025.3.3
|
|
|
-2025.3.5
|
|
|
-4.49.0
|
|
|
-0.15.2
|
|
|
-__UNSLOTH_VERSIONING__
|
|
|
-"""
|
|
|
from torch import Tensor
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.nn import functional as F
|
|
|
-from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_liger_kernel_available, is_peft_available, is_wandb_available, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, Callable, ConstantLengthDataset, DataCollator, Dataset, IterableDataset, Optional, Union, os, pack_examples, transformers, os)
|
|
|
+from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PartialState, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclasses, defaultdict, deprecate_kwarg, generate_model_card, get_comet_experiment_url, get_peft_model, is_conversational, is_liger_kernel_available, is_peft_available, is_wandb_available, maybe_apply_chat_template, maybe_convert_to_chatml, nn, os, pack_examples, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, warnings, os)
|
|
|
|
|
|
|
|
|
import os
|
|
|
@@ -618,89 +611,117 @@ class _UnslothSFTTrainer(Trainer):
|
|
|
def _prepare_dataset(
|
|
|
self,
|
|
|
dataset: Union[Dataset, IterableDataset],
|
|
|
- processing_class,
|
|
|
- args,
|
|
|
+ processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
|
|
|
+ args: SFTConfig,
|
|
|
packing: bool,
|
|
|
formatting_func: Optional[Callable[[dict], str]],
|
|
|
dataset_name: str,
|
|
|
) -> Union[Dataset, IterableDataset]:
|
|
|
- # All Unsloth Zoo code licensed under LGPLv3
|
|
|
- if isinstance(dataset, ConstantLengthDataset): return dataset
|
|
|
-
|
|
|
- map_kwargs = {}
|
|
|
- use_desc = isinstance(dataset, Dataset)
|
|
|
-
|
|
|
- # Get max length
|
|
|
- max_seq_length = getattr(args, "max_length", 0)
|
|
|
- if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
|
|
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
|
|
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
|
|
- dataset_text_field = getattr(args, "dataset_text_field", "text")
|
|
|
- do_truncation = max_seq_length != 0
|
|
|
- do_formatting_func = False
|
|
|
-
|
|
|
- # Check if already tokenized so skip
|
|
|
- from transformers import DataCollatorForSeq2Seq
|
|
|
- column_names = set(next(iter(dataset)).keys())
|
|
|
- if "input_ids" in column_names:
|
|
|
- # Most likely forgot data collator!
|
|
|
- from transformers import DataCollatorForSeq2Seq
|
|
|
- self.data_collator = DataCollatorForSeq2Seq(processing_class)
|
|
|
+ # Convert the dataset to an IterableDataset if it is a ConstantLengthDataset
|
|
|
+ if isinstance(dataset, ConstantLengthDataset):
|
|
|
return dataset
|
|
|
- elif dataset_text_field not in column_names:
|
|
|
- do_formatting_func = True
|
|
|
- if formatting_func is None:
|
|
|
- raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
|
|
- pass
|
|
|
-
|
|
|
- # Check double BOS tokens
|
|
|
- if do_formatting_func:
|
|
|
- test_text = formatting_func(dataset[0])
|
|
|
- if not isinstance(test_text, list):
|
|
|
- raise ValueError(
|
|
|
- "Unsloth: The `formatting_func` should return a list of processed strings."
|
|
|
+
|
|
|
+ # If the dataset is already preprocessed (tokenized), skip the processing steps.
|
|
|
+ column_names = list(next(iter(dataset)).keys())
|
|
|
+ is_processed = "input_ids" in column_names
|
|
|
+
|
|
|
+ # Build the kwargs for the `map` function
|
|
|
+ map_kwargs = {}
|
|
|
+ if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
|
|
+ map_kwargs["num_proc"] = args.dataset_num_proc
|
|
|
+
|
|
|
+ with PartialState().local_main_process_first():
|
|
|
+ # Apply the formatting function if any
|
|
|
+ if formatting_func is not None and is_processed:
|
|
|
+ warnings.warn(
|
|
|
+ "You passed a dataset that is already processed (contains an `input_ids` field) together with a "
|
|
|
+ "formatting function. Therefore `formatting_func` will be ignored. Either remove the "
|
|
|
+ "`formatting_func` or pass a dataset that is not already processed.",
|
|
|
+ UserWarning,
|
|
|
)
|
|
|
- test_text = test_text[0]
|
|
|
- else:
|
|
|
- test_text = dataset[0][dataset_text_field]
|
|
|
- chat_template = getattr(processing_class, 'chat_template', None)
|
|
|
- chat_template = '' if chat_template is None else chat_template
|
|
|
- add_special_tokens = True
|
|
|
-
|
|
|
- if getattr(processing_class, 'bos_token', None) is not None:
|
|
|
- if test_text.startswith(processing_class.bos_token) or processing_class.bos_token in chat_template:
|
|
|
- add_special_tokens = False
|
|
|
- print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
|
|
- pass
|
|
|
-
|
|
|
- # Create tokenize function
|
|
|
- def _tokenize(example):
|
|
|
- return processing_class(
|
|
|
- example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
|
|
- truncation = do_truncation,
|
|
|
- max_length = max_seq_length,
|
|
|
- return_token_type_ids = False,
|
|
|
- add_special_tokens = add_special_tokens,
|
|
|
+
|
|
|
+ if formatting_func is not None and not is_processed:
|
|
|
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
|
|
+ map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
|
|
|
+
|
|
|
+ batched = isinstance(formatting_func(next(iter(dataset))), list)
|
|
|
+
|
|
|
+ def _func(example):
|
|
|
+ return {"text": formatting_func(example)}
|
|
|
+
|
|
|
+ dataset = dataset.map(_func, batched=batched, **map_kwargs)
|
|
|
+
|
|
|
+ # If the dataset is prompt-completion, convert it to language modeling type
|
|
|
+ if "prompt" in dataset.column_names and "completion" in dataset.column_names:
|
|
|
+ key = "messages" if is_conversational(dataset[0]) else "text"
|
|
|
+
|
|
|
+ def concat_prompt_completion(example):
|
|
|
+ return {key: example["prompt"] + example["completion"]}
|
|
|
+
|
|
|
+ dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
|
|
+
|
|
|
+ # Convert the dataset to ChatML if needed
|
|
|
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
|
|
+ map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
|
|
|
+ dataset = dataset.map(
|
|
|
+ maybe_convert_to_chatml,
|
|
|
+ remove_columns="conversations" if "conversations" in dataset.column_names else None,
|
|
|
+ **map_kwargs,
|
|
|
)
|
|
|
- pass
|
|
|
-
|
|
|
- map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
|
|
- if use_desc: map_kwargs["desc"] = f'Tokenizing to ["{dataset_text_field}"]'
|
|
|
- dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
|
|
|
-
|
|
|
- if packing:
|
|
|
- if max_seq_length == 0:
|
|
|
- raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
|
|
-
|
|
|
- if use_desc: map_kwargs["desc"] = f"Packing {dataset_name} dataset"
|
|
|
- dataset = dataset.select_columns("input_ids").map(
|
|
|
- pack_examples,
|
|
|
- batched = True,
|
|
|
- fn_kwargs = {"seq_length": max_seq_length,},
|
|
|
+
|
|
|
+ # Apply the chat template if needed
|
|
|
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
|
|
+ map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
|
|
+ dataset = dataset.map(
|
|
|
+ maybe_apply_chat_template,
|
|
|
+ fn_kwargs={"tokenizer": processing_class},
|
|
|
+ remove_columns="messages" if "messages" in dataset.column_names else None, # renamed to "text"
|
|
|
**map_kwargs,
|
|
|
)
|
|
|
+
|
|
|
+ # Tokenize the dataset if needed
|
|
|
+ if not is_processed:
|
|
|
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
|
|
+ map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
|
|
+
|
|
|
+ def tokenize(example, processing_class, dataset_text_field):
|
|
|
+ return processing_class(example[dataset_text_field])
|
|
|
+
|
|
|
+ dataset = dataset.map(
|
|
|
+ tokenize,
|
|
|
+ fn_kwargs={"processing_class": processing_class, "dataset_text_field": args.dataset_text_field},
|
|
|
+ **map_kwargs,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Pack or truncate
|
|
|
+ if packing:
|
|
|
+ if args.max_seq_length is None:
|
|
|
+ raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
|
|
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
|
|
+ map_kwargs["desc"] = f"Packing {dataset_name} dataset"
|
|
|
+ dataset = dataset.select_columns("input_ids")
|
|
|
+ dataset = dataset.map(
|
|
|
+ pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
|
|
|
+ )
|
|
|
+ elif args.max_seq_length is not None:
|
|
|
+ if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
|
|
+ map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
|
|
|
+
|
|
|
+ def truncate(example, max_seq_length):
|
|
|
+ return {key: example[key][:max_seq_length] for key in ["input_ids", "attention_mask"]}
|
|
|
+
|
|
|
+ dataset = dataset.map(
|
|
|
+ truncate,
|
|
|
+ fn_kwargs={"max_seq_length": args.max_seq_length},
|
|
|
+ **map_kwargs,
|
|
|
+ )
|
|
|
+
|
|
|
+ # For Liger kernel, ensure only input_ids is present
|
|
|
+ if args.use_liger:
|
|
|
+ dataset = dataset.select_columns("input_ids")
|
|
|
+
|
|
|
return dataset
|
|
|
-
|
|
|
+
|
|
|
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
|
|
outputs = super().compute_loss(
|
|
|
model,
|