SDXL代码阅读-训练(pytorch lightning)-2.2

上一节给PL开了个头,简单讲了一下PL构建一整套模型及训练系统需要写什么,这篇文章继续。
起始参考代码还是generative-models/blob/main/main.py,见662行model = instantiate_from_config(config.model)结合上一节解析的SDXL是如何动态导入模块的,那么这里我们可以直接去配置文件看看它到底是怎么进行model的构造的。

Model

还是以generative-models/configs/example_training/toy/mnist_cond.yaml配置文件作为参考,发现model构造来自于sgm.models.diffusion.DiffusionEngine方法。

DiffusionEngine

通过配置文件是多层target嵌套可以知道DiffusionEngine是由多个module组合init的,参考init代码也可以得知由以下几个模块组成:
1.network
2.denoiser
3.sampler
4.loss_fn
5.conditioner

class DiffusionEngine(pl.LightningModule):
    def __init__(
        self,
        network_config,
        denoiser_config,
        first_stage_config,
        conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
        network_wrapper: Union[None, str] = None,
        ckpt_path: Union[None, str] = None,
        use_ema: bool = False,
        ema_decay_rate: float = 0.9999,
        scale_factor: float = 1.0,
        disable_first_stage_autocast=False,
        input_key: str = "jpg",
        log_keys: Union[List, None] = None,
        no_cond_log: bool = False,
        compile_model: bool = False,
        en_and_decode_n_samples_a_time: Optional[int] = None,
    ):
        super().__init__()
        self.log_keys = log_keys
        self.input_key = input_key
        self.optimizer_config = default(
            optimizer_config, {"target": "torch.optim.AdamW"}
        )
        model = instantiate_from_config(network_config)
        self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
            model, compile_model=compile_model
        )

        self.denoiser = instantiate_from_config(denoiser_config)
        self.sampler = (
            instantiate_from_config(sampler_config)
            if sampler_config is not None
            else None
        )
        self.conditioner = instantiate_from_config(
            default(conditioner_config, UNCONDITIONAL_CONFIG)
        )
        self.scheduler_config = scheduler_config
        self._init_first_stage(first_stage_config)

        self.loss_fn = (
            instantiate_from_config(loss_fn_config)
            if loss_fn_config is not None
            else None
        )

        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self.model, decay=ema_decay_rate)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        self.scale_factor = scale_factor
        self.disable_first_stage_autocast = disable_first_stage_autocast
        self.no_cond_log = no_cond_log

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path)

        self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time

整个model其实对应着DiffusionEngine,而DiffusionEngine里面初始化的model(network_config)更符合我们以往对于model的定义,对应配置文件也就是sgm.modules.diffusionmodules.openaimodel.UNetModel,后续再分步阅读。

_init_first_stage

DiffusionEngine会有一个init_first_stage的方法,用于直接copy一个network.eval且固定其权重,因为有些part是不需要训练的,比如vae的解码编码、clip等等。

    def _init_first_stage(self, config):
        model = instantiate_from_config(config).eval()
        model.train = disabled_train
        for param in model.parameters():
            param.requires_grad = False
        self.first_stage_model = model

譬如后面的encode、decode就是调用first_stage_model

    @torch.no_grad()
    def encode_first_stage(self, x):
        n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
        n_rounds = math.ceil(x.shape[0] / n_samples)
        all_out = []
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                out = self.first_stage_model.encode(
                    x[n * n_samples : (n + 1) * n_samples]
                )
                all_out.append(out)
        z = torch.cat(all_out, dim=0)
        z = self.scale_factor * z
        return z

    @torch.no_grad()
    def decode_first_stage(self, z):
        z = 1.0 / self.scale_factor * z
        n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])

        n_rounds = math.ceil(z.shape[0] / n_samples)
        all_out = []
        with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
            for n in range(n_rounds):
                if isinstance(self.first_stage_model.decoder, VideoDecoder):
                    kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
                else:
                    kwargs = {}
                out = self.first_stage_model.decode(
                    z[n * n_samples : (n + 1) * n_samples], **kwargs
                )
                all_out.append(out)
        out = torch.cat(all_out, dim=0)
        return out

forward

forward函数与pytorch不同点在于,它将计算loss的过程放在了forward里面,因此具体要看loss_fn是怎么写的。

    def forward(self, x, batch):
        loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
        loss_mean = loss.mean()
        loss_dict = {"loss": loss_mean}
        return loss_mean, loss_dict

train

self.input_key是字符串变量,表示图片数据的格式,默认值是"jpg",这里batch参数还不确定是什么,可能要具体看数据集是怎么构造的。

    def get_input(self, batch):
        # assuming unified data format, dataloader returns a dict.
        # image tensors should be scaled to -1 ... 1 and in bchw format
        return batch[self.input_key]

share step
真正的训练step,首先通过get_input对输入做处理,然后对input进行encode操作到潜在空间,然后调用forward函数进行前向推理。

    def shared_step(self, batch: Dict) -> Any:
        x = self.get_input(batch)
        x = self.encode_first_stage(x)
        # global_step 是 PyTorch Lightning 中的一个属性,用于跟踪训练过程中的全局步数。它表示从训练开始到当前已经完成的优化步骤的总数。
        # global_step 是一个自动维护的计数器,通常用于记录和监控训练过程中的指标,尤其是在日志记录和学习率调度等场景中。
        batch["global_step"] = self.global_step
        loss, loss_dict = self(x, batch)
        return loss, loss_dict

training_step核心就是调用了share_step,其余部分基本上就是在记录log

    def training_step(self, batch, batch_idx):
        loss, loss_dict = self.shared_step(batch)

        self.log_dict(
            loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
        )

        self.log(
            "global_step",
            self.global_step,
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=False,
        )

        if self.scheduler_config is not None:
            lr = self.optimizers().param_groups[0]["lr"]
            self.log(
                "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
            )

        return loss

configure_optimizers

在前面的训练流程里面出现了self.optimizers(),对应的就是训练过程的优化器,优化器的配置对应的就是configure_optimizers。
configure_optimizers 方法在训练开始前和从检查点恢复训练时被 PyTorch Lightning 自动调用。它用于初始化优化器和学习率调度器,以便在训练过程中正确地更新模型参数和调整学习率。通过这种机制,Lightning 框架能够简化优化器和调度器的管理,使得训练过程更加高效和便捷。

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        for embedder in self.conditioner.embedders:
            if embedder.is_trainable:
                params = params + list(embedder.parameters())
        opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)
            print("Setting up LambdaLR scheduler...")
            # LambdaLR通过用户定义的规则、函数来动态调整学习率
            scheduler = [
                {
                    "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
                    "interval": "step",
                    "frequency": 1,
                }
            ]
            return [opt], scheduler
        return opt

Sample

主要就是调用denosier和sampler

    @torch.no_grad()
    def sample(
        self,
        cond: Dict,
        uc: Union[Dict, None] = None,
        batch_size: int = 16,
        shape: Union[None, Tuple, List] = None,
        **kwargs,
    ):
        randn = torch.randn(batch_size, *shape).to(self.device)

        denoiser = lambda input, sigma, c: self.denoiser(
            self.model, input, sigma, c, **kwargs
        )
        samples = self.sampler(denoiser, randn, cond, uc=uc)
        return samples

log_conditionings

这个函数我理解就是将各种各样的条件输入转换为一种img_size的张量并记录log

    @torch.no_grad()
    def log_conditionings(self, batch: Dict, n: int) -> Dict:
        """
        Defines heuristics to log different conditionings.
        These can be lists of strings (text-to-image), tensors, ints, ...
        """
        image_h, image_w = batch[self.input_key].shape[2:]
        log = dict()

        for embedder in self.conditioner.embedders:
            if (
                (self.log_keys is None) or (embedder.input_key in self.log_keys)
            ) and not self.no_cond_log:
                x = batch[embedder.input_key][:n]
                if isinstance(x, torch.Tensor):
                    if x.dim() == 1:
                        # class-conditional, convert integer to string
                        x = [str(x[i].item()) for i in range(x.shape[0])]
                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
                    elif x.dim() == 2:
                        # size and crop cond and the like
                        x = [
                            "x".join([str(xx) for xx in x[i].tolist()])
                            for i in range(x.shape[0])
                        ]
                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
                    else:
                        raise NotImplementedError()
                elif isinstance(x, (List, ListConfig)):
                    if isinstance(x[0], str):
                        # strings
                        xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
                    else:
                        raise NotImplementedError()
                else:
                    raise NotImplementedError()
                log[embedder.input_key] = xc
        return log

核心就是判断x到底是什么类型的数据后,来调用log_txt_as_img

def log_txt_as_img(wh, xc, size=10):
    # wh a tuple of (width, height)
    # xc a list of captions to plot
    b = len(xc)
    txts = list()
    for bi in range(b):
        txt = Image.new("RGB", wh, color="white") # 白色背景新图像
        draw = ImageDraw.Draw(txt) # 创建为一个绘图对象
        font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) # 指定字体
        nc = int(40 * (wh[0] / 256))
        if isinstance(xc[bi], list):
            text_seq = xc[bi][0]
        else:
            text_seq = xc[bi]
        lines = "\n".join(
            text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
        )
        
        # 在图上把字体画出来
        try:
            draw.text((0, 0), lines, fill="black", font=font)
        except UnicodeEncodeError:
            print("Cant encode string for logging. Skipping.")

        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
        txts.append(txt)
    # 返回一个tensor格式的张量
    txts = np.stack(txts)
    txts = torch.tensor(txts)
    return txts

log_images

这个函数也是记录输入图像到log里面,比如记录了encode前是什么图像信息、直接decode又是什么图像信息,如果sample就再记录sample+decode后的信息。

    @torch.no_grad()
    def log_images(
        self,
        batch: Dict,
        N: int = 8,
        sample: bool = True,
        ucg_keys: List[str] = None,
        **kwargs,
    ) -> Dict:
        # 从embedder层获取key
        conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
        # 确定无条件生成的key(如果有)在前面获取的key里面
        if ucg_keys:
            assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
                "Each defined ucg key for sampling must be in the provided conditioner input keys,"
                f"but we have {ucg_keys} vs. {conditioner_input_keys}"
            )
        else:
            ucg_keys = conditioner_input_keys
        log = dict()

        x = self.get_input(batch)

        # 获取条件和无条件向量嵌入
        c, uc = self.conditioner.get_unconditional_conditioning(
            batch,
            force_uc_zero_embeddings=ucg_keys
            if len(self.conditioner.embedders) > 0
            else [],
        )

        sampling_kwargs = {}

        N = min(x.shape[0], N)
        x = x.to(self.device)[:N]
        log["inputs"] = x
        z = self.encode_first_stage(x) # 对输入图像进行编码
        log["reconstructions"] = self.decode_first_stage(z) # 记录解码图像
        log.update(self.log_conditionings(batch, N)) # 更新日志(条件信息)
        
        # 将条件向量转换到合适的device
        for k in c:
            if isinstance(c[k], torch.Tensor):
                c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
        
        # 如果sample就采样生成结果并记录到log里面
        if sample:
            with self.ema_scope("Plotting"):
                samples = self.sample(
                    c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
                )
            samples = self.decode_first_stage(samples)
            log["samples"] = samples
        return log

model EMA

在DiffusionEngine初始化的时候还进行了ema初始化操作:

if self.use_ema:
            self.model_ema = LitEma(self.model, decay=ema_decay_rate)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

且还有一个上下文方法,ema_scope用于加载模型的ema参数

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

所以我们需要了解一下EMA是什么,有什么作用。

EMA 移动指数平均

指数移动平均(Exponential Moving Average, EMA)是一种常用的技巧,可以帮助模型在训练过程中更稳定,通常用于提高模型的泛化能力和性能。
这种技术在深度学习中被广泛应用,主要有以下几个原因:

1.降低噪声和波动
在模型训练过程中,尤其是在使用随机梯度下降(SGD)或其变种时,参数更新可能会受到噪声和波动的影响。这些噪声和波动可能来自于小批量数据的不稳定性或学习率的变化。EMA 通过对参数进行加权平均,可以平滑掉这些噪声和波动,使得模型参数更加稳定。

2.缓解过拟合
EMA 可以看作是一种正则化技术。通过对参数进行加权平均,EMA 可以抑制参数的过度波动,从而减少过拟合的风险。特别是在训练后期,EMA 可以帮助模型更好地泛化到未见过的数据。

3.提高模型的泛化能力
由于 EMA 在一定程度上平滑了参数更新,它可以帮助模型更好地捕捉数据的总体趋势,而不是过度拟合到训练数据中的细节。这种平滑效应可以提高模型在测试集上的性能,从而提高模型的泛化能力。

4.减少参数的极端值
EMA 可以防止参数出现极端值。极端值可能会导致模型的不稳定性和性能下降。通过对参数进行加权平均,EMA 可以减缓参数的剧烈变化,使得参数更加平滑和稳定。
(来自于chatgpt)

我的理解就是有点类似于集成模型的原理,集成了训练以来所有轮参数模型的一个结果,移动平均就是对所有模型输出结果的一个处理。
公式就是emaw=decayemaw+(1decay)wema_w​=decay⋅ema_w​+(1−decay)⋅w
对应到源文件generative-models/sgm/modules/ema.py

import torch
from torch import nn


class LitEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")

        self.m_name2s_name = {}
        
        # 为decay和num_updates注册buffer
        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
        self.register_buffer(
            "num_updates",
            torch.tensor(0, dtype=torch.int)
            if use_num_upates
            else torch.tensor(-1, dtype=torch.int),
        )

        # 遍历原模型的可训练参数,为每个参数对应创建一个ema参数并注册buffer
        for name, p in model.named_parameters():
            if p.requires_grad:
                # remove as '.'-character is not allowed in buffers
                s_name = name.replace(".", "")
                self.m_name2s_name.update({name: s_name})
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []

    def reset_num_updates(self):
        # 重置num_updates
        del self.num_updates
        self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            # m_param是原模型的参数
            # shadow_params是ema模型的参数
            m_param = dict(model.named_parameters())
            shadow_params = dict(self.named_buffers())

            for key in m_param:
                if m_param[key].requires_grad:
                    sname = self.m_name2s_name[key]
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    # ema_w = ema_w - (1 - decay)(ema_w - w)
                    #       = decay*ema_w + (1 - decay)*w
                    shadow_params[sname].sub_(
                        one_minus_decay * (shadow_params[sname] - m_param[key])
                    )
                else:
                    assert not key in self.m_name2s_name

    def copy_to(self, model):
        # 把ema参数复制到模型参数
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        for key in m_param:
            if m_param[key].requires_grad:
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                assert not key in self.m_name2s_name

    # 存储和重新载入模型参数
    def store(self, parameters):
        """
        Save the current parameters for restoring later.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)