SDXL代码阅读-Denoiser-2.3

基于score matching的 denoising


在DiffusionEngine中,我们初始化了denoiser,在loss和sample都会调用它,现在来看一下denosier大致做了什么。

denoiser

文件在generative-models/sgm/modules/diffusionmodules/denoiser.py
代码很短,也确实没有什么可说的,跟截图公式基本对应。离散的去噪器实现感觉像是没写完的样子,看起来没有什么用,因为只是修改了sigma的来源,通过Discretization去生成sigma,但是本来的denoiser类接收的sigma就是通过Discretization生成的吧。

class Denoiser(nn.Module):
    def __init__(self, scaling_config: Dict):
        super().__init__()

        self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)

    def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
        return sigma

    def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
        return c_noise

    def forward(
        self,
        network: nn.Module,
        input: torch.Tensor,
        sigma: torch.Tensor,
        cond: Dict,
        **additional_model_inputs,
    ) -> torch.Tensor:
        sigma = self.possibly_quantize_sigma(sigma)
        sigma_shape = sigma.shape
        sigma = append_dims(sigma, input.ndim)
        # 重点就在这里,对应论文公式的几个参数
        c_skip, c_out, c_in, c_noise = self.scaling(sigma)
        c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
        # 这里return的结果就对应论文的公式
        return (
            network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
            + input * c_skip
        )

DenoiserScaling

这个就是用来跟论文里表格进行匹配,生成对应的公式参数

注意sigma_data的值来自于配置文件,我看了mnist里面设置的就是1.0

class DenoiserScaling(ABC):
    @abstractmethod
    def __call__(
        self, sigma: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        pass


class EDMScaling:
    def __init__(self, sigma_data: float = 0.5):
        self.sigma_data = sigma_data

    def __call__(
        self, sigma: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        c_noise = 0.25 * sigma.log()
        return c_skip, c_out, c_in, c_noise