SDXL代码阅读-Conditioner-2.4
2024-09-24
6 min read
Conditioner
在DiffusionEngine中同样也通过配置文件初始化了self.conditioner,这个部分主要是控制条件向量的生成,以及在loss_fn里面也会用到他,因此我们来看一下。
GeneralConditioner
代码路径:generative-models/sgm/modules/encoders/modules.py
这也是配置文件直接初始化的self.conditioner,其主要功能也是初始化embedding model,然后根据一些参数去做conditioner和unconditioner的向量生成。
class GeneralConditioner(nn.Module):
# 将不同维度的输出映射到特定的键
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat"} # , 5: "concat"}
# 将不同维度的输出映射到特定的键
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1, "cond_view": 1, "cond_motion": 1}
def __init__(self, emb_models: Union[List, ListConfig]):
super().__init__()
embedders = []
# 初始化emb模型,并设置一些参数
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
embedder.is_trainable = embconfig.get("is_trainable", False)
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
# emb模型必须要有input_key或者input_keys键
if "input_key" in embconfig:
embedder.input_key = embconfig["input_key"]
elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"]
else:
raise KeyError(
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
)
# 如果legacy_ucg_val存在就初始化一个随机数生成器
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
# 根据 ucg_rate 随机替换批量数据中的值
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def forward(
self, batch: Dict, force_zero_embeddings: Optional[List] = None
) -> Dict:
output = dict()
if force_zero_embeddings is None:
force_zero_embeddings = []
# 根据legacy_ucg_val的值来判断是否要对数据做处理,也就是上面的possibly_get_ucg_val
for embedder in self.embedders:
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
with embedding_context():
if hasattr(embedder, "input_key") and (embedder.input_key is not None):
if embedder.legacy_ucg_val is not None:
batch = self.possibly_get_ucg_val(embedder, batch)
emb_out = embedder(batch[embedder.input_key])
elif hasattr(embedder, "input_keys"):
emb_out = embedder(*[batch[k] for k in embedder.input_keys])
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
# 对emb_out进行一下处理
for emb in emb_out:
# 确定out_key(因为SD整个过程中基本上都用dict的形式传递数据)
if embedder.input_key in ["cond_view", "cond_motion"]:
out_key = embedder.input_key
else:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
# 如果 ucg_rate > 0 且 legacy_ucg_val 为空,使用伯努利分布生成随机掩码,并应用于嵌入输出。
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate)
* torch.ones(emb.shape[0], device=emb.device)
),
emb,
)
* emb
)
# 如果 input_key 在 force_zero_embeddings 中,将嵌入输出设为全零。
if (
hasattr(embedder, "input_key")
and embedder.input_key in force_zero_embeddings
):
emb = torch.zeros_like(emb)
# 如果 out_key 已存在于输出字典中,拼接新嵌入输出,否则直接赋值。
if out_key in output:
output[out_key] = torch.cat(
(output[out_key], emb), self.KEY2CATDIM[out_key]
)
else:
output[out_key] = emb
return output
def get_unconditional_conditioning(
self,
batch_c: Dict,
batch_uc: Optional[Dict] = None,
force_uc_zero_embeddings: Optional[List[str]] = None,
force_cond_zero_embeddings: Optional[List[str]] = None,
):
# 获取无条件和条件的嵌入表示
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
# 暂时将 ucg_rate 设为 0,计算条件和无条件的输出,然后恢复 ucg_rate
embedder.ucg_rate = 0.0
c = self(batch_c, force_cond_zero_embeddings)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
return c, uc
ClassEmbedder
上面看到embedder通过配置文件sgm.modules.encoders.modules.ClassEmbedder初始化,所以我们也来看一眼ClassEmbedder(其实也没啥,就是torch的embedding做一个forward)
class ClassEmbedder(AbstractEmbModel):
def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
super().__init__()
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.add_sequence_dim = add_sequence_dim
def forward(self, c):
c = self.embedding(c)
if self.add_sequence_dim:
c = c[:, None, :]
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
# 产生一个uc的张量,以字典的形式返回
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc.long()}
return uc