UnslothAlignPropTrainer.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. from torch import Tensor
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn import functional as F
  5. from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
  6. import os
  7. from typing import *
  8. from dataclasses import dataclass, field
  9. from packaging.version import Version
  10. import torch
  11. import numpy as np
  12. from contextlib import nullcontext
  13. from torch.nn import functional as F
  14. torch_compile_options = {
  15. "epilogue_fusion" : True,
  16. "max_autotune" : False,
  17. "shape_padding" : True,
  18. "trace.enabled" : False,
  19. "triton.cudagraphs" : False,
  20. }
  21. @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
  22. def selective_log_softmax(logits, index):
  23. logits = logits.to(torch.float32)
  24. selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
  25. # loop to reduce peak mem consumption
  26. # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
  27. logsumexp_values = torch.logsumexp(logits, dim = -1)
  28. per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
  29. return per_token_logps
  30. @dataclass
  31. class UnslothAlignPropConfig(AlignPropConfig):
  32. """
  33. Configuration class for the [`AlignPropTrainer`].
  34. Using [`~transformers.HfArgumentParser`] we can turn this class into
  35. [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
  36. command line.
  37. Parameters:
  38. exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
  39. Name of this experiment (defaults to the file name without the extension).
  40. run_name (`str`, *optional*, defaults to `""`):
  41. Name of this run.
  42. seed (`int`, *optional*, defaults to `0`):
  43. Random seed for reproducibility.
  44. log_with (`str` or `None`, *optional*, defaults to `None`):
  45. Log with either `"wandb"` or `"tensorboard"`. Check
  46. [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
  47. log_image_freq (`int`, *optional*, defaults to `1`):
  48. Frequency for logging images.
  49. tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
  50. Keyword arguments for the tracker (e.g., `wandb_project`).
  51. accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
  52. Keyword arguments for the accelerator.
  53. project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
  54. Keyword arguments for the accelerator project config (e.g., `logging_dir`).
  55. tracker_project_name (`str`, *optional*, defaults to `"trl"`):
  56. Name of project to use for tracking.
  57. logdir (`str`, *optional*, defaults to `"logs"`):
  58. Top-level logging directory for checkpoint saving.
  59. num_epochs (`int`, *optional*, defaults to `100`):
  60. Number of epochs to train.
  61. save_freq (`int`, *optional*, defaults to `1`):
  62. Number of epochs between saving model checkpoints.
  63. num_checkpoint_limit (`int`, *optional*, defaults to `5`):
  64. Number of checkpoints to keep before overwriting old ones.
  65. mixed_precision (`str`, *optional*, defaults to `"fp16"`):
  66. Mixed precision training.
  67. allow_tf32 (`bool`, *optional*, defaults to `True`):
  68. Allow `tf32` on Ampere GPUs.
  69. resume_from (`str`, *optional*, defaults to `""`):
  70. Path to resume training from a checkpoint.
  71. sample_num_steps (`int`, *optional*, defaults to `50`):
  72. Number of sampler inference steps.
  73. sample_eta (`float`, *optional*, defaults to `1.0`):
  74. Eta parameter for the DDIM sampler.
  75. sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
  76. Classifier-free guidance weight.
  77. train_batch_size (`int`, *optional*, defaults to `1`):
  78. Batch size for training.
  79. train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
  80. Whether to use the 8bit Adam optimizer from `bitsandbytes`.
  81. train_learning_rate (`float`, *optional*, defaults to `1e-3`):
  82. Learning rate.
  83. train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
  84. Beta1 for Adam optimizer.
  85. train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
  86. Beta2 for Adam optimizer.
  87. train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
  88. Weight decay for Adam optimizer.
  89. train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
  90. Epsilon value for Adam optimizer.
  91. train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
  92. Number of gradient accumulation steps.
  93. train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
  94. Maximum gradient norm for gradient clipping.
  95. negative_prompts (`str` or `None`, *optional*, defaults to `None`):
  96. Comma-separated list of prompts to use as negative examples.
  97. truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
  98. If `True`, randomized truncation to different diffusion timesteps is used.
  99. truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
  100. Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
  101. truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
  102. Range of diffusion timesteps for randomized truncated backpropagation.
  103. push_to_hub (`bool`, *optional*, defaults to `False`):
  104. Whether to push the final model to the Hub.
  105. """
  106. vllm_sampling_params: Optional[Any] = field(
  107. default = None,
  108. metadata = {'help': 'vLLM SamplingParams'},
  109. )
  110. unsloth_num_chunks : Optional[int] = field(
  111. default = -1,
  112. metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
  113. )
  114. def __init__(
  115. self,
  116. exp_name = 'train_model_grpo',
  117. run_name = '',
  118. seed = 3407,
  119. log_with = None,
  120. log_image_freq = 1,
  121. tracker_project_name = 'trl',
  122. logdir = 'logs',
  123. num_epochs = 100,
  124. save_freq = 1,
  125. num_checkpoint_limit = 5,
  126. mixed_precision = 'fp16',
  127. allow_tf32 = True,
  128. resume_from = '',
  129. sample_num_steps = 50,
  130. sample_eta = 1.0,
  131. sample_guidance_scale = 5.0,
  132. train_batch_size = 1,
  133. train_use_8bit_adam = False,
  134. train_learning_rate = 5e-05,
  135. train_adam_beta1 = 0.9,
  136. train_adam_beta2 = 0.999,
  137. train_adam_weight_decay = 0.01,
  138. train_adam_epsilon = 1e-08,
  139. train_gradient_accumulation_steps = 2,
  140. train_max_grad_norm = 1.0,
  141. negative_prompts = None,
  142. truncated_backprop_rand = True,
  143. truncated_backprop_timestep = 49,
  144. push_to_hub = False,
  145. vllm_sampling_params = None,
  146. unsloth_num_chunks = -1,
  147. **kwargs,
  148. ):
  149. super().__init__(
  150. exp_name = exp_name,
  151. run_name = run_name,
  152. seed = seed,
  153. log_with = log_with,
  154. log_image_freq = log_image_freq,
  155. tracker_project_name = tracker_project_name,
  156. logdir = logdir,
  157. num_epochs = num_epochs,
  158. save_freq = save_freq,
  159. num_checkpoint_limit = num_checkpoint_limit,
  160. mixed_precision = mixed_precision,
  161. allow_tf32 = allow_tf32,
  162. resume_from = resume_from,
  163. sample_num_steps = sample_num_steps,
  164. sample_eta = sample_eta,
  165. sample_guidance_scale = sample_guidance_scale,
  166. train_batch_size = train_batch_size,
  167. train_use_8bit_adam = train_use_8bit_adam,
  168. train_learning_rate = train_learning_rate,
  169. train_adam_beta1 = train_adam_beta1,
  170. train_adam_beta2 = train_adam_beta2,
  171. train_adam_weight_decay = train_adam_weight_decay,
  172. train_adam_epsilon = train_adam_epsilon,
  173. train_gradient_accumulation_steps = train_gradient_accumulation_steps,
  174. train_max_grad_norm = train_max_grad_norm,
  175. negative_prompts = negative_prompts,
  176. truncated_backprop_rand = truncated_backprop_rand,
  177. truncated_backprop_timestep = truncated_backprop_timestep,
  178. push_to_hub = push_to_hub,**kwargs)
  179. self.vllm_sampling_params = vllm_sampling_params
  180. self.unsloth_num_chunks = unsloth_num_chunks
  181. pass
  182. class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
  183. """"""
  184. _tag_names = ["trl", "alignprop"]
  185. def __init__(
  186. self,
  187. config: AlignPropConfig,
  188. reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
  189. prompt_function: Callable[[], tuple[str, Any]],
  190. sd_pipeline: DDPOStableDiffusionPipeline,
  191. image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
  192. ):
  193. if image_samples_hook is None:
  194. warn("No image_samples_hook provided; no images will be logged")
  195. self.prompt_fn = prompt_function
  196. self.reward_fn = reward_function
  197. self.config = config
  198. self.image_samples_callback = image_samples_hook
  199. accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
  200. if self.config.resume_from:
  201. self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
  202. if "checkpoint_" not in os.path.basename(self.config.resume_from):
  203. # get the most recent checkpoint in this directory
  204. checkpoints = list(
  205. filter(
  206. lambda x: "checkpoint_" in x,
  207. os.listdir(self.config.resume_from),
  208. )
  209. )
  210. if len(checkpoints) == 0:
  211. raise ValueError(f"No checkpoints found in {self.config.resume_from}")
  212. checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
  213. self.config.resume_from = os.path.join(
  214. self.config.resume_from,
  215. f"checkpoint_{checkpoint_numbers[-1]}",
  216. )
  217. accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
  218. self.accelerator = Accelerator(
  219. log_with=self.config.log_with,
  220. mixed_precision=self.config.mixed_precision,
  221. project_config=accelerator_project_config,
  222. # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
  223. # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
  224. # the total number of optimizer steps to accumulate across.
  225. gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
  226. **self.config.accelerator_kwargs,
  227. )
  228. is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
  229. if self.accelerator.is_main_process:
  230. self.accelerator.init_trackers(
  231. self.config.tracker_project_name,
  232. config=dict(alignprop_trainer_config=config.to_dict())
  233. if not is_using_tensorboard
  234. else config.to_dict(),
  235. init_kwargs=self.config.tracker_kwargs,
  236. )
  237. logger.info(f"\n{config}")
  238. set_seed(self.config.seed, device_specific=True)
  239. self.sd_pipeline = sd_pipeline
  240. self.sd_pipeline.set_progress_bar_config(
  241. position=1,
  242. disable=not self.accelerator.is_local_main_process,
  243. leave=False,
  244. desc="Timestep",
  245. dynamic_ncols=True,
  246. )
  247. # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
  248. # as these weights are only used for inference, keeping weights in full precision is not required.
  249. if self.accelerator.mixed_precision == "fp16":
  250. inference_dtype = torch.float16
  251. elif self.accelerator.mixed_precision == "bf16":
  252. inference_dtype = torch.bfloat16
  253. else:
  254. inference_dtype = torch.float32
  255. self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
  256. self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
  257. self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
  258. trainable_layers = self.sd_pipeline.get_trainable_layers()
  259. self.accelerator.register_save_state_pre_hook(self._save_model_hook)
  260. self.accelerator.register_load_state_pre_hook(self._load_model_hook)
  261. # Enable TF32 for faster training on Ampere GPUs,
  262. # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
  263. if self.config.allow_tf32:
  264. torch.backends.cuda.matmul.allow_tf32 = True
  265. self.optimizer = self._setup_optimizer(
  266. trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
  267. )
  268. self.neg_prompt_embed = self.sd_pipeline.text_encoder(
  269. self.sd_pipeline.tokenizer(
  270. [""] if self.config.negative_prompts is None else self.config.negative_prompts,
  271. return_tensors="pt",
  272. padding="max_length",
  273. truncation=True,
  274. max_length=self.sd_pipeline.tokenizer.model_max_length,
  275. ).input_ids.to(self.accelerator.device)
  276. )[0]
  277. # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
  278. # more memory
  279. self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
  280. if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
  281. unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
  282. self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
  283. else:
  284. self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
  285. if config.resume_from:
  286. logger.info(f"Resuming from {config.resume_from}")
  287. self.accelerator.load_state(config.resume_from)
  288. self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
  289. else:
  290. self.first_epoch = 0
  291. def compute_rewards(self, prompt_image_pairs):
  292. reward, reward_metadata = self.reward_fn(
  293. prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
  294. )
  295. return reward
  296. def step(self, epoch: int, global_step: int):
  297. """
  298. Perform a single step of training.
  299. Args:
  300. epoch (int): The current epoch.
  301. global_step (int): The current global step.
  302. Side Effects:
  303. - Model weights are updated
  304. - Logs the statistics to the accelerator trackers.
  305. - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
  306. Returns:
  307. global_step (int): The updated global step.
  308. """
  309. info = defaultdict(list)
  310. self.sd_pipeline.unet.train()
  311. for _ in range(self.config.train_gradient_accumulation_steps):
  312. with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
  313. prompt_image_pairs = self._generate_samples(
  314. batch_size=self.config.train_batch_size,
  315. )
  316. rewards = self.compute_rewards(prompt_image_pairs)
  317. prompt_image_pairs["rewards"] = rewards
  318. rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
  319. loss = self.calculate_loss(rewards)
  320. self.accelerator.backward(loss)
  321. if self.accelerator.sync_gradients:
  322. self.accelerator.clip_grad_norm_(
  323. self.trainable_layers.parameters()
  324. if not isinstance(self.trainable_layers, list)
  325. else self.trainable_layers,
  326. self.config.train_max_grad_norm,
  327. )
  328. self.optimizer.step()
  329. self.optimizer.zero_grad()
  330. info["reward_mean"].append(rewards_vis.mean())
  331. info["reward_std"].append(rewards_vis.std())
  332. info["loss"].append(loss.item())
  333. # Checks if the accelerator has performed an optimization step behind the scenes
  334. if self.accelerator.sync_gradients:
  335. # log training-related stuff
  336. info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
  337. info = self.accelerator.reduce(info, reduction="mean")
  338. info.update({"epoch": epoch})
  339. self.accelerator.log(info, step=global_step)
  340. global_step += 1
  341. info = defaultdict(list)
  342. else:
  343. raise ValueError(
  344. "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
  345. )
  346. # Logs generated images
  347. if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
  348. self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
  349. if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
  350. self.accelerator.save_state()
  351. return global_step
  352. def calculate_loss(self, rewards):
  353. """
  354. Calculate the loss for a batch of an unpacked sample
  355. Args:
  356. rewards (torch.Tensor):
  357. Differentiable reward scalars for each generated image, shape: [batch_size]
  358. Returns:
  359. loss (torch.Tensor)
  360. (all of these are of shape (1,))
  361. """
  362. # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
  363. loss = 10.0 - (rewards).mean()
  364. return loss
  365. def loss(
  366. self,
  367. advantages: torch.Tensor,
  368. clip_range: float,
  369. ratio: torch.Tensor,
  370. ):
  371. unclipped_loss = -advantages * ratio
  372. clipped_loss = -advantages * torch.clamp(
  373. ratio,
  374. 1.0 - clip_range,
  375. 1.0 + clip_range,
  376. )
  377. return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
  378. def _setup_optimizer(self, trainable_layers_parameters):
  379. if self.config.train_use_8bit_adam:
  380. import bitsandbytes
  381. optimizer_cls = bitsandbytes.optim.AdamW8bit
  382. else:
  383. optimizer_cls = torch.optim.AdamW
  384. return optimizer_cls(
  385. trainable_layers_parameters,
  386. lr=self.config.train_learning_rate,
  387. betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
  388. weight_decay=self.config.train_adam_weight_decay,
  389. eps=self.config.train_adam_epsilon,
  390. )
  391. def _save_model_hook(self, models, weights, output_dir):
  392. self.sd_pipeline.save_checkpoint(models, weights, output_dir)
  393. weights.pop() # ensures that accelerate doesn't try to handle saving of the model
  394. def _load_model_hook(self, models, input_dir):
  395. self.sd_pipeline.load_checkpoint(models, input_dir)
  396. models.pop() # ensures that accelerate doesn't try to handle loading of the model
  397. def _generate_samples(self, batch_size, with_grad=True, prompts=None):
  398. """
  399. Generate samples from the model
  400. Args:
  401. batch_size (int): Batch size to use for sampling
  402. with_grad (bool): Whether the generated RGBs should have gradients attached to it.
  403. Returns:
  404. prompt_image_pairs (dict[Any])
  405. """
  406. prompt_image_pairs = {}
  407. sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
  408. if prompts is None:
  409. prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
  410. else:
  411. prompt_metadata = [{} for _ in range(batch_size)]
  412. prompt_ids = self.sd_pipeline.tokenizer(
  413. prompts,
  414. return_tensors="pt",
  415. padding="max_length",
  416. truncation=True,
  417. max_length=self.sd_pipeline.tokenizer.model_max_length,
  418. ).input_ids.to(self.accelerator.device)
  419. prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
  420. if with_grad:
  421. sd_output = self.sd_pipeline.rgb_with_grad(
  422. prompt_embeds=prompt_embeds,
  423. negative_prompt_embeds=sample_neg_prompt_embeds,
  424. num_inference_steps=self.config.sample_num_steps,
  425. guidance_scale=self.config.sample_guidance_scale,
  426. eta=self.config.sample_eta,
  427. truncated_backprop_rand=self.config.truncated_backprop_rand,
  428. truncated_backprop_timestep=self.config.truncated_backprop_timestep,
  429. truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
  430. output_type="pt",
  431. )
  432. else:
  433. sd_output = self.sd_pipeline(
  434. prompt_embeds=prompt_embeds,
  435. negative_prompt_embeds=sample_neg_prompt_embeds,
  436. num_inference_steps=self.config.sample_num_steps,
  437. guidance_scale=self.config.sample_guidance_scale,
  438. eta=self.config.sample_eta,
  439. output_type="pt",
  440. )
  441. images = sd_output.images
  442. prompt_image_pairs["images"] = images
  443. prompt_image_pairs["prompts"] = prompts
  444. prompt_image_pairs["prompt_metadata"] = prompt_metadata
  445. return prompt_image_pairs
  446. def train(self, epochs: Optional[int] = None):
  447. """
  448. Train the model for a given number of epochs
  449. """
  450. global_step = 0
  451. if epochs is None:
  452. epochs = self.config.num_epochs
  453. for epoch in range(self.first_epoch, epochs):
  454. global_step = self.step(epoch, global_step)
  455. def _save_pretrained(self, save_directory):
  456. self.sd_pipeline.save_pretrained(save_directory)
  457. self.create_model_card()
  458. def create_model_card(
  459. self,
  460. model_name: Optional[str] = None,
  461. dataset_name: Optional[str] = None,
  462. tags: Union[str, list[str], None] = None,
  463. ):
  464. """
  465. Creates a draft of a model card using the information available to the `Trainer`.
  466. Args:
  467. model_name (`str` or `None`, *optional*, defaults to `None`):
  468. Name of the model.
  469. dataset_name (`str` or `None`, *optional*, defaults to `None`):
  470. Name of the dataset used for training.
  471. tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
  472. Tags to be associated with the model card.
  473. """
  474. if not self.is_world_process_zero():
  475. return
  476. if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
  477. base_model = self.model.config._name_or_path
  478. else:
  479. base_model = None
  480. tags = tags or []
  481. if isinstance(tags, str):
  482. tags = [tags]
  483. if hasattr(self.model.config, "unsloth_version"):
  484. tags.append("unsloth")
  485. citation = textwrap.dedent("""\
  486. @article{prabhudesai2024aligning,
  487. title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
  488. author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
  489. year = 2024,
  490. eprint = {arXiv:2310.03739}
  491. }""")
  492. model_card = generate_model_card(
  493. base_model=base_model,
  494. model_name=model_name,
  495. hub_model_id=self.hub_model_id,
  496. dataset_name=dataset_name,
  497. tags=tags,
  498. wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
  499. comet_url=get_comet_experiment_url(),
  500. trainer_name="AlignProp",
  501. trainer_citation=citation,
  502. paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
  503. paper_id="2310.03739",
  504. )
  505. model_card.save(os.path.join(self.args.output_dir, "README.md"))
  506. class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
  507. """
  508. The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
  509. Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
  510. As of now only Stable Diffusion based pipelines are supported
  511. Attributes:
  512. config (`AlignPropConfig`):
  513. Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
  514. reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
  515. Reward function to be used
  516. prompt_function (`Callable[[], tuple[str, Any]]`):
  517. Function to generate prompts to guide model
  518. sd_pipeline (`DDPOStableDiffusionPipeline`):
  519. Stable Diffusion pipeline to be used for training.
  520. image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
  521. Hook to be called to log images
  522. """
  523. def __init__(
  524. self,
  525. config,
  526. reward_function,
  527. prompt_function,
  528. sd_pipeline,
  529. image_samples_hook = None,
  530. **kwargs
  531. ):
  532. if args is None: args = UnslothAlignPropConfig()
  533. other_metrics = []
  534. from unsloth_zoo.logging_utils import PatchRLStatistics
  535. PatchRLStatistics('alignprop_trainer', other_metrics)
  536. super().__init__(
  537. config = config,
  538. reward_function = reward_function,
  539. prompt_function = prompt_function,
  540. sd_pipeline = sd_pipeline,
  541. image_samples_hook = image_samples_hook,**kwargs)
  542. pass