SDXL代码阅读-Loss_fn-2.5
2024-09-25
4 min read
Loss_fn
前面讲的denosier、conditioner在DiffusionEngine的forward阶段都被送入self.loss_fn进行计算,因此Loss_fn具体做了什么也是串联前面内容的重要部分。
由于Loss_fn的初始化也涉及到其他模块,先来阅读这些模块。
对应论文表格:
sigma_sampling
文件来自于generative-models/sgm/modules/diffusionmodules/sigma_sampling.py
从配置文件(sgm.modules.diffusionmodules.sigma_sampling.EDMSampling)可以知道导入的是EDMSampling,看了一下代码,实际上就是从正态分布里面采样。
class EDMSampling:
"""
通过均值和标准差来控制在对数标准正态分布里来采样
"""
def __init__(self, p_mean=-1.2, p_std=1.2):
self.p_mean = p_mean
self.p_std = p_std
def __call__(self, n_samples, rand=None):
log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
return log_sigma.exp()
loss_weighting
文件来自于generative-models/sgm/modules/diffusionmodules/loss_weighting.py
也就是完全对应论文图片
class EDMWeighting(DiffusionLossWeighting):
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
StandardDiffusionLoss
init
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config: dict,
loss_weighting_config: dict,
loss_type: str = "l2",
offset_noise_level: float = 0.0,
batch2model_keys: Optional[Union[str, List[str]]] = None,
):
super().__init__()
assert loss_type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.loss_weighting = instantiate_from_config(loss_weighting_config)
self.loss_type = loss_type
self.offset_noise_level = offset_noise_level
# lpips是一种用于衡量图像之间感知相似度的损失函数。
# 与传统的像素级损失(如 L1 或 L2 损失)不同,LPIPS 通过使用预训练的深度卷积神经网络来捕捉更接近人类视觉感知的相似性。
if loss_type == "lpips":
self.lpips = LPIPS().eval()
if not batch2model_keys:
batch2model_keys = []
# 模型额外输入的key
if isinstance(batch2model_keys, str):
batch2model_keys = [batch2model_keys]
self.batch2model_keys = set(batch2model_keys)
forward
真的的前向推理在_forward函数,forward函数就是生成了condition向量然后调用了_forward函数
def get_noised_input(
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
) -> torch.Tensor:
noised_input = input + noise * sigmas_bc
return noised_input
def forward(
self,
network: nn.Module,
denoiser: Denoiser,
conditioner: GeneralConditioner,
input: torch.Tensor,
batch: Dict,
) -> torch.Tensor:
cond = conditioner(batch)
return self._forward(network, denoiser, cond, input, batch)
_forward
def _forward(
self,
network: nn.Module,
denoiser: Denoiser,
cond: Dict,
input: torch.Tensor,
batch: Dict,
) -> Tuple[torch.Tensor, Dict]:
# 可能会有一些额外输入
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
# 生成对数正态分布的采样
sigmas = self.sigma_sampler(input.shape[0]).to(input)
# self.offset_noise_level > 0.0的话就对noise再加偏移,这个偏移本身也是noise的
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
offset_shape = (
(input.shape[0], 1, input.shape[2])
if self.n_frames is not None
else (input.shape[0], input.shape[1])
)
noise = noise + self.offset_noise_level * append_dims(
torch.randn(offset_shape, device=input.device),
input.ndim,
)
# 对图片加噪
sigmas_bc = append_dims(sigmas, input.ndim)
noised_input = self.get_noised_input(sigmas_bc, noise, input)
# 模型输出,也就是denoiser,也就是score matching
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
# 计算权重
w = append_dims(self.loss_weighting(sigmas), input.ndim)
# return就是具体loss,比如l2loss
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.loss_type == "l2":
return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
elif self.loss_type == "l1":
return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
elif self.loss_type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
else:
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
小结
loss_fn相当于就是完成了论文中公式2、3
但是这里非常反直觉的一个点是,得分匹配计算的D()用于拟合实际图片输入y