SDXL代码阅读-Loss_fn-2.5

Loss_fn

前面讲的denosier、conditioner在DiffusionEngine的forward阶段都被送入self.loss_fn进行计算,因此Loss_fn具体做了什么也是串联前面内容的重要部分。
由于Loss_fn的初始化也涉及到其他模块,先来阅读这些模块。
对应论文表格:

sigma_sampling

文件来自于generative-models/sgm/modules/diffusionmodules/sigma_sampling.py
从配置文件(sgm.modules.diffusionmodules.sigma_sampling.EDMSampling)可以知道导入的是EDMSampling,看了一下代码,实际上就是从正态分布里面采样。

class EDMSampling:
    """
    通过均值和标准差来控制在对数标准正态分布里来采样
    """
    def __init__(self, p_mean=-1.2, p_std=1.2):
        self.p_mean = p_mean
        self.p_std = p_std

    def __call__(self, n_samples, rand=None):
        log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
        return log_sigma.exp()

loss_weighting

文件来自于generative-models/sgm/modules/diffusionmodules/loss_weighting.py
也就是完全对应论文图片

class EDMWeighting(DiffusionLossWeighting):
    def __init__(self, sigma_data: float = 0.5):
        self.sigma_data = sigma_data

    def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2

StandardDiffusionLoss

init

class StandardDiffusionLoss(nn.Module):
    def __init__(
        self,
        sigma_sampler_config: dict,
        loss_weighting_config: dict,
        loss_type: str = "l2",
        offset_noise_level: float = 0.0,
        batch2model_keys: Optional[Union[str, List[str]]] = None,
    ):
        super().__init__()

        assert loss_type in ["l2", "l1", "lpips"]

        self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
        self.loss_weighting = instantiate_from_config(loss_weighting_config)

        self.loss_type = loss_type
        self.offset_noise_level = offset_noise_level
        
        # lpips是一种用于衡量图像之间感知相似度的损失函数。
        # 与传统的像素级损失(如 L1 或 L2 损失)不同,LPIPS 通过使用预训练的深度卷积神经网络来捕捉更接近人类视觉感知的相似性。
        if loss_type == "lpips":
            self.lpips = LPIPS().eval()

        if not batch2model_keys:
            batch2model_keys = []
        
        # 模型额外输入的key
        if isinstance(batch2model_keys, str):
            batch2model_keys = [batch2model_keys]

        self.batch2model_keys = set(batch2model_keys)

forward

真的的前向推理在_forward函数,forward函数就是生成了condition向量然后调用了_forward函数

    def get_noised_input(
        self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
    ) -> torch.Tensor:
        noised_input = input + noise * sigmas_bc
        return noised_input

    def forward(
        self,
        network: nn.Module,
        denoiser: Denoiser,
        conditioner: GeneralConditioner,
        input: torch.Tensor,
        batch: Dict,
    ) -> torch.Tensor:
        cond = conditioner(batch)
        return self._forward(network, denoiser, cond, input, batch)

_forward

    def _forward(
        self,
        network: nn.Module,
        denoiser: Denoiser,
        cond: Dict,
        input: torch.Tensor,
        batch: Dict,
    ) -> Tuple[torch.Tensor, Dict]:
        # 可能会有一些额外输入
        additional_model_inputs = {
            key: batch[key] for key in self.batch2model_keys.intersection(batch)
        }
        # 生成对数正态分布的采样
        sigmas = self.sigma_sampler(input.shape[0]).to(input)

        # self.offset_noise_level > 0.0的话就对noise再加偏移,这个偏移本身也是noise的
        noise = torch.randn_like(input)
        if self.offset_noise_level > 0.0:
            offset_shape = (
                (input.shape[0], 1, input.shape[2])
                if self.n_frames is not None
                else (input.shape[0], input.shape[1])
            )
            noise = noise + self.offset_noise_level * append_dims(
                torch.randn(offset_shape, device=input.device),
                input.ndim,
            )
        
        # 对图片加噪
        sigmas_bc = append_dims(sigmas, input.ndim)
        noised_input = self.get_noised_input(sigmas_bc, noise, input)

        # 模型输出,也就是denoiser,也就是score matching
        model_output = denoiser(
            network, noised_input, sigmas, cond, **additional_model_inputs
        )
        # 计算权重
        w = append_dims(self.loss_weighting(sigmas), input.ndim)
        
        # return就是具体loss,比如l2loss
        return self.get_loss(model_output, input, w)

    def get_loss(self, model_output, target, w):
        if self.loss_type == "l2":
            return torch.mean(
                (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
            )
        elif self.loss_type == "l1":
            return torch.mean(
                (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
            )
        elif self.loss_type == "lpips":
            loss = self.lpips(model_output, target).reshape(-1)
            return loss
        else:
            raise NotImplementedError(f"Unknown loss type {self.loss_type}")

小结

loss_fn相当于就是完成了论文中公式2、3

但是这里非常反直觉的一个点是,得分匹配计算的D()用于拟合实际图片输入y