SDXL代码阅读-训练(pytorch lightning)-2.1

SDXL是使用PL(pytorch lightning)进行训练的,其实PL与pytorch在写法上区别不算很大,只不过pytorch需要构造模型,然后零散的去写train、loss、optimizer等。PL构建的不是一个模型,而是基于一个模型的一套系统,这套系统包含该模型的train、loss等等,PL实际就是对pytorch的方法的封装和简化,让模型构建、训练、测试整套流程代码简洁易懂。
对应的,我们可以通过SDXL代码阅读顺便来学习一下pytorch lightning,因此代码对应pytorch从几个方面分开记录:
1.数据集构造
2.模型构造
3.Loss与Optimizer构造
4.训练step和测试
5.log(optional)

学习的代码参考generative-models/blob/main/main.py

SDXL 模块导入

在记录之前,首先需要了解一下SDXL是如何进行模块的搭建的,通过这种方法我们可以编辑配置文件的内容来按需引入module中的方法,并给予其想要的参数。

举个例子:搭建模型的时候我希望引入一个attention模块里的SE-attn方法来构建网络,并给予参数1、参数2初始化。过了几天我可能又发现另一个可以尝试的attn比如self-attn,那么我只需要在attention模块里写好self-attn,在配置文件里编辑好self-attn及对应的参数,就不需要再修改我的model构建文件,每次去更换import各种方法。

当然,可以直接通过import *的方式去导入所有的方法,但我觉得配置文件的方式更加清晰明了,并且通过配置文件就可以一目了然的看到这次搭建的model里面具体用了哪些module、function,而不用具体看代码。

配置文件写法

参考的配置文件路径在generative-models/configs/example_training/toy/mnist_cond.yaml
简单的来说,target字段就是想要import的模块、方法,对应params就是初始化该方法的参数。

model:
  base_learning_rate: 1.0e-4
  target: sgm.models.diffusion.DiffusionEngine
  params:
    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.Denoiser
      params:
        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
          params:
            sigma_data: 1.0

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        in_channels: 1
        out_channels: 1
        model_channels: 32
        attention_resolutions: []
        num_res_blocks: 4
        channel_mult: [1, 2, 2]
        num_head_channels: 32
        num_classes: sequential
        adm_in_channels: 128

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          - is_trainable: True
            input_key: cls
            ucg_rate: 0.2
            target: sgm.modules.encoders.modules.ClassEmbedder
            params:
              embed_dim: 128
              n_classes: 10

    first_stage_config:
      target: sgm.models.autoencoder.IdentityFirstStage

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:
        loss_weighting_config:
          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
          params:
            sigma_data: 1.0
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 50

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.VanillaCFG
          params:
            scale: 3.0

data:
  target: sgm.data.mnist.MNISTLoader
  params:
    batch_size: 512
    num_workers: 1

lightning:
  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 25000

    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        batch_frequency: 1000
        max_images: 16
        increase_log_steps: True
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          N: 16
          n_rows: 4

  trainer:
    devices: 0,
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 20

instantiate_from_config

该函数用于解析配置文件,就是确定一下要有target参数。
main.py到处都有这个函数,包括之前讲采样过程也是。

def instantiate_from_config(config):
    if not "target" in config:
        if config == "__is_first_stage__":
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    # print("--------------------------------------")
    # print(config)
    # print(**config)
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

get_obj_from_str

从字符串中去import module里面的class方法

def get_obj_from_str(string, reload=False, invalidate_cache=True):
    # rsplit方法分割字符串一次,获得module和class名称
    module, cls = string.rsplit(".", 1)
    # 调用 importlib.invalidate_caches(),使导入系统的缓存失效。这在文件系统发生变化时可能会有用。
    if invalidate_cache:
        importlib.invalidate_caches()
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp) # 重新加载模块,确保获取最新版本。
    
    # 使用 importlib.import_module(module_path) 导入一个模块时,Python 会将模块加载为一个对象。
    # 这个模块对象的属性包括模块中定义的所有内容,比如类、函数、和变量。
    # getattr从模块对象中获取名为 cls 的属性。
    return getattr(importlib.import_module(module, package=None), cls)

实际上调用instantiate_from_config就是通过该方法完成了一个module.class(params)的初始化过程。

PytorchLightning 构造

dataloader构建

其实dataloader构建可以按照pytorch的dataset、dataloader类来构建,也可以用LightningDataModule来构建。虽然 LightningDataModule 没有严格的“必须”实现的方法,但为了充分利用其功能,通常会实现下述几个方法:prepare_data、setup、train_dataloader、val_dataloader、test_dataloader。
下面会主要讲一下SDXL用LightningDataModule来构造dataloader。

MNIST数据

本质上还是用的pytorch的dataloader,但还是实现了PL需要的一些方法

import pytorch_lightning as pl
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class MNISTDataDictWrapper(Dataset):
    def __init__(self, dset):
        super().__init__()
        self.dset = dset

    def __getitem__(self, i):
        x, y = self.dset[i]
        return {"jpg": x, "cls": y}

    def __len__(self):
        return len(self.dset)


class MNISTLoader(pl.LightningDataModule):
    def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
        super().__init__()

        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
        )

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
        self.shuffle = shuffle
        self.train_dataset = MNISTDataDictWrapper(
            torchvision.datasets.MNIST(
                root=".data/", train=True, download=True, transform=transform
            )
        )
        self.test_dataset = MNISTDataDictWrapper(
            torchvision.datasets.MNIST(
                root=".data/", train=False, download=True, transform=transform
            )
        )

    def prepare_data(self):
        pass

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            prefetch_factor=self.prefetch_factor,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            prefetch_factor=self.prefetch_factor,
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            prefetch_factor=self.prefetch_factor,
        )

custom dataset

他们自己写了个sdata的库来create dataset或者dataloader,但dataloader最终的本质还是pytorch的dataloader,所以简单看一下继承PL的模块还是实现那些方法就行了

from typing import Optional

import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule

try:
    from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
    print("#" * 100)
    print("Datasets not yet available")
    print("to enable, we need to add stable-datasets as a submodule")
    print("please use ``git submodule update --init --recursive``")
    print("and do ``pip install -e stable-datasets/`` from the root of this repo")
    print("#" * 100)
    exit(1)


class StableDataModuleFromConfig(LightningDataModule):
    def __init__(
        self,
        train: DictConfig,
        validation: Optional[DictConfig] = None,
        test: Optional[DictConfig] = None,
        skip_val_loader: bool = False,
        dummy: bool = False,
    ):
        super().__init__()
        self.train_config = train
        assert (
            "datapipeline" in self.train_config and "loader" in self.train_config
        ), "train config requires the fields `datapipeline` and `loader`"

        self.val_config = validation
        if not skip_val_loader:
            if self.val_config is not None:
                assert (
                    "datapipeline" in self.val_config and "loader" in self.val_config
                ), "validation config requires the fields `datapipeline` and `loader`"
            else:
                print(
                    "Warning: No Validation datapipeline defined, using that one from training"
                )
                self.val_config = train

        self.test_config = test
        if self.test_config is not None:
            assert (
                "datapipeline" in self.test_config and "loader" in self.test_config
            ), "test config requires the fields `datapipeline` and `loader`"

        self.dummy = dummy
        if self.dummy:
            print("#" * 100)
            print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
            print("#" * 100)

    def setup(self, stage: str) -> None:
        print("Preparing datasets")
        if self.dummy:
            data_fn = create_dummy_dataset
        else:
            data_fn = create_dataset

        self.train_datapipeline = data_fn(**self.train_config.datapipeline)
        if self.val_config:
            self.val_datapipeline = data_fn(**self.val_config.datapipeline)
        if self.test_config:
            self.test_datapipeline = data_fn(**self.test_config.datapipeline)

    def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
        loader = create_loader(self.train_datapipeline, **self.train_config.loader)
        return loader

    def val_dataloader(self) -> wds.DataPipeline:
        return create_loader(self.val_datapipeline, **self.val_config.loader)

    def test_dataloader(self) -> wds.DataPipeline:
        return create_loader(self.test_datapipeline, **self.test_config.loader)

区别是这里没有实现了prepare_data了,大概是sdata里面的datapipeline做了这件事情了。