Category: posts

  • Useful Reasoning Behaviors

    [2503.01307] Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs

    Very useful insight in this paper out of Stanford.

    Test-time inference has emerged as a powerful paradigm for enabling language models to “think” longer and more carefully about complex challenges, much like skilled human experts. While reinforcement learning (RL) can drive self-improvement in language models on verifiable tasks, some models exhibit substantial gains while others quickly plateau.

    The authors were running a reasoning post-training process on both Qwen 2.5 3B and Llama 3.2 3B. They noticed that while both learned, Llama was consistently worse than Qwen, which feels odd as both models are strong. In looking at the reasoning approaches exhibited they observed 4 distinct reasoning strategies:

    • verification
    • backtracking
    • subgoal setting
    • backward chaining 

    They noticed that Qwen exhibited these behaviors more from the base model, and those behaviors were enhanced more in the RL process.

    While the larger Llama-3.1-70B showed generally increased activation of these behaviors compared to Llama-3.2-3B, this improvement was notably uneven — backtracking, in particular, remained limited even in the larger model.

    They then generated some custom reasoning traces that intentionally demonstrated all 4 behaviors, using Claude.

    We generate these datasets using Claude-3.5-Sonnet4, leveraging its ability to produce reasoning trajectories with precisely specified behavioral characteristics. While Claude does not always produce the correct answer (see Fig. 9), it consistently demonstrates the requested reasoning patterns, providing clean behavioral primitives for our analysis.

    They found that when using the SFT set before RL they closed most of the gap between Llama and Qwen. They also found that it isn’t even important that the reasoning traces are correct – demonstrating the behavior is more important than the reasoning itself at the SFT stage.

    Priming models with cognitive behaviors, by a small amount of finetuning, enabled significant performance gains even in models that initially lack these capabilities. Remarkably, this holds even when primed with incorrect solutions that exhibit the target behavioral patterns, suggesting that cognitive behaviors matter more than solution accuracy. 

    This fits with the elicitation idea — the SFT is training a style. By having increased activation of the reasoning styles the RL process is more able to explore these capabilities and reinforce extended reasoning generation.

    This also fits with my mental model that a base model’s capabilities are often pretty underexplored: a combination of targeted SFT + RL seems to be a very powerful elicitation tool!

  • Vibe Coding & Code Review

    Vibe coding is trendy right now, and it’s part of a bigger shift toward AI-generated code. One of the questions I have about this, at scale, is how it will interact with code review.

    When you go public, you have to comply with Sarbanes-Oxley. SOX mandates a separation of duties, which in practical terms usually means having a process by which changes are tested and approved by someone other than the author before being deployed, and often that is implemented through code review. Companies like Google, Meta, Snapchat, and Uber also have compliance obligations around privacy and security thanks to FTC consent decrees. While these mandates don’t directly say “do code reviews” (they mostly mandate programs) in practice, code reviews are key checkpoints for compliance.

    As we use more model-generated code, traditional code review processes will have to change. Maybe engineers become reviewers of AI-generated code, or we start clearly separating AI-written code from the manually-reviewed stuff—similar to what already happens with some current codegen tools.

    Companies will have to think about ownership and stewardship of changes. Code review requirements like OWNERS or Readability at Google require an approver who isn’t the original author: if the change is LLM authored, is it acceptable for there to be no human author, or is it owned by the person who kicked off the workflow, precluding them from being the reviewer? If we have automated detect-change flows set up (e.g. upgrading downstreams when a dependency changes), is it an individual or a team/oncall that owns the change?

    The bottleneck is human attention and focus. It is plausible to envisage 10-1000x more changes in a large code base, but current review practices simply wouldn’t do an effective job: at best, many changes would be rubber stamped. Work like the diff risk score shows you can do some degree of triage or prioritizing with models. Conceptually you could see extending this to privacy, security and more specific types of risk scoring.

    Work like Policy Zones associates compliance requirements with data and asserts controls in the data flow of the system. This may be easier to scale than trying to validate on code changes.

    Static analysis can also catch patterns of bad usage, like recent work in detecting scraping opportunities. This feels especially important for LLM generated code where a well diffused but bad pattern might be easier to generate than a more novel but safe pattern.

    I expect to see more pressure on developer infrastructure teams to build out capabilities for risk detection, automatic validation and embedding policy information into code or data. There will be an advantage to these being open, industry standard approaches as foundation models will do a better job and require less fine tuning to company specific idioms.

  • Post-Training & Elicitation

    Nathan Lambert of the Allen Institute writes about their (very strong) Olmo 2 32B release, and the just released Gemma 3 model from Google. One of the many interesting points:

    Comparing Gemma 3 27B to OLMo 32B, the pretraining evaluations for both are super similar, but Gemma 3 scores are way better after post-training. The ceiling on post-training expectations has been shifting extremely fast among open models.

    Given that Google have about the best crawling infrastructure in the world, and that Al2 have published the complete pretraining dataset used for Olmo, I think this is slightly surprising. You can see the benchmarks in the blog and technical report: for example, Gemma 3 27B gets 78.8 on winogrande from pretraining (a little below Gemma2 as it happens) while Olmo2 32B get 78.7.

    The vibes have definitely shifted to post-training for where model differentiation is coming from, opening the question of what exactly is happening there. Nathan also posted about that recently, linking to this post by Mohit Ragavendra of Scale and Georgia tech:

    The post looks at The Superficial Alignment Hypothesis, which is (largely) that post-training is just about preference tuning for behaviors the base model can already do

    […]

    It initially seems like “Less Is More” in the sense that the LIMA model response was highly preferred by the GPT-4 evaluator for Math prompts (in-line with the work’s original claim). However, these model responses were also largely incorrect – the accuracy of models fine-tuned specifically for Math was substantially better, with the same data budget. If we went by subjective win-rate comparisons, we would have picked a model that was significantly worse.

    In the post (and the two linked papers) Mohit breaks down how post-training actually helps. Starting with SFT, the work shows that mimicking style happens quickly, with relatively few samples.

    with just a hundred finetuning examples, the model’s formatting mistakes were virtually solved – the model was perfect at mimicking the expected style. However, the model took a lot more supervised finetuning data to get better at reasoning – the substance of the task.

    They find though that, largely, more-is-more when it comes to SFT, but that there is a power-law style scaling curve: big gains initially followed by slower, marginal gains. Adding in RL doesn’t change the fundamental curve, but it does shift it, leading more efficiently to the model gaining the reasoning capabilities they were training towards:

    Preference data offers a weaker signal compared to supervised finetuning data. So, running DPO directly on the base model on reasoning tasks, is asking the model to learn a completely different response style from its reference model, with a weaker signal, while penalizing for being different from the reference model. Small amount of SFT on the base model teaches it the reasoning style and PFT can use the reward signal to focus on reasoning within the required response space.

    I did wonder when reading this whether the results would look different with an online process (like PPO), rather than an offline. Luckily, Mohit links to another recent paper on this topic:

    We prove that under idealized assumptions, online
    and offline PFT techniques should return policies of
    equivalent quality

    but also

    we observe that despite the lack of information-theoretic separation, online PFT out-performs offline PFT across different sampling distributions, labelers, and model sizes. Furthermore, it appears to be “easier” to learn a global RM than it is to learn a local RM, leading to higher validation likelihood.

    The result here seems to be that the reward model is simply easier to model, and it helps “translate” the problem of the distribution.

    This all feels like a continuum: at some level the superficial alignment hypothesis is directionally correct but its not that “superficial”: the base models have a lot of capabilities that are hard to elicit, and fine tuning/post training can juice them effectively, while adding some learning of its own (as more data is better!)

    The best way of performing that elicitation turns out to be solving different problems at different levels: SFT for format, then RL for the deeper capability, and having a reward model effectively simplifies the learning process again.

  • Quantization in PyTorch

    Jerry Zhang recently posted a couple of updates on the evolution of the quantization APIs in PyTorch, and the unification around TorchAO.

    If you haven’t spent much time around quantization, or are used mainly to quantizing LLM models via tools, the range of options can be pretty gnarly.

    Quantization compresses the model by taking a number format with a wide range and replacing it with something shorter. To recover the original value you track a scale factor and a zero point (sometimes referred to as affine quantization).

    For example, if you have a float32 layer, but all of the parameters with it are between 1 and 10, then using a int8 (256 values) to represent the whole float range will compress all those values to a single value, losing a huge amount of information. Instead, you set the scale factor to cover the range of values actually present. This lowers the quantization error: the difference between the upscaled value and the original.

    That’s easy to do for the weights of the models, but you need to calculate the activations as well (so you are multiplying matrics of the same time). “Static” quantization determines the scale factor and zero point for activations up front, while “dynamic” quantization calculates them at inference time, resulting in better accuracy. The downside is that generally the approach only works on CPUs as it’s inherently data-dependent.

    You don’t have to quantize everything the same way, and its common to see quantization schemes in the format of A16W8 . That indicates the weights are quantized to 8-bit (usually int8), but the activations are kept at 16 bit (usually float16 or bfloat16). In those cases, the weights are upcast to a matching dtype at compute time, but you still benefit from the faster loading and lower persistent memory usage.

    The general flow for quantization is:

    • Identify which parts of the model you want to quantize: often you’ll want to quantize some parts, and not others (for example, all the linear layers, but not a softmax)
    • Prepare the parts being quantized by adding quantize and dequantize operations around the normal ops.
    • For static quantization (where scale/zero point are set ahead of time for activations) calibrate the model by sending input data through and observing the ranges.
    • Convert the model, by replacing the layers with their quantized, lower bit representations, and ensure the appropriate operators are in place to dequantize when the

    Quantizing after training is Post-Training Quantization (PTQ). Quantization-Aware Training (QAT) introduces quantization during training, allowing the model to learn and partially recover accuracy loss.

    Quantizing in PyTorch

    There are (at least!) 4 different approaches to quantization in PyTorch:

    • Eager Mode: Deprecated, simple to use with quantize_dynamic for quick dynamic quantization.
    • FX Mode: Also deprecated; separates quantization from model code via FX graph tracing but requires FX-traceability.
    • PT2E Mode: Current method using PyTorch 2 export. It captures the model graph, supports backend-specific configs (like XNNPack), and is preferable for exported models.
    • TorchAO Quantization: Latest method optimized for torch.compile. Includes easy-to-use features like autoquant for automatic quantization tuning as well as manual options.

    Eager mode quantization

    The original, and soon to be deprecated, quantization method in PyTorch was eager mode quantization. The simplest version to use is calling torch.ao.quantization.quantize_dynamic. This takes a config (identify what parts of the model you want to quantize), and a target dtype. That call scales the weights and downcasts, and injects ops to dynamically quantize activations. The prepare and convert steps are handled automatically, and there is no calibration as the activations will be scaled dynamically.

    Eager mode can also static quantization with torch.ao.quantization.quantize. This requires a lot more modification: you manually add torch.ao.quantization.QuantStub() and DequantStub around the nn.Module calls you want to operate quantized. Then you torch.ao.quantization.prepare the model and call the forward with some example data to calibrate. Prepare adds observers that collect statistics on the activations to determine good zero point and scale values. Finally torch.ao.quantization.convert processes and returns the quantized model.

    Quantization aware training works the same as before in terms of adding stubs, but rather than calibrating with input data you call torch.ao.quantization.prepare_qat and then run a regular training loop on the model.

    Modifying the model code itself is pretty painful as well, particularly if you are in an active, multi-collaborator code base. This led to the second evolution of PyTorch quantization:

    FX Mode Quantization

    The idea behind FX mode quantization was to give the same range of options as before, without having to change model code. This is particularly helpful when you have (say) a research team developing a model, and production team that is trying to make it fast for inference!

    The FX graph is a graph of operations created by tracing the model, and by working on the FX ops directly the library can make the quantization modifications while leaving the original code untouched. This works through pattern-matching in the FX graph and applying the transforms to add quant and dequant stubs automatically. The downside is it needs your model to be FX-traceable (as it would if you were using TorchScript), which often requires model changes.

    torch.ao.quantization.quantize_fx contains the methods, and they follow the same pattern: a quantization config, then prepare_fx, then convert_fx with options for running input through for calibration or running a training loop for QAT.

    FX tracing and TorchScript have been on the outs for a while due to their mix of complexity and inflexibility, and under the hood both this and the Eager variant use a quantized Tensor type which has been slated for deprecation. So, in their place we have…

    PT2E Quantization

    This one is not deprecated! But it is still changing a bit. The basic idea is very similar to FX quantization, except instead of using FX to capture the graph, we use PyTorch 2’s export feature. torch.ao.quantization.quantize_pt2e offer prepare_pt2e and convert_pt2e, again with an optional calibration/QAT step.

    One big difference is that the quantization config is now backend specific. Prior to calling prepare, you would set up the backend: e.g. for the XNNPack library:

    XNNPACKQuantizer().set_global(
        get_symmetric_quantization_config()
    )

    This allows different backends to define quantization set up based on the ops available to them.

    This does require capturing the full model graph with PT2 export, which can be painful, but if you’re going through the pain its best to do so with this flow rather than FX!

    If you don’t want to work with full graph capture, there is one other option which integrates with torch.compile:

    TorchAO Quantization

    Also not deprecated! The TorchAO library has a toolkit for quantization, and one particularly nice feature is autoquant:

    model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

    Underneath this will try different quantization schemes to identify the best size reduction vs accuracy reduction tradeoff.

    You can also quantize manually:

    quantize_(m, Int4WeightOnlyConfig(group_size=group_size))

    Here, the quantize function wraps the prepare/convert steps for you, so its pretty user friendly!

    Not sure what to use? Jerry’s table can help you distinguish, but if you’re quantizing for use in PyTorch (and using torch.compile) then prefer the TorchAO quantization. If you are quantizing as part of making the model available in non-Python environments then you’ll want to be able to export and use the PT2E flow.

  • Byte-Latent Transformers

    Who needs a tokenizer anyway!

    [2412.09871] Byte Latent Transformer: Patches Scale Better Than Tokens

    This paper, from back in December last year, presents an interesting approach to handling raw byte sequences in LLMs without relying on tokenization.

    Vocab sizes for tokenizers have gone up over the last couple of years with attendant gains in usefulness, but this remains a particularly hand-tuned number in the training process. BLT proposes a method that processes raw UTF-8 byte sequences directly, leveraging a dynamic patching mechanism to group bytes into variable-length patches based on entropy.

    Higher-entropy regions receive more attention and shorter patches, while lower-entropy regions can be processed more efficiently.

    There are conceptually three levels of processing:

    • Local Encoder: A small transformer stack encodes raw byte sequences into higher level representations, which are then structured into patches.
    • Latent Global Transformer: A standard large transformer model operating on patch-level representations
    • Local Decoder: The encoded patches are decoded back into byte sequences, using a cross-attention mechanism to reconstruct text.

    In the paper they show they can achieve parity in pretraining with a traditional tokenized approach in llama for similar parameter count, while being more robust and offering some inference time performance gains. The patching approach allows for allocating compute where needed most.

    Retrofitting existing models

    One of the ideas I found most interesting is starting with a traditionally pretrained model. The paper discusses using the main transformer layers from Llama and training the byte latent approach successfully.

    I gave the approach a go with a simplified local encoder, entropy and patching approach, and took the transformer layers from Qwen 2.5 3B, a strong model that could still be trained locally (no corporate resources were harmed, etc).

    The basic approach was replacing the tokenizer, adding a small transformer and patch pooling based on a local entropy measure to generate patches, then cross-attending in some of the Qwen layers. Its training a new encoder while leveraging Qwen for the backbone of the global transformer and adding new cross-attention params to make it also the decoder, with the embedding layers at each end chopped off – so a significant domain shift. For inference I leverage the same patch generation process to try and generate effective tokens.

    You can find my Torchtune recipe on GitHub, running through the Alpaca dataset. Thus far I’m still training so while loss is improving, I have no idea whether it will turn into something useful. The fact that there is something trainable is fun though, and I have hopes that this kind of technique will lead to some breakthroughs in tokenizer-free models in the future!

  • On career growth

    Particularly at the giant tech firms where there are many, many smart people, folks look at success and try to copy it. They set themselves up for promotions against the definition of what is expected at the next level. But those expectations describe an average, not a person. Real people are spiky.

    No one is great at everything, in every situation. Some of the most successful can create situations where they can use their strengths. They shape the work to fit them. Most of us don’t get that luxury, but we can often pick where to play. I’ve taken on projects and roles that looked reasonable, but I knew weren’t a fit, and the results have been from poor-to-fair, never great. That mismatch is costly.

    Choosing projects, teams, or roles is more significant than choosing how to work on then. They decide whether you work with your strengths or against them. Lean into strengths, lean into things that bring you energy and satisfaction. That doesn’t mean staying comfortable. It means knowing where you do your best work and pushing those capabilities or abilities even further, rather than attempting to contort yourself to a generic level N+1.

  • GRPO & Verifiable Rewards

    GRPO (Group Relative Policy Optimization) is an RL technique originally proposed in the DeepSeekMath paper. Instead of using a full-blown value network like PPO does, GRPO samples a group of completions for a given prompt and then computes a relative (normalized) reward for each output. The rewards are “verifiable” because they come from checking the final answer against ground truth and confirming. E.g. does the response follow the expected format (i.e. a <think>…</think> block for reasoning and an <answer>…</answer> block for the solution) and is the answer accurate against a predetermined fact. Not every problem fits this model, but there are a bunch that do, including math reasoning with the GSM8K dataset of grade-school math word problems. These look like this:

    “Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?”

    How Does the Training Work?

    1. Sampling Completions: For each prompt, the model generates a group of candidate completions. These are produced in inference mode (gradients aren’t collected) using a KV cache for speed (or a dedicated inference engine like VLLM)
    2. Verifiable Reward Calculation: Each completion is scored between 0 and 1—rewarding outputs that follow the prescribed format and yield the correct answer.
    3. Forward Pass for Gradients: Both the “policy” (the model being tuned) and a reference (typically the base, unmodified model) are used for a forward pass with the prompt and completions to compute per-token logits and log-probabilities.
    4. Loss and Backwards: The loss is then calculated as a combination of the (group-averaged) reward and a KL divergence term between the tuned model and the baseline, to constrain learning to similar responses. This loss is backpropagated through the policy model based on the earlier forward pass.

    Getting it going in TorchTune

    Over last weekend I hacked up a quick and dirty version of the training loop in the TorchTune, and over a couple of bus rides to Menlo Park cleaned it up into something that could work as a more general recipe(PR). Most of the work goes into the recipe and getting the dataset shaped properly to generate completions. This version—tested on a smaller model (the 1B Llama 3.2 variant, with LoRA)—showed some promising improvements in approach but I didn’t get to the point of having something converge enough to be confident in the overall recipe. In the DeepSeek R1 paper they had discussed trying a smaller model, but found 3B was the lowest they were able to get results on with some of their fine-tuning approaches.

    Luckily for everyone, at around the same time Ariel Kwiatkowski also put together a version that included distributed device support, making it easier to experiment on bigger models. This PR is more modular, and I’m excited to see it refined and landed so the recipe is widely available!

    There’s a growing energy around tools like torchtune, and it’s exciting to see how easy it is to “hack on” these ideas. It’s also great to see the techniques show up in other libraries, like HuggingFace’s TRL, which is being used as part of the OpenR1 replication effort!

  • Gradient Accumulation (was) busted

    This weekend I was reading the Tulu v3 paper (link), which offers a deep dive into building robust post-training setups. This is an very good resource for anyone aiming to build a really robust fine-tuning workflows. It covers critical elements like data set selection, synthetic data generation (with example prompts!), strategies for SFT and preference tuning, and various things they struggled with.

    One struggle was an issue with gradient accumulation they ran into where the loss was worse than without it on. The community at large also hit this, and fixed it, thanks to an excellent blog post by Unsloth (link).

    The bug

    Gradient accumulation is a technique used to simulate larger batch sizes by accumulating gradients over several smaller batches before performing a backward pass. This approach is particularly useful for managing memory constraints during training, so comes up a lot when post-training on a biggish model with more limited hardware.

    The problem arises when dealing with sequences of varying lengths within these mini-batches. In standard practice, the loss is calculated and normalized by the number of non-padded (i.e., valid) tokens in each sequence. However, when accumulating gradients across multiple mini-batches, each with different sequence lengths, the naive summation of gradients can lead to an incorrect total loss calculation.

    The discrepancy occurs because the cross-entropy loss function normalizes by the number of valid tokens, and this normalization factor can vary between mini-batches. When these normalized losses are accumulated without proper adjustment, the final loss does not match what would have been obtained using a single large batch. This results in a higher observed loss during training when using gradient accumulation compared to full batch training.

    Daniel and co at unsloth addressed this issue by developing a methodology that ensures the accumulated gradients are correctly scaled, accounting for the varying sequence lengths across mini-batches. This fix aligns the gradient accumulation process more closely with the theoretical foundations of full batch training, leading to more accurate loss calculations and improved training performance.

    Fixes and Workarounds

    Recent updates in both Hugging Face Transformers (pull request) and TorchTune (pull request) offer fixes. And, at least in Evan’s case, a little bit snark:

    In honor of the day the ML community first discovered the fact that (x1 / n1) + (x2 / n2) != (x1 + x2) / (n1 + n2)

    I really like seeing these small but practical problems pop up, and seeing the community rally around to fix them. I missed this when it happened in October, so glad to look back at it now!

  • nonzero_static in PyTorch

    Natalia Gimelshein added cuda support for nonzero_static to PyTorch the other day — its a feature Jax has had for a while that lets you avoid a bit of data dependent annoyingness.

    I most often see nonzero pop up in logs: its underlies a lot of boolean mask operations and torch.where. A downside of torch.nonzero is that we don’t know the size of the returned tensor. The shape is data dependent, which causes pain for the compiler, and when running on an accelerator requires a device to host sync. This can significantly slow down otherwise fast operations.

    nonzero_static overcomes this by allowing you to supply a shape for the output — if the actual result is smaller it is padded, if larger the result is truncated. The PR linked above enables that for CUDA. You can’t seamlessly use it with bool masks or torch.where, but you can easily replace the call with one to nonzero_static.

    To see the difference, imagine we have a big tensor where we have some sense of the output shape, for example searching for a one-hot index like this:

    import torch
    
    # Set device
    device = torch.device("cuda:0")
    
    # Generate a large tensor of 0s with exactly one 1 at a specific index on the GPU
    size = 100_000_000
    large_tensor = torch.zeros(size, device=device, dtype=torch.long)
    large_tensor[size // 2] = 1  # Place a single 1 in the middle of the tensor
    
    # Record GPU events to ensure asynchronicity
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    ones_indices = torch.nonzero(large_tensor)
    end_event.record()
    torch.cuda.synchronize()
    
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"torch.nonzero() execution time: {elapsed_time_ms:.2f} ms")
    print("Indices of 1s:", ones_indices.cpu())  

    That gives us (on my laptop)

    torch.nonzero() execution time: 177.58 ms
    Indices of 1s: tensor([[50000000]])

    If we modify it to use nonzero_static:

    import torch
    
    # Set device
    device = torch.device("cuda:0")
    
    # Generate a large tensor of 0s with exactly one 1 at a specific index on the GPU
    size = 100_000_000
    large_tensor = torch.zeros(size, device=device, dtype=torch.long)
    large_tensor[size // 2] = 1  # Place a single 1 in the middle of the tensor
    
    # Record GPU events to ensure asynchronicity
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    ones_indices = torch.nonzero_static(large_tensor, size=1)
    end_event.record()
    torch.cuda.synchronize()
    
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"torch.nonzero_static() execution time: {elapsed_time_ms:.2f} ms")
    print("Indices of 1s:", ones_indices.cpu())  
    
    torch.nonzero_static() execution time: 78.53 ms
    Indices of 1s: tensor([[50000000]])

    A very nice improvement!

  • Engineering Culture at Meta

    A question I’ve been asked recently is how Meta compares to other places I’ve worked, or what makes it different.

    From my conversations and observations, those who disliked working at Meta often cited chaos, short-term focus, and internal politics, while those who liked it called out  autonomy, speed, and the feeling they could work on important projects. To explain that disparity, I refer to three values or themes that shape the work culture.  

    Individuals are responsible for doing impactful work: “impact” is an important concept at Meta, and having a level-appropriate collection of impactful work at performance review time is important for every engineer. If you find yourself in a situation where your impact feels limited, you are generally responsible for exploring ways to address it, or make a change. 

    This means that ICs (individual contributors) at Meta are willing to cross team boundaries to find important work, and will also gravitate towards highly visible projects. They care about how their work is regarded and how it fits in to the wider organization. Internal mobility is fairly easy, so folks will leave teams if they can’t find the right kind of work. Its also reflected in the growth expectations: when I worked at Google and Lyft they also had expectations that IC3s would become IC4s, and IC4s become IC5s (though Google later removed this part), but the timelines were somewhat soft. At Meta, they are firm, and expectations ramp at defined intervals as you approach the boundaries. 

    Practically, this means ICs should expect to identify and collaborate on projects that align with organizational goals and take initiative to push them forward. Managers and leadership provide support, but success heavily depends on individuals ability and desire to chart a path, adapt as needed, and ensure their contributions are visible. In general, there’s a strong bias towards getting things done, getting things out,  “rough consensus and running code”.

    Dave Anderson has written about how much more helpful he found teams at Meta, vs Amazon where there was a lot of horse trading for collaboration. Part of that is driven, I think, by this responsibility for impact. Having another team’s thanks, or enabling results for them, allows you to claim some credit for their impact with relatively little effort. Conversely, intentionally blocking another team can be seen as gatekeeping, which is frowned upon.

    No gatekeeping: Some version of “Move Fast” has been in Meta’s official values for a long time, and the company still operates at a good clip. Part of that is aided by generally making it easy to go make changes wherever they need to be made. One example of this I use with Google folks is OWNERS files. Google and Meta are both monorepo based, but at Google you have sets of services with clear owners, and touching code in another team’s service requires their full blessing. Meta also operates a monorepo, but there is much more fluidity — folks can land changes anywhere they need to. In part this is because the original Meta product, Facebook itself, is a monolith, but there is a deeper cultural aspect here.

    For example, even very senior engineers will very rarely say “no” to something. Instead of outright rejection, feedback is often framed as suggestions or concerns to consider. This encourages risk-taking and innovation, but it also places a significant responsibility on engineers to weigh feedback carefully, address risks , and seek out champions from stakeholders. This is one source of the disconnect between folks who see Meta as a place where it’s ok to take risks and folks who don’t: if you take a risk, it fails, and at PSC (performance review) time someone affirms they called out the problems that occurred and you didn’t take appropriate measures, you will be dinged. I have seen people interpret this kind of feedback as nits or suggestions, rather than weighing it heavily and convincing others that the risk is well managed ahead of time.

    In general, folks are expected to be helpful, to provide guidance to others, and not to put up walls, so it can be a tricky balance when outside teams or other engineers come in and ride roughshod over a team’s plans or projects. As an engineer, escalating misalignments on goals/priorities to management is usually well supported, and as a manager, putting engineers together to get to technical solutions across teams is expected. Enacting hard blocks where one team can’t achieve their goals because another was in the way is less so.

    The heavy dependence on individuals and relationships, particularly for cross-team projects, is another key theme:

    It’s a social company: Somewhat unsurprisingly for a company that started around a social network, Meta is a pretty social company. The internal social network is a firehose of information, and there are deep networks of connections across the company between senior engineers, managers and executives.

    Its important for engineers to talk about their work in order to find folks interested in it, build connections and relationships with them, and have a good sense not just of their org but the universe of organizations that they operate within. A fairly common failure pattern is to build a good relationship with one side of an org and ignore another, developing a sizable blind spot that later comes back to be a problem.

    The official org chart is the secondary and lagging structure at the company. The more important structure is the informal network of relationships that where many things get done, and decisions get made.

    This dynamic can sometimes feel political, which is why some describe Meta this way. While there are large-company politics (it is a large, influential company), for most people, it’s less about traditional power politics and more about navigating cliques and informal networks of folks who have worked together on multiple projects and have mutual trust and respect. The company leans a lot on strong, senior engineers to drive projects to success, and those folks may not report into the org that is officially doing the work, or may be at a lower or higher level in the org chart than you might expect. They will often work by going directly to the people they know to unblock issues, drive important changes, or get alignment on a controversial decisions.

    For big enough changes, org structural changes do follow, but they usually lag rather than lead the work itself.

    What are the downsides and upsides?

    As folks who have had a bad time can attest, Meta can be chaotic. There can be parallel implementations, people can swarm on important projects to the detriment of those trying to work on them, and less impactful projects can end up unowned and passed around. It can feel like information overload, with more being published than you can possibly follow. At the same time, there’s can be an information drought when truly important conversations happen in small, exclusive groups. For example, I’ve seen feedback about a project be shared openly between a small collection of engineers and leaders, without ever clearly reaching the team responsible for developing it.

    The flip side is that this combination of values allows Meta to pivot surprisingly fast. Changes that would have required months at some companies can be kicked off in a day, particularly by very senior, well-connected leaders. Senior ICs can take problem descriptions, and quickly form an idea of who might have thoughts on it, and pull them in. Soon you have a loose group (often later structured into a “v-team” or virtual team) that can quickly align and drive change. The lack of gatekeeping, both technically and culturally, reduces the corporate immune reaction to large changes. The incentive to individual impact encourages folks to jump onboard important things without having to work out if that means a team change, or what their long-term situation might be.

  • A very lazy mental model for parallelisms

    There is a great deal of technical complexity in distributed data parallel, model parallel, tensor parallel, fully-sharded-data-parallel, ZeRO 1/2/3, context parallel etc. To fit ideas into a picture I tend to segment in-practice parallelism into three buckets.

    1. Data parallel
    2. Pipeline parallelism
    3. Everything else

    I will tackle these out of order:

    Data parallelism

    Data parallel is a very convenient training concept because you take multiple copies of the model and give different data to them, do a forward-backward pass, then all-reduce gradients between them i.e., aggregate gradients across all devices so they operate as if they had trained on the unified set of data. This lets you train more quickly by using more devices!

    This all-reduce operation can usually be effectively overlapped with other computations, minimizing the overhead of data parallelism, and the copies of the model can be largely anything: a model running on a single GPU, or a cluster of devices running many parallelisms of their own.

    Because you get to run a whole batch of the computation through the model before needing comms, its relatively network efficient as well, which makes it plausible to do distributed data parallel over more normal network links (than the ones I’m about to mention). 

    Everything else

    Almost everything about these parallelism methods are annoying: either the communication collectives are difficult to overlap with other operations, or they must occur more frequently, leading to increased sensitivity to latency.

    This means you want faster links, which means you need very tight networking like, canonically, NVLink, which is Nvidia’s high-speed interconnect between GPUs. This tends to mean you have tensor parallelism, context parallelism etc. implemented within these fast domains, which traditionally is one host (8 GPUs), but with Blackwell can be quite a lot more (up to 36 or 72). There are a lot of options for efficiently packing the compute and scheduling work between devices in there based on the characteristics of the model, and they form this big set of potential parallelisms.

    Pipeline parallelism 

    When your model unit is too big for the available “fast” networking domain, you try and divide the model into multiple sequential stages (like layers) and process them in parallel, overlapping the computation of one stage with the communication of the next.

    This is painful for all the reasons pipelining in anything is painful but is additionally painful in that you have to write your training code with, effectively, a bunch of if statements for whichever stage it happens to be in. 

    There is some leakage between these buckets (e.g. the first and last stages of the pipeline tend to vary a lot due to dealing with embeddings for input and output), but it’s a reasonable first approximation to treat them as discrete.

    (Or FSDP everything) 

    This does presume quite a lot of scale, as the cost of getting everything working well together is non-trivial. FSDP (preferably FSDP2 in PyTorch at least) directly mixes up between the DP and Everything Else buckets, and works pretty well for most folks up to a decently large (100s+) number of GPUs. So roughly: 

    • Smallish model, lots of data: Distributed data parallel.
    • Medium sized model, medium sized number of GPUs: FSDP2
    • Anything expensive: units of <parallelisms>, chunked into pipelines if the model is too big, copied into multiple DP copies to speed up training.  
  • PyTorch while_loop

    I’ve been following the development of the higher order ops in PyTorch nightlies for a little bit, and got a chance to try out while_loop. The best examples right now are in the tests, but as another, here’s a mandlebrot example:

    import torch
    from torch._higher_order_ops.while_loop import while_loop
    import matplotlib.pyplot as plt
    
    def mandelbrot_step(z, c):
        """Performs one iteration of the Mandelbrot sequence."""
        return z**2 + c
    
    def mandelbrot(c, max_iter, threshold):
        """Compute Mandelbrot set membership for a grid of complex numbers."""
        def cond_fn(z, iter_count, mask):
            return torch.any(mask & (iter_count < max_iter))
    
        def body_fn(z, iter_count, mask):
            z_next = mandelbrot_step(z, c)
            diverged = torch.abs(z_next) > threshold
            mask_next = mask & ~diverged
            iter_count_next = iter_count + mask_next
            return z_next, iter_count_next, mask_next
    
        # Initialize variables
        z0 = torch.zeros_like(c)
        iter_count = torch.zeros(c.shape, dtype=torch.int32)
        mask = torch.ones(c.shape, dtype=torch.bool)  # All points start as candidates
        final_state = while_loop(cond_fn, body_fn, (z0, iter_count, mask))
        
        _, iterations, _ = final_state
        return iterations
    
    # Define the grid of complex numbers
    x = torch.linspace(-2.0, 1.0, 500)
    y = torch.linspace(-1.5, 1.5, 500)
    xx, yy = torch.meshgrid(x, y)
    complex_grid = xx + 1j * yy
    
    # Compute the Mandelbrot set
    max_iter = 100
    threshold = 2.0
    mandelbrot_set = mandelbrot(complex_grid, max_iter, threshold)
    
    # Plot the Mandelbrot set
    plt.figure(figsize=(10, 10))
    plt.imshow(mandelbrot_set, extent=(-2, 1, -1.5, 1.5), cmap="inferno")
    plt.colorbar(label="Iteration count")
    plt.title("Mandelbrot Set")
    plt.xlabel("Real")
    plt.ylabel("Imaginary")
    plt.show()

    In general, the only non-obvious thing about while_loop is that the cond_fn is returning a tensor, not a bool, so make sure you are getting your types right, and that the shapes must be consistent from loop to loop. If you need more accumulating type behavior, look at scan!