SDXL代码阅读-训练(pytorch lightning)-2.2
上一节给PL开了个头,简单讲了一下PL构建一整套模型及训练系统需要写什么,这篇文章继续。
起始参考代码还是generative-models/blob/main/main.py,见662行model = instantiate_from_config(config.model)结合上一节解析的SDXL是如何动态导入模块的,那么这里我们可以直接去配置文件看看它到底是怎么进行model的构造的。
Model
还是以generative-models/configs/example_training/toy/mnist_cond.yaml配置文件作为参考,发现model构造来自于sgm.models.diffusion.DiffusionEngine方法。
DiffusionEngine
通过配置文件是多层target嵌套可以知道DiffusionEngine是由多个module组合init的,参考init代码也可以得知由以下几个模块组成:
1.network
2.denoiser
3.sampler
4.loss_fn
5.conditioner
class DiffusionEngine(pl.LightningModule):
def __init__(
self,
network_config,
denoiser_config,
first_stage_config,
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
network_wrapper: Union[None, str] = None,
ckpt_path: Union[None, str] = None,
use_ema: bool = False,
ema_decay_rate: float = 0.9999,
scale_factor: float = 1.0,
disable_first_stage_autocast=False,
input_key: str = "jpg",
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.log_keys = log_keys
self.input_key = input_key
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.AdamW"}
)
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = (
instantiate_from_config(sampler_config)
if sampler_config is not None
else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self.scheduler_config = scheduler_config
self._init_first_stage(first_stage_config)
self.loss_fn = (
instantiate_from_config(loss_fn_config)
if loss_fn_config is not None
else None
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
整个model其实对应着DiffusionEngine,而DiffusionEngine里面初始化的model(network_config)更符合我们以往对于model的定义,对应配置文件也就是sgm.modules.diffusionmodules.openaimodel.UNetModel,后续再分步阅读。
_init_first_stage
DiffusionEngine会有一个init_first_stage的方法,用于直接copy一个network.eval且固定其权重,因为有些part是不需要训练的,比如vae的解码编码、clip等等。
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
譬如后面的encode、decode就是调用first_stage_model
@torch.no_grad()
def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
forward
forward函数与pytorch不同点在于,它将计算loss的过程放在了forward里面,因此具体要看loss_fn是怎么写的。
def forward(self, x, batch):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
train
self.input_key是字符串变量,表示图片数据的格式,默认值是"jpg",这里batch参数还不确定是什么,可能要具体看数据集是怎么构造的。
def get_input(self, batch):
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]
share step
真正的训练step,首先通过get_input对输入做处理,然后对input进行encode操作到潜在空间,然后调用forward函数进行前向推理。
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
x = self.encode_first_stage(x)
# global_step 是 PyTorch Lightning 中的一个属性,用于跟踪训练过程中的全局步数。它表示从训练开始到当前已经完成的优化步骤的总数。
# global_step 是一个自动维护的计数器,通常用于记录和监控训练过程中的指标,尤其是在日志记录和学习率调度等场景中。
batch["global_step"] = self.global_step
loss, loss_dict = self(x, batch)
return loss, loss_dict
training_step核心就是调用了share_step,其余部分基本上就是在记录log
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
if self.scheduler_config is not None:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
return loss
configure_optimizers
在前面的训练流程里面出现了self.optimizers(),对应的就是训练过程的优化器,优化器的配置对应的就是configure_optimizers。
configure_optimizers 方法在训练开始前和从检查点恢复训练时被 PyTorch Lightning 自动调用。它用于初始化优化器和学习率调度器,以便在训练过程中正确地更新模型参数和调整学习率。通过这种机制,Lightning 框架能够简化优化器和调度器的管理,使得训练过程更加高效和便捷。
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for embedder in self.conditioner.embedders:
if embedder.is_trainable:
params = params + list(embedder.parameters())
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
# LambdaLR通过用户定义的规则、函数来动态调整学习率
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
Sample
主要就是调用denosier和sampler
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.device)
denoiser = lambda input, sigma, c: self.denoiser(
self.model, input, sigma, c, **kwargs
)
samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples
log_conditionings
这个函数我理解就是将各种各样的条件输入转换为一种img_size的张量并记录log
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[2:]
log = dict()
for embedder in self.conditioner.embedders:
if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = [
"x".join([str(xx) for xx in x[i].tolist()])
for i in range(x.shape[0])
]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
# strings
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
核心就是判断x到底是什么类型的数据后,来调用log_txt_as_img
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white") # 白色背景新图像
draw = ImageDraw.Draw(txt) # 创建为一个绘图对象
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) # 指定字体
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
)
# 在图上把字体画出来
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
# 返回一个tensor格式的张量
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
log_images
这个函数也是记录输入图像到log里面,比如记录了encode前是什么图像信息、直接decode又是什么图像信息,如果sample就再记录sample+decode后的信息。
@torch.no_grad()
def log_images(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: List[str] = None,
**kwargs,
) -> Dict:
# 从embedder层获取key
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
# 确定无条件生成的key(如果有)在前面获取的key里面
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
# 获取条件和无条件向量嵌入
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0
else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
log["inputs"] = x
z = self.encode_first_stage(x) # 对输入图像进行编码
log["reconstructions"] = self.decode_first_stage(z) # 记录解码图像
log.update(self.log_conditionings(batch, N)) # 更新日志(条件信息)
# 将条件向量转换到合适的device
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
# 如果sample就采样生成结果并记录到log里面
if sample:
with self.ema_scope("Plotting"):
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples"] = samples
return log
model EMA
在DiffusionEngine初始化的时候还进行了ema初始化操作:
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
且还有一个上下文方法,ema_scope用于加载模型的ema参数
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
所以我们需要了解一下EMA是什么,有什么作用。
EMA 移动指数平均
指数移动平均(Exponential Moving Average, EMA)是一种常用的技巧,可以帮助模型在训练过程中更稳定,通常用于提高模型的泛化能力和性能。
这种技术在深度学习中被广泛应用,主要有以下几个原因:
1.降低噪声和波动
在模型训练过程中,尤其是在使用随机梯度下降(SGD)或其变种时,参数更新可能会受到噪声和波动的影响。这些噪声和波动可能来自于小批量数据的不稳定性或学习率的变化。EMA 通过对参数进行加权平均,可以平滑掉这些噪声和波动,使得模型参数更加稳定。
2.缓解过拟合
EMA 可以看作是一种正则化技术。通过对参数进行加权平均,EMA 可以抑制参数的过度波动,从而减少过拟合的风险。特别是在训练后期,EMA 可以帮助模型更好地泛化到未见过的数据。
3.提高模型的泛化能力
由于 EMA 在一定程度上平滑了参数更新,它可以帮助模型更好地捕捉数据的总体趋势,而不是过度拟合到训练数据中的细节。这种平滑效应可以提高模型在测试集上的性能,从而提高模型的泛化能力。
4.减少参数的极端值
EMA 可以防止参数出现极端值。极端值可能会导致模型的不稳定性和性能下降。通过对参数进行加权平均,EMA 可以减缓参数的剧烈变化,使得参数更加平滑和稳定。
(来自于chatgpt)
我的理解就是有点类似于集成模型的原理,集成了训练以来所有轮参数模型的一个结果,移动平均就是对所有模型输出结果的一个处理。
公式就是
对应到源文件generative-models/sgm/modules/ema.py
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
# 为decay和num_updates注册buffer
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
# 遍历原模型的可训练参数,为每个参数对应创建一个ema参数并注册buffer
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
# 重置num_updates
del self.num_updates
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
# m_param是原模型的参数
# shadow_params是ema模型的参数
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
# ema_w = ema_w - (1 - decay)(ema_w - w)
# = decay*ema_w + (1 - decay)*w
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
# 把ema参数复制到模型参数
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
# 存储和重新载入模型参数
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)