首页
/ 深入理解Diffusers中的Flux Control训练技术

深入理解Diffusers中的Flux Control训练技术

2026-02-04 05:22:42作者:齐添朝

概述

在扩散模型领域,Flux Control技术是一种创新的结构条件控制方法,它能够通过额外的控制信号(如深度图、姿态图等)来引导图像生成过程。本文将深入探讨如何在Diffusers框架中训练Flux Control LoRA模型,以及相关的技术细节和最佳实践。

Flux Control技术原理

Flux Control的核心思想是通过扩展模型的输入特征空间来实现条件控制。具体来说:

  1. 原始Flux模型的输入特征维度为64,对应待去噪的潜在空间
  2. Flux Control将其扩展为128维,其中前64维保持不变,后64维用于编码控制信号
  3. 这种扩展发生在x_embedder层,该层将组合后的潜在向量投影到网络期望的特征维度

这种设计使得模型能够在保持原有生成能力的同时,有效地融合控制信号的信息。

准备工作

由于Flux模型是受控访问的,使用前需要完成以下步骤:

  1. 访问Flux.1 [dev]模型页面并填写申请表格
  2. 接受访问条款后,使用以下命令登录:
huggingface-cli login

训练流程详解

基础训练配置

Flux Control的训练可以通过以下关键参数进行配置:

accelerate launch train_control_lora_flux.py \
  --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
  --dataset_name="raulc0399/open_pose_controlnet" \
  --output_dir="pose-control-lora" \
  --mixed_precision="bf16" \
  --train_batch_size=1 \
  --rank=64 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=5000 \
  --validation_image="openpose.png" \
  --validation_prompt="A couple, 4k photo, highly detailed" \
  --offload \
  --seed="0" \
  --push_to_hub

高级训练选项

训练脚本提供了几个值得关注的高级选项:

  1. LoRA偏置训练:通过use_lora_bias参数可以额外训练lora_B层的偏置项
  2. 归一化层训练train_norm_layers参数允许训练归一化尺度参数
  3. 层选择lora_layers参数可以指定应用LoRA的层,如"all-linear"表示所有线性层

DeepSpeed集成

为了提升训练效率,可以使用DeepSpeed的Zero2系统优化。配置示例如下:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

使用时通过--config_file参数指定配置文件。

推理流程

训练完成后,可以使用以下流程进行推理:

  1. 安装必要的依赖:
pip install controlnet_aux
  1. 推理代码示例:
from controlnet_aux import OpenposeDetector
from diffusers import FluxControlPipeline
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch 

# 初始化管道
pipe = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights("...")  # 替换为训练好的LoRA权重路径

# 准备姿态条件
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]           
image = Image.fromarray(np.uint8(image))

# 生成图像
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
    prompt=prompt,
    control_image=image,
    num_inference_steps=50,
    joint_attention_kwargs={"scale": 0.9},
    guidance_scale=25., 
).images[0]
gen_images.save("output.png")

完整微调方案

除了LoRA训练外,Diffusers还提供了完整的微调方案:

accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
  --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
  --dataset_name="raulc0399/open_pose_controlnet" \
  --output_dir="pose-control" \
  --mixed_precision="bf16" \
  --train_batch_size=2 \
  --dataloader_num_workers=4 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --proportion_empty_prompts=0.2 \
  --learning_rate=5e-5 \
  --adam_weight_decay=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="cosine" \
  --lr_warmup_steps=1000 \
  --checkpointing_steps=1000 \
  --max_train_steps=10000 \
  --validation_steps=200 \
  --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
  --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
  --offload \
  --seed="0" \
  --push_to_hub

完整微调后的推理流程略有不同,需要加载训练好的Transformer模型:

transformer = FluxTransformer2DModel.from_pretrained("...")  # 替换为训练好的模型路径
pipe = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  
    transformer=transformer, 
    torch_dtype=torch.bfloat16
).to("cuda")

注意事项

  1. 当前提供的训练脚本主要用于教育和实验目的,可能需要针对特定条件进行调整
  2. 脚本未进行内存优化,但可以通过--offload参数将VAE和文本编码器在不用时卸载到CPU
  3. 虽然当前未提供直接工具,但可以从完整微调的模型中提取LoRA权重

通过本文的详细讲解,读者应该能够全面理解Flux Control技术的原理和实现方式,并能够在Diffusers框架中有效地训练和应用这一技术。

登录后查看全文
热门项目推荐
相关项目推荐