SDXL代码阅读-Denoiser-2.3
2024-09-24
2 min read
基于score matching的 denoising
在DiffusionEngine中,我们初始化了denoiser,在loss和sample都会调用它,现在来看一下denosier大致做了什么。
denoiser
文件在generative-models/sgm/modules/diffusionmodules/denoiser.py
代码很短,也确实没有什么可说的,跟截图公式基本对应。离散的去噪器实现感觉像是没写完的样子,看起来没有什么用,因为只是修改了sigma的来源,通过Discretization去生成sigma,但是本来的denoiser类接收的sigma就是通过Discretization生成的吧。
class Denoiser(nn.Module):
def __init__(self, scaling_config: Dict):
super().__init__()
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
return c_noise
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
# 重点就在这里,对应论文公式的几个参数
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
# 这里return的结果就对应论文的公式
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
DenoiserScaling
这个就是用来跟论文里表格进行匹配,生成对应的公式参数
注意sigma_data的值来自于配置文件,我看了mnist里面设置的就是1.0
class DenoiserScaling(ABC):
@abstractmethod
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pass
class EDMScaling:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise