SDXL代码阅读-Pipeline推理-1.1

整体参考generative-models/tests/inference/test_inference.py,很多重要方法来自于sgm/inference/api.py

这里我将test_inference.py里面代码自己改写,不走pytest的流程,并将一些东西直接初始化一个实例而不是通过config的方式去初始化。

基本参数

基本上就是一些数据、枚举类型的定义和声明

  • ModelArchitecture是模型架构枚举类的定义,我们使用SDXL_V1_BASE就行
  • Sampler是采样器枚举类的定义
  • Discretization是调度器枚举类的定义,这里就分为两种:传统DDPM和EDM调度器
  • Guider是指导消噪过程的guider类的定义,这里也就两种:恒等和按比例
  • SamplingParams是一个数据类,用于整个采样过程中所需要的所有参数的总结
  • SamplingSpec与SamplingParams类似,指特别参数,这里是因为不同的模型架构可能需要的参数不同所以才需要特别指出
# api.py
class ModelArchitecture(str, Enum):
    SD_2_1 = "stable-diffusion-v2-1"
    SD_2_1_768 = "stable-diffusion-v2-1-768"
    SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
    SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
    SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
    SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"


class Sampler(str, Enum):
    EULER_EDM = "EulerEDMSampler"
    HEUN_EDM = "HeunEDMSampler"
    EULER_ANCESTRAL = "EulerAncestralSampler"
    DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
    DPMPP2M = "DPMPP2MSampler"
    LINEAR_MULTISTEP = "LinearMultistepSampler"


class Discretization(str, Enum):
    LEGACY_DDPM = "LegacyDDPMDiscretization"
    EDM = "EDMDiscretization"


class Guider(str, Enum):
    VANILLA = "VanillaCFG"
    IDENTITY = "IdentityGuider"


class Thresholder(str, Enum):
    NONE = "None"

# 使用 @dataclass 装饰器可以自动生成一些常见的特殊方法,如 __init__、__repr__、__eq__ 等。简化类的定义
@dataclass
class SamplingParams:
    width: int = 1024
    height: int = 1024
    steps: int = 50
    sampler: Sampler = Sampler.DPMPP2M
    discretization: Discretization = Discretization.LEGACY_DDPM
    guider: Guider = Guider.VANILLA
    thresholder: Thresholder = Thresholder.NONE
    scale: float = 6.0
    aesthetic_score: float = 5.0
    negative_aesthetic_score: float = 5.0
    img2img_strength: float = 1.0
    orig_width: int = 1024
    orig_height: int = 1024
    crop_coords_top: int = 0
    crop_coords_left: int = 0
    sigma_min: float = 0.0292
    sigma_max: float = 14.6146
    rho: float = 3.0
    s_churn: float = 0.0
    s_tmin: float = 0.0
    s_tmax: float = 999.0
    s_noise: float = 1.0
    eta: float = 1.0
    order: int = 4


@dataclass
class SamplingSpec:
    width: int
    height: int
    channels: int
    factor: int
    is_legacy: bool
    config: str
    ckpt: str
    is_guided: bool


model_specs = {
    ModelArchitecture.SD_2_1: SamplingSpec(
        height=512,
        width=512,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_2_1.yaml",
        ckpt="v2-1_512-ema-pruned.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SD_2_1_768: SamplingSpec(
        height=768,
        width=768,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_2_1_768.yaml",
        ckpt="v2-1_768-ema-pruned.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=False,
        config="sd_xl_base.yaml",
        ckpt="sd_xl_base_0.9.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_xl_refiner.yaml",
        ckpt="sd_xl_refiner_0.9.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=False,
        config="sd_xl_base.yaml",
        ckpt="sd_xl_base_1.0.safetensors",
        is_guided=True,
    ),
    ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
        height=1024,
        width=1024,
        channels=4,
        factor=8,
        is_legacy=True,
        config="sd_xl_refiner.yaml",
        ckpt="sd_xl_refiner_1.0.safetensors",
        is_guided=True,
    ),
}

SamplingPipeline

在声明确定了各种参数之后,就可以开始准备推理流程了,首先需要初始化一个采样的pipeline
pipeline包含四个主要内容:
1.初始化(load模型)
2.完成text2img
3.完成img2img
4.SDXL中还有一个refine过程,其实本质就是img2img
我会主要解释一下1和2,因为3、4是类似的

class SamplingPipeline:
    def __init__(
        self,
        model_id: ModelArchitecture, # 模型的id,也就是上面的枚举类
        model_path="checkpoints", # 模型的存放路径
        config_path="configs/inference", # 模型对应的config路径
        device="cuda", 
        use_fp16=True,
    ) -> None:
        if model_id not in model_specs:
            raise ValueError(f"Model {model_id} not supported")
        self.model_id = model_id
        self.specs = model_specs[self.model_id] # 这里通过model_id拿到之前对应不同的model的不同参数
        self.config = str(pathlib.Path(config_path, self.specs.config))
        self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
        self.device = device
        # 初始化的核心步骤就是能够正确的初始化model
        self.model = self._load_model(device=device, use_fp16=use_fp16)

    def _load_model(self, device="cuda", use_fp16=True):
        # 从config去load模型,主要就是看是ckpt格式还是safetensor格式
        # 然后会打印一些缺失的模型结构名称等等
        config = OmegaConf.load(self.config)
        model = load_model_from_config(config, self.ckpt)
        if model is None:
            raise ValueError(f"Model {self.model_id} could not be loaded")
        model.to(device)
        # 要注意模型被分为了conditioner、model、denoiser等几部分
        if use_fp16:
            model.conditioner.half()
            model.model.half()
        return model

    def text_to_image(
        self,
        params: SamplingParams,
        prompt: str,
        negative_prompt: str = "",
        samples: int = 1,
        return_latents: bool = False,
    ):
        sampler = get_sampler_config(params) # 初始化采样器
        value_dict = asdict(params) # 将之前的@dataclass类转直接转化为字典
        value_dict["prompt"] = prompt
        value_dict["negative_prompt"] = negative_prompt
        value_dict["target_width"] = params.width
        value_dict["target_height"] = params.height
        # do_sample是本质函数
        return do_sample(
            self.model,
            sampler,
            value_dict,
            samples,
            params.height,
            params.width,
            self.specs.channels,
            self.specs.factor,
            force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
            return_latents=return_latents,
            filter=None,
        )

    def image_to_image(
        self,
        params: SamplingParams,
        image,
        prompt: str,
        negative_prompt: str = "",
        samples: int = 1,
        return_latents: bool = False,
    ):
        sampler = get_sampler_config(params)

        if params.img2img_strength < 1.0:
            sampler.discretization = Img2ImgDiscretizationWrapper(
                sampler.discretization,
                strength=params.img2img_strength,
            )
        height, width = image.shape[2], image.shape[3]
        value_dict = asdict(params)
        value_dict["prompt"] = prompt
        value_dict["negative_prompt"] = negative_prompt
        value_dict["target_width"] = width
        value_dict["target_height"] = height
        return do_img2img(
            image,
            self.model,
            sampler,
            value_dict,
            samples,
            force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
            return_latents=return_latents,
            filter=None,
        )

    def refiner(
        self,
        params: SamplingParams,
        image,
        prompt: str,
        negative_prompt: Optional[str] = None,
        samples: int = 1,
        return_latents: bool = False,
    ):
        sampler = get_sampler_config(params)
        value_dict = {
            "orig_width": image.shape[3] * 8,
            "orig_height": image.shape[2] * 8,
            "target_width": image.shape[3] * 8,
            "target_height": image.shape[2] * 8,
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "crop_coords_top": 0,
            "crop_coords_left": 0,
            "aesthetic_score": 6.0,
            "negative_aesthetic_score": 2.5,
        }

        return do_img2img(
            image,
            self.model,
            sampler,
            value_dict,
            samples,
            skip_encode=True,
            return_latents=return_latents,
            filter=None,
        )

sampler

采样器主要分三个内容:离散方法(噪声控制器)、guider方法(去噪控制器)、采样器

def get_sampler_config(params: SamplingParams):
    # 通过config初始化采样器
    # 主要初始化两个东西,一个离散化的config
    # 一个是guider的config
    discretization_config = get_discretization_config(params)
    guider_config = get_guider_config(params)
    sampler = None
    if params.sampler == Sampler.EULER_EDM:
        return EulerEDMSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            s_churn=params.s_churn,
            s_tmin=params.s_tmin,
            s_tmax=params.s_tmax,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.HEUN_EDM:
        return HeunEDMSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            s_churn=params.s_churn,
            s_tmin=params.s_tmin,
            s_tmax=params.s_tmax,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.EULER_ANCESTRAL:
        return EulerAncestralSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            eta=params.eta,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
        return DPMPP2SAncestralSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            eta=params.eta,
            s_noise=params.s_noise,
            verbose=True,
        )
    if params.sampler == Sampler.DPMPP2M:
        return DPMPP2MSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            verbose=True,
        )
    if params.sampler == Sampler.LINEAR_MULTISTEP:
        return LinearMultistepSampler(
            num_steps=params.steps,
            discretization_config=discretization_config,
            guider_config=guider_config,
            order=params.order,
            verbose=True,
        )

    raise ValueError(f"unknown sampler {params.sampler}!")

discretization

可以看到在得到discretization的config后,sampler初始化会用到它,在/sgm/modules/diffusionmodules/sampling.py的31行可以看到instantiate_from_config的调用,用于初始化一个discretization。方法是utils.py模块中的instantiate_from_config和get_obj_from_str方法,就是通过config字段去有选择的import module,这种通过配置文件去import module的方式可以学习一下。
最后,discretization类的实现都在sgm/modules/diffusionmodules/discretizer.py下,这里分部进行注释。

基类

一句话,Discretization就是在完成σt\sigma_t序列的生成(实际上就是论文里面的time step)。
确定采样步数(step_num)和sigma后

class Discretization:
    def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
        # 重点就是get_sigmas方法,父类定义了个抽象方法,由子类具体实现
        # 这里也是区分了EDM和DDPM两种不同的实现
        sigmas = self.get_sigmas(n, device=device)
        # 
        sigmas = append_zero(sigmas) if do_append_zero else sigmas
        return sigmas if not flip else torch.flip(sigmas, (0,))

    @abstractmethod
    def get_sigmas(self, n, device):
        pass

EDM

可以参照EDM论文:

class EDMDiscretization(Discretization):
    def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.rho = rho

    def get_sigmas(self, n, device="cpu"):
        # 跟论文的表格一模一样
        ramp = torch.linspace(0, 1, n, device=device)
        min_inv_rho = self.sigma_min ** (1 / self.rho)
        max_inv_rho = self.sigma_max ** (1 / self.rho)
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
        return sigmas

DDPM

生成一组线性序列βt\beta_t函数,噪声强度的调度由βt\beta_t控制

def make_beta_schedule(
    schedule,
    n_timestep,
    linear_start=1e-4,
    linear_end=2e-2,
):
    if schedule == "linear":
        betas = (
            torch.linspace(
                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
            )
            ** 2
        )
    return betas.numpy()

初始化先生成一组线性序列βt\beta_t,然后计算αt=1βt\alpha_t = 1 - \beta_tαtˉ=i=1tαi\bar{\alpha_t} = \prod_{i=1}^t \alpha_iαtˉ\bar{\alpha_t}可以看作每个时间步累计的噪声强度。

## 生成从max_step - 1到0之间的num_substeps个等间隔的数值。
def generate_roughly_equally_spaced_steps(
    num_substeps: int, max_step: int
) -> np.ndarray:
    return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]


class LegacyDDPMDiscretization(Discretization):
    def __init__(
        self,
        linear_start=0.00085,
        linear_end=0.0120,
        num_timesteps=1000,
    ):
        super().__init__()
        self.num_timesteps = num_timesteps
        # 生成beta序列
        betas = make_beta_schedule(
            "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
        )
        alphas = 1.0 - betas
        # 计算alpha-bar
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.to_torch = partial(torch.tensor, dtype=torch.float32)

最后生成标准差σt\sigma_t的序列为1αˉtαˉt\sqrt{\frac{1 - \bar\alpha_t}{\bar\alpha_t}}

    def get_sigmas(self, n, device="cpu"):
        # n如果小于1000步就重新在1000步里面取n个点出来,对应上原来的alpha-bar
        if n < self.num_timesteps:
            timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
            alphas_cumprod = self.alphas_cumprod[timesteps]
        elif n == self.num_timesteps:
            alphas_cumprod = self.alphas_cumprod
        else:
            raise ValueError

        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
        sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
        return torch.flip(sigmas, (0,))

Guider

guider有两个作用,一个是准备去噪器的输入参数(prepare_inputs方法),一个是对去噪的结果再进行一点后处理(call)。

prepare_inputs

准备去噪器的输入参数基本上是一套固定的流程,不同的guider基本上都一致。
x表示输入,s表示sigma,c和uc表示条件向量和无条件向量。
主要其实就是将c和uc合并放到一个dict里面。

def prepare_inputs(self, x, s, c, uc):
        c_out = dict()

        for k in c:
            if k in ["vector", "crossattn", "concat"]:
                c_out[k] = torch.cat((uc[k], c[k]), 0)
            else:
                assert c[k] == uc[k]
                c_out[k] = c[k]
        return torch.cat([x] * 2), torch.cat([s] * 2), c_out

call guider

去噪后会有两个输出,一个是有条件输出一个是无条件输出,cat在一起也就是x(guider的输入)
比如IdentityGuider就是直接返回x

    def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
        x_u, x_c = x.chunk(2) # 把denoiser的输出分为无条件结果和有条件结果
        x_pred = x_u + self.scale * (x_c - x_u) # 根据scale混合
        return x_pred