Tag: diffusion

  • Qwen-Image

    GPT4o’s image generation was a remarkable event, beyond the brief Ghiblification of all social media.GPT-4o offered significantly more steerability than earlier image generation models,, while offering image quality in the ball park of the best diffusion models. Qwen-Image gives a similar level of fidelity and accuracy and is an open-weights model with a pretty decent technical report: QwenLM/Qwen-Image.

    While I was fairly familiar with diffusion models, I wasn’t really familiar with the backbone of this model, the multimodal diffusion transformer (MMDiT). Rather than just look at it, I vibed up a repo with Claude Code that went step by step through the architectures, training on good old MNIST. ianbarber/diffusion-edu — which spat out this:

    This ended up being a helpful way to go step by step through the evolution of diffusion models. 

    Loss/Target

    Modern image generation really kicked off with GANs. GANs were a clever idea that exploited the fact that we are better at building classifiers than generators by using one to bootstrap the other. A generator would generate an image against a reference, the discriminator would be given the generated image and the reference and have to predict which was the real one, and both networks were scored on how well they did on their tasks. This was effective, but was challenging to train. The generator also had to start from somewhere and what it effectively started from was noise: the generate would start with fairly random output and the discriminator would learn to identify noise vs the real image. 

    The clever idea Jonathan Ho and co had with DDPM was to focus on that noise: what if instead of learning to generate images we learned to remove noise from images. In the snippet below we:

    • Pick a timestep between 0 and 1000
    • Generate some noise
    • Add an amount of noise to the training image proportional to the timestep
    • Get the model to predict the noise, given the time step
    • Calculate the loss as the mean squared error between the known noise and the predicted noise
    # Sample random timestep
    t = torch.randint(0, 1000, (B,), device=device)
    
    # Add noise to image
    eps = torch.randn_like(x0)
    alpha_t = self.alpha_schedule(t)
    xt = sqrt(alpha_t) * x0 + sqrt(1 - alpha_t) * eps
    
    # Predict the noise we just added
    eps_pred = self.model(xt, t, cond)
    
    return F.mse_loss(eps_pred, eps)  

    This pretty much worked! You needed to use quite a few timesteps (around 1000), but the model would learn to discriminate noise from data. Then, you can reverse the process to generate: given a noisy starting point, generate some noise,  predict the noise at the first timestep, remove it, increment the timestep, then repeat, each time adding some noise and removing. 

    Song et al. followed this up with DDIM, identifying that one of the reasons you need so many samples is that you are injecting new noise each generation. If you start with the noise up front when sampling you have a much more deterministic process, and can generate in more like 50 steps than 1000: 

    x = torch.randn(*x_shape)  # Start with pure noise
    
    for i in reversed(range(steps)):
      t = torch.full((B,), i/steps)
      if target == TargetMode.EPS:
        eps = model(x, t, cond)
        x = (x - eps * dt) / sqrt(1 - dt)

    The next step, in 2021, was Classifier-Free Guidance from Ho and Salimans. The clever idea was to pass a conditioning variable through to the model: for example in our MNIST example it could be the digit we are learning from. However, during training we would sometimes zero it out. This means the model learns to generate conditionally (for the specific digit) and unconditionally (just in whichever direction looks the best). 

    if cond is not None and self.cfg_dropout_prob > 0:
      mask = torch.rand(B, 1, 1) < self.cfg_dropout_prob
    
      cond = cond * ~mask  # Zero out conditioning randomly
    
    return F.mse_loss(self.model(xt, t, cond), target)

    This gets useful at generation time. When sampling, we can sample both conditionally and unconditionally, and diff out the unconditioned part: 

    # Run model twice: with and without conditioning
    cond_pred = model(x, t, cond)
    uncond_pred = model(x, t, None)
    
    # Amplify the difference
    return uncond_pred + cfg_scale * (cond_pred - uncond_pred)

    If you imagine the sampling process as denoising, this is saying there is the “best” direction, given the condition, and the “best direction” overall. By reducing the influence of the overall best direction, we get clearer steerability, and effectively the model serves as its own iterative classifier. 

    Also in 2021, Song et al published Score-Based Generative Modeling through Stochastic Differential Equations. They framed the diffusion problem as a Stochastic Differential Equation (SDE), effectively a regular differential equation dx = f(x, t)dt with an additional noise term: dx = f(x, t)dt + g(t)dw1. That latter term g(t) controls how much random noise is involved.

    The contribution from the paper is that they worked out how to reframe this without that dw noise – e.g. they turned it into an “Ordinary” Differential Equation (ODE) without the random component. Then the model can be viewed as a velocity field that ends up having a similar shape to the one modelled by the random noise version, but is deterministic.

    Salimans & Ho were not done, and proposed another improvement to loss in V-Parameterization in the Imagen paper. One of the challenges with predicting the noise (eps above) is that when you get pretty close to a finished image there isn’t much noise, so the prediction isn’t particularly good. Similarly, when you are starting with pure noise it’s predicting almost everything, so also doesn’t give much information. Predicting the noise involves estimating both the clean sample and the noise. Some reordering lets you predict a single value, the velocity field (or gradients) which combines the clean sample (alpha), the noise (eps) the time step and the current (noised) sample. By having the model predict that we can balance between predicting the image and the noise, giving better results better at extremes. 

    v_target = alpha_t * eps - sigma_t * x0
    v_pred = self.model(xt, t, cond)
    
    return F.mse_loss(v_pred, v_target)

    Finally (on the loss) we get to flow matching from folks at Meta FAIR (Flow matching) and UT Austin (Rectified Flow). Rather than making the target a blend of start and noise, why don’t we just predict the straight path to the data. Compare the v_target below to the one above: 

    t = torch.rand(B, 1, 1, 1)
    z = torch.randn_like(x0)
    
    # Straight line: xt = (1-t)*x0 + t*z
    xt = (1 - t) * x0 + t * z
    
    # Learn the velocity field pointing from noise to data
    v_target = x0 - z  # The straight path direction
    v_pred = self.model(xt, t.squeeze(), cond)
    
    return F.mse_loss(v_pred, v_target)

    Flow matching models often converge faster during training and can generate good samples with fewer steps. They also tend to have more consistent quality across different sampling step counts.

    Architecture

    All of that evolution was about the loss function and sampling, and we haven’t really discussed the model architecture itself. The original diffusion models used an approach called Unets: a series of convolutions that compressed the (latent) visual information into fewer dimensions, then expanded it back up (giving a sort of U shape). But post-ChatGPT the Transformer was ascendent, so in 2023 Peebles and Xie proposed swapping out the Unet for a stack of transformers in the Diffusion Transformers (DiT) paper. 

    class DiTTiny(nn.Module):
        def __init__(self, embed_dim=256, depth=6):
            # Patchify the image (like ViT)
            self.patch_embed = PatchEmbed(patch_size=2)
    
            # Stack of transformer blocks
            self.blocks = nn.ModuleList([
             TransformerBlock(embed_dim) for _ in range(depth)
            ])
    
        def forward(self, x, t, cond=None):
            # Convert image to patches
            x = self.patch_embed(x)  # (B, num_patches, embed_dim)
    
            # Add positional encoding
            x = x + self.pos_embed
    
            # Transform through attention layers
            for block in self.blocks:
                x = block(x, t_emb)
    
            # Reshape back to image
            return self.unpatchify(x)

    This looks like a regular transformer but with patches (segments of the image) rather than text tokens, as in ViT understanding models. The transformer block will also look pretty familiar 

    class TransformerBlock(nn.Module):
      def __init__(self, dim, heads=8, mlp_ratio=4.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
          nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(), nn.Linear(int(dim*mlp_ratio), dim)
      )
    
      def forward(self, x):
        h = self.ln1(x)
        x = x + self.attn(h, h, h, need_weights=False)[0]
        x = x + self.mlp(self.ln2(x))
        
        return x

    They got good results and more importantly it was easier to scale up to more compute and larger inputs. For what it’s worth, I found DiTs a bit tricky for training on small data sets (like the mnist example), but didn’t spend much time on it, since: 

    MMDiTs emerged in 2024, and were used for Stable Diffusion 3 and Flux, largely setting the standard in terms of image quality. The idea is to process images and text in parallel with the ability to attend across each other, reminiscent of cross-encoder models.

    class MMDiTTiny(nn.Module):
        def __init__(self, img_dim=256, txt_dim=256):
            # Separate encoders for each modality
            self.img_encoder = nn.Linear(patch_dim, img_dim)
            self.txt_encoder = nn.Linear(txt_dim, txt_dim)
    
            # Joint transformer blocks
            self.blocks = nn.ModuleList([
                CrossTransformerBlock(img_dim, txt_dim) for _ in range(depth)
            ])
    
        def forward(self, img, t, txt=None):
            # Process both modalities
            img_tokens = self.img_encoder(patchify(img))
            txt_tokens = self.txt_encoder(txt) if txt is not None else None
    
            # Bidirectional attention between modalities
            for block in self.blocks:
                img_tokens, txt_tokens = block(img_tokens, txt_tokens, t)
    
            return unpatchify(img_tokens)

    MMDiT models demonstrate great prompt adherence and can handle complex requests. The bidirectional flow means text understanding improves alongside image generation.

    class CrossTransformerBlock(nn.Module):
    """Cross‑attention: query=image tokens, key/value = text tokens."""
    
        def __init__(self, dim_img, dim_txt, heads=8, mlp_ratio=4.0):
            super().__init__()
            self.q_proj = nn.Linear(dim_img, dim_img)
            self.k_proj = nn.Linear(dim_txt, dim_img)
            self.v_proj = nn.Linear(dim_txt, dim_img)
    
            self.attn = nn.MultiheadAttention(dim_img, heads, batch_first=True)
    
            self.ln_q = nn.LayerNorm(dim_img)
            self.ln = nn.LayerNorm(dim_img)
            self.mlp = nn.Sequential(
                nn.Linear(dim_img, int(dim_img*mlp_ratio)), nn.GELU(), nn.Linear(int(dim_img*mlp_ratio), dim_img)
            )
    
        def forward(self, x_img, x_txt):
            q = self.q_proj(self.ln_q(x_img))
            k = self.k_proj(x_txt)
            v = self.v_proj(x_txt)
    
            x = x_img + self.attn(q, k, v, need_weights=False)[0]
            x = x + self.mlp(self.ln(x))
    
            return x

    Here, in the cross attention block the image is used for the Query part and the text for the K and V parts of the attention. The results are combined with the image input. 

    Putting this all together, you can see the evolution of the common diffusion baselines across both scale and steerability:

    1. DDPM: Clean but slow. The baseline everything else improves on.
    2. SD1-style (UNet + Epsilon + CFG): The first practical system. Good quality, reasonable speed, follows prompts well with CFG.
    3. SD2-style (UNet + V-param + CFG): Slightly better contrast and stability, especially at high resolutions.
    4. SD3-style (MMDiT + Flow): The current state-of-the-art. Fastest training, best prompt adherence, most efficient sampling.

    Back to Qwen

    The Qwen-Image model is a good, practical example of scaling this up. It uses an existing multimodal model2 () to encode text and image inputs, a pretrained VAE3 for translating between pixel and latent space, and then as its backbone an MMDiT. The use of strong (understanding) models for encoding helps really enhance the steerability of the results in the MMDiT. 

    In the MMDiT sketch above we just concatenate image and text together. In real systems we first add the positional embeddings for the image tokens, then add on text tokens. This works, but made it difficult to adapt to different image resolutions.

    Seedream introduced Scaling RoPE4 that instead puts the image positional encoding in the middle of the image, treats the text tokens as 2D shapes [1,L], then applied 2D RoPE to both text and image tokens. This worked better, but had some problems where the positions were confusable between text and image latents, meaning the model couldn’t properly differentiate in some cases. The Qwen team updates this by implementing positional encoding across both dimensions of the text tokens, and concatenating the text along the diagonal of the image:

    This design allows MSRoPE to leverage resolution scaling advantages on the image side while maintaining functional equivalence to 1D-RoPE on the text side, thereby obviating the need to determine the optimal positional encoding for text.

    The resolution independence is important for the training recipe. The model is progressively trained  with images starting at 256×256 and increasing in steps up to 1328x, in a variety of aspect ratios. They follow it up with post-training consisting of SFT on organized, high quality image-text pairs and DPO against preference pairs judged by human raters5. Finally, they do a GRPO stage with a “reward model”: though it isn’t clear if that’s based on the aforementioned preference data or is some other secret sauce. 

    While we don’t know how GPT-image is trained, this recipe certainly gave some comparable results. I was surprised to learn that the combination of a strong text and image encoding model plus MMDiT6 gives this level of steerability and fidelity. As usual, it’s exciting to have open models and papers to bring these concepts together! 

    1.  Its w because the noise is a Weiner process, also known as standard Brownian motion. I am heavily conditioned to think of this as the motion in a cup of tea thanks to HHGTTG
      ↩︎
    2. Qwen 2.5-VL ↩︎
    3. Interestingly, a video one from Wan-2.1 ↩︎
    4. Roughly the same idea was about as Column-wise position encoding as I understand it. 
      ↩︎
    5.  The same prompt with two different seeds, and — if present — a reference image
      ↩︎
    6. And a lot of very carefully curated and programmatically generated data, to be fair
      ↩︎