SDXL代码阅读-Sampler推理(EDM)-1.2
本文紧接上一节的sampler讲解采样器的第三个部分,采样器方法本身。
Sampler
主要的方法在BaseDiffusionSampler类中(init、denoise、prepare_sampling_loop),会基于一些不同的采样方法有一些子类。
BaseDiffusionSampler
# 提供默认值的机制
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class BaseDiffusionSampler:
def __init__(
self,
discretization,
guider,
num_steps: Union[int, None] = None,
verbose: bool = False,
device: str = "cuda",
):
self.num_steps = num_steps
self.discretization = discretization
self.guider = guider
self.verbose = verbose
self.device = device
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
'''
用于准备迭代采样的元素
x: 输入
cond: 条件输入
uc: 无条件输入,可能用于一些特殊的生成设置或对比实验。
'''
# 通过调度器生成噪声的标准差sigma
sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
uc = default(uc, cond)
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)
# s_in 作为一个全 1 张量,通常用于对 sigma 进行缩放操作。在采样过程中,s_in 与 sigma 相乘,确保 sigma 在批次维度上正确广播。
# 在采样步骤中,s_in 确保了每一步的 sigma 值可以正确地应用于整个批次的输入张量 x。(来自于chatgpt)
s_in = x.new_ones([x.shape[0]])
return x, s_in, sigmas, num_sigmas, cond, uc
def denoise(self, x, denoiser, sigma, cond, uc):
'''
用guider指导消噪的过程
'''
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
denoised = self.guider(denoised, sigma)
return denoised
def get_sigma_gen(self, num_sigmas):
sigma_generator = range(num_sigmas - 1)
if self.verbose:
print("#" * 30, " Sampling setting ", "#" * 30)
print(f"Sampler: {self.__class__.__name__}")
print(f"Discretization: {self.discretization.__class__.__name__}")
print(f"Guider: {self.guider.__class__.__name__}")
sigma_generator = tqdm(
sigma_generator,
total=num_sigmas,
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
)
return sigma_generator
单步采样
定义了两个方法,一个是sample_step,这个由具体子类实现;一个是euler_step,这个是欧拉法解ODE。
class SingleStepDiffusionSampler(BaseDiffusionSampler):
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
raise NotImplementedError
def euler_step(self, x, d, dt):
return x + dt * d
EDM Sampler
在EDM Sampler中,最重要的方法就是sample_step。
在许多基于扩散模型的采样方法中,噪声标准差 是时间的一个函数。例如,在某些扩散模型中, 是时间 t 的单调函数。通过这种方式, 实际上隐含地编码了时间信息。因此, 的变化量可以用来近似时间步长。
具体来说,如果我们假设 是时间 t 的单调函数,那么 可以表示为 ,并且 可以表示为 。因此,next_sigma - sigma_hat 实际上是 ,这可以用来近似时间步长 Δt。这种方法简化了时间步长的计算,同时保持了采样过程的稳定性。(来自于GPT)
class EDMSampler(SingleStepDiffusionSampler):
def __init__(
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.s_churn = s_churn
self.s_tmin = s_tmin
self.s_tmax = s_tmax
self.s_noise = s_noise
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
"""
sigma: 当前步的噪声标准差
next_sigma: 下一步的噪声标准差
denoiser: 消噪器
x: 当前输入
cond: 条件输入
uc: 无条件输入
gamma: 调节参数, 用于给输入x增加一些随机噪声
"""
# 计算调节后的sigma_hat
sigma_hat = sigma * (gamma + 1.0)
# 如果gamma > 0, 也就是sigma受到了调节,就对输入x增加一些随机扰动
# x + eps * sigma_hat ** 0.5
if gamma > 0:
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
d = to_d(x, sigma_hat, denoised)
dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt)
x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
# 计算gamma的值
for i in self.get_sigma_gen(num_sigmas):
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
cond,
uc,
gamma,
)
return x
class EulerEDMSampler(EDMSampler):
def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
return euler_step