SDXL代码阅读-Sampler推理(EDM)-1.2

本文紧接上一节的sampler讲解采样器的第三个部分,采样器方法本身。

Sampler

主要的方法在BaseDiffusionSampler类中(init、denoise、prepare_sampling_loop),会基于一些不同的采样方法有一些子类。

BaseDiffusionSampler

# 提供默认值的机制
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class BaseDiffusionSampler:
    def __init__(
        self,
        discretization,
        guider,
        num_steps: Union[int, None] = None,
        verbose: bool = False,
        device: str = "cuda",
    ):
        self.num_steps = num_steps
        self.discretization = discretization
        self.guider = guider
        self.verbose = verbose
        self.device = device

    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
        '''
            用于准备迭代采样的元素
            x: 输入
            cond: 条件输入
            uc: 无条件输入,可能用于一些特殊的生成设置或对比实验。
        '''
        # 通过调度器生成噪声的标准差sigma
        sigmas = self.discretization(
            self.num_steps if num_steps is None else num_steps, device=self.device
        )
        uc = default(uc, cond)

        x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
        num_sigmas = len(sigmas)

        # s_in 作为一个全 1 张量,通常用于对 sigma 进行缩放操作。在采样过程中,s_in 与 sigma 相乘,确保 sigma 在批次维度上正确广播。
        # 在采样步骤中,s_in 确保了每一步的 sigma 值可以正确地应用于整个批次的输入张量 x。(来自于chatgpt)
        s_in = x.new_ones([x.shape[0]])

        return x, s_in, sigmas, num_sigmas, cond, uc

    def denoise(self, x, denoiser, sigma, cond, uc):
        '''
            用guider指导消噪的过程
        '''
        denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
        denoised = self.guider(denoised, sigma)
        return denoised

    def get_sigma_gen(self, num_sigmas):
        sigma_generator = range(num_sigmas - 1)
        if self.verbose:
            print("#" * 30, " Sampling setting ", "#" * 30)
            print(f"Sampler: {self.__class__.__name__}")
            print(f"Discretization: {self.discretization.__class__.__name__}")
            print(f"Guider: {self.guider.__class__.__name__}")
            sigma_generator = tqdm(
                sigma_generator,
                total=num_sigmas,
                desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
            )
        return sigma_generator

单步采样

定义了两个方法,一个是sample_step,这个由具体子类实现;一个是euler_step,这个是欧拉法解ODE。

class SingleStepDiffusionSampler(BaseDiffusionSampler):
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
        raise NotImplementedError

    def euler_step(self, x, d, dt):
        return x + dt * d

EDM Sampler

在EDM Sampler中,最重要的方法就是sample_step。

σ^=σ×(γ+1)x=x+ϵ×σ^2,ifγ>0μθ=denoiser(x,σ^,c,uc)d=xμθσ^dt=σnextσ^xnext=x+d×dt\hat\sigma = \sigma \times (\gamma + 1) \\ x = x + \epsilon \times \sqrt{\hat\sigma^2} , \quad if \quad \gamma > 0 \\ \mu_{\theta} = denoiser(x, \hat\sigma, c, uc) \\ d = \frac {x - \mu_{\theta}} {\hat\sigma} \\ dt = \sigma_{next} - \hat\sigma x_{next} = x + d \times dt

在许多基于扩散模型的采样方法中,噪声标准差 σ\sigma 是时间的一个函数。例如,在某些扩散模型中,σ\sigma 是时间 t 的单调函数。通过这种方式,σ\sigma 实际上隐含地编码了时间信息。因此,σ\sigma 的变化量可以用来近似时间步长。

具体来说,如果我们假设 σ\sigma 是时间 t 的单调函数,那么 σ(t)\sigma(t) 可以表示为 σt\sigma_t,并且 σ(t+Δt)\sigma(t + Δt) 可以表示为 σt+Δt\sigma_{t+Δt}。因此,next_sigma - sigma_hat 实际上是 σt+Δtσt\sigma_{t+Δt} - \sigma_t,这可以用来近似时间步长 Δt。这种方法简化了时间步长的计算,同时保持了采样过程的稳定性。(来自于GPT)

class EDMSampler(SingleStepDiffusionSampler):
    def __init__(
        self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)

        self.s_churn = s_churn
        self.s_tmin = s_tmin
        self.s_tmax = s_tmax
        self.s_noise = s_noise

    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
        """
            sigma: 当前步的噪声标准差
            next_sigma: 下一步的噪声标准差
            denoiser: 消噪器
            x: 当前输入
            cond: 条件输入
            uc: 无条件输入
            gamma: 调节参数, 用于给输入x增加一些随机噪声
        """
        
        # 计算调节后的sigma_hat
        sigma_hat = sigma * (gamma + 1.0)
        # 如果gamma > 0, 也就是sigma受到了调节,就对输入x增加一些随机扰动
        # x + eps * sigma_hat ** 0.5
        if gamma > 0:
            eps = torch.randn_like(x) * self.s_noise
            x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5

        denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
        d = to_d(x, sigma_hat, denoised)
        dt = append_dims(next_sigma - sigma_hat, x.ndim)

        euler_step = self.euler_step(x, d, dt)
        x = self.possible_correction_step(
            euler_step, x, d, dt, next_sigma, denoiser, cond, uc
        )
        return x

    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )
        # 计算gamma的值
        for i in self.get_sigma_gen(num_sigmas):
            gamma = (
                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
                if self.s_tmin <= sigmas[i] <= self.s_tmax
                else 0.0
            )
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc,
                gamma,
            )

        return x

class EulerEDMSampler(EDMSampler):
    def possible_correction_step(
        self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
    ):
        return euler_step