SDXL代码阅读-训练(pytorch lightning)-2.1
SDXL是使用PL(pytorch lightning)进行训练的,其实PL与pytorch在写法上区别不算很大,只不过pytorch需要构造模型,然后零散的去写train、loss、optimizer等。PL构建的不是一个模型,而是基于一个模型的一套系统,这套系统包含该模型的train、loss等等,PL实际就是对pytorch的方法的封装和简化,让模型构建、训练、测试整套流程代码简洁易懂。
对应的,我们可以通过SDXL代码阅读顺便来学习一下pytorch lightning,因此代码对应pytorch从几个方面分开记录:
1.数据集构造
2.模型构造
3.Loss与Optimizer构造
4.训练step和测试
5.log(optional)
学习的代码参考generative-models/blob/main/main.py
SDXL 模块导入
在记录之前,首先需要了解一下SDXL是如何进行模块的搭建的,通过这种方法我们可以编辑配置文件的内容来按需引入module中的方法,并给予其想要的参数。
举个例子:搭建模型的时候我希望引入一个attention模块里的SE-attn方法来构建网络,并给予参数1、参数2初始化。过了几天我可能又发现另一个可以尝试的attn比如self-attn,那么我只需要在attention模块里写好self-attn,在配置文件里编辑好self-attn及对应的参数,就不需要再修改我的model构建文件,每次去更换import各种方法。
当然,可以直接通过import *的方式去导入所有的方法,但我觉得配置文件的方式更加清晰明了,并且通过配置文件就可以一目了然的看到这次搭建的model里面具体用了哪些module、function,而不用具体看代码。
配置文件写法
参考的配置文件路径在generative-models/configs/example_training/toy/mnist_cond.yaml
简单的来说,target字段就是想要import的模块、方法,对应params就是初始化该方法的参数。
model:
base_learning_rate: 1.0e-4
target: sgm.models.diffusion.DiffusionEngine
params:
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
params:
sigma_data: 1.0
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
in_channels: 1
out_channels: 1
model_channels: 32
attention_resolutions: []
num_res_blocks: 4
channel_mult: [1, 2, 2]
num_head_channels: 32
num_classes: sequential
adm_in_channels: 128
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: True
input_key: cls
ucg_rate: 0.2
target: sgm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 128
n_classes: 10
first_stage_config:
target: sgm.models.autoencoder.IdentityFirstStage
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
params:
loss_weighting_config:
target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
params:
sigma_data: 1.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
num_steps: 50
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.VanillaCFG
params:
scale: 3.0
data:
target: sgm.data.mnist.MNISTLoader
params:
batch_size: 512
num_workers: 1
lightning:
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 25000
image_logger:
target: main.ImageLogger
params:
disabled: False
batch_frequency: 1000
max_images: 16
increase_log_steps: True
log_first_step: False
log_images_kwargs:
use_ema_scope: False
N: 16
n_rows: 4
trainer:
devices: 0,
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
max_epochs: 20
instantiate_from_config
该函数用于解析配置文件,就是确定一下要有target参数。
main.py到处都有这个函数,包括之前讲采样过程也是。
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
# print("--------------------------------------")
# print(config)
# print(**config)
return get_obj_from_str(config["target"])(**config.get("params", dict()))
get_obj_from_str
从字符串中去import module里面的class方法
def get_obj_from_str(string, reload=False, invalidate_cache=True):
# rsplit方法分割字符串一次,获得module和class名称
module, cls = string.rsplit(".", 1)
# 调用 importlib.invalidate_caches(),使导入系统的缓存失效。这在文件系统发生变化时可能会有用。
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp) # 重新加载模块,确保获取最新版本。
# 使用 importlib.import_module(module_path) 导入一个模块时,Python 会将模块加载为一个对象。
# 这个模块对象的属性包括模块中定义的所有内容,比如类、函数、和变量。
# getattr从模块对象中获取名为 cls 的属性。
return getattr(importlib.import_module(module, package=None), cls)
实际上调用instantiate_from_config就是通过该方法完成了一个module.class(params)的初始化过程。
PytorchLightning 构造
dataloader构建
其实dataloader构建可以按照pytorch的dataset、dataloader类来构建,也可以用LightningDataModule来构建。虽然 LightningDataModule 没有严格的“必须”实现的方法,但为了充分利用其功能,通常会实现下述几个方法:prepare_data、setup、train_dataloader、val_dataloader、test_dataloader。
下面会主要讲一下SDXL用LightningDataModule来构造dataloader。
MNIST数据
本质上还是用的pytorch的dataloader,但还是实现了PL需要的一些方法
import pytorch_lightning as pl
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
class MNISTDataDictWrapper(Dataset):
def __init__(self, dset):
super().__init__()
self.dset = dset
def __getitem__(self, i):
x, y = self.dset[i]
return {"jpg": x, "cls": y}
def __len__(self):
return len(self.dset)
class MNISTLoader(pl.LightningDataModule):
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
super().__init__()
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
)
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
self.shuffle = shuffle
self.train_dataset = MNISTDataDictWrapper(
torchvision.datasets.MNIST(
root=".data/", train=True, download=True, transform=transform
)
)
self.test_dataset = MNISTDataDictWrapper(
torchvision.datasets.MNIST(
root=".data/", train=False, download=True, transform=transform
)
)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
)
custom dataset
他们自己写了个sdata的库来create dataset或者dataloader,但dataloader最终的本质还是pytorch的dataloader,所以简单看一下继承PL的模块还是实现那些方法就行了
from typing import Optional
import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
print("please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)
class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False,
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "train config requires the fields `datapipeline` and `loader`"
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
print(
"Warning: No Validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "test config requires the fields `datapipeline` and `loader`"
self.dummy = dummy
if self.dummy:
print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
def setup(self, stage: str) -> None:
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
return loader
def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_datapipeline, **self.val_config.loader)
def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_datapipeline, **self.test_config.loader)
区别是这里没有实现了prepare_data了,大概是sdata里面的datapipeline做了这件事情了。