Skip to content

Commit

Permalink
Add beta sigmas to other schedulers and update docs (#9538)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky authored Sep 30, 2024
1 parent f9fd511 commit c4a8979
Show file tree
Hide file tree
Showing 12 changed files with 551 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/source/en/api/schedulers/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
| sgm_uniform | init with `timestep_spacing="trailing"` |
| simple | init with `timestep_spacing="trailing"` |
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |

All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.

Expand Down
53 changes: 50 additions & 3 deletions src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils import deprecate, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


if is_scipy_available():
import scipy.stats


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
Expand Down Expand Up @@ -113,6 +117,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
Expand Down Expand Up @@ -141,11 +148,16 @@ def __init__(
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
Expand Down Expand Up @@ -263,6 +275,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
Expand Down Expand Up @@ -396,6 +411,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None

if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None

sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas

def convert_model_output(
self,
model_output: torch.Tensor,
Expand Down
55 changes: 52 additions & 3 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils import deprecate, is_scipy_available
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


if is_scipy_available():
import scipy.stats


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
Expand Down Expand Up @@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
Expand Down Expand Up @@ -209,6 +216,7 @@ def __init__(
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
Expand All @@ -217,8 +225,12 @@ def __init__(
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
Expand Down Expand Up @@ -337,6 +349,8 @@ def set_timesteps(
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if timesteps is not None and self.config.use_beta_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")

if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
Expand Down Expand Up @@ -388,6 +402,9 @@ def set_timesteps(
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)

Expand Down Expand Up @@ -542,6 +559,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None

if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None

sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas

def convert_model_output(
self,
model_output: torch.Tensor,
Expand Down
54 changes: 51 additions & 3 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from ..utils import deprecate, is_scipy_available
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


if is_scipy_available():
import scipy.stats


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
Expand Down Expand Up @@ -126,6 +130,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
Expand Down Expand Up @@ -161,13 +168,18 @@ def __init__(
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
Expand Down Expand Up @@ -219,6 +231,7 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.use_karras_sigmas = use_karras_sigmas
self.use_exponential_sigmas = use_exponential_sigmas
self.use_beta_sigmas = use_beta_sigmas

@property
def step_index(self):
Expand Down Expand Up @@ -276,6 +289,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
Expand Down Expand Up @@ -416,6 +432,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""

# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None

if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None

sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

sigmas = torch.Tensor(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output(
self,
Expand Down
Loading

0 comments on commit c4a8979

Please sign in to comment.