Understanding InstaFlow/Rectified Flow
Hi! I usually do posts on medium here but I wanted to test the huggingface blog out so this will be my first blog post here! The reason I am interested in Instaflows/rectified flows is
- We talked about this in the Eleuther Diffusion Reading group and it sounded interesting
- I wanted to make a pr for this in diffusers(in this issue). I'll add code once the pr is done!
What is InstaFlow/Rectified Flow?
Rectified flow is a method of finetuning diffusion models so that you can generate images in just 1 step while, traditionally, you need around 12 steps. Instaflow is just that applied to stable diffusion. If you want to test it out, check out the demo here!
So, let's first look into how rectified flows
Rectified Flows
Rectified Flows were introduced in the paper "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow".
The paper deals with the idea of not just generating images in one step but instead, having 2 image distributions, , which can be horses, and which can be zebras—and then making the flow between them as straight as possible. The extension for this to generate in one step is that if we make the flow between a noisy distribution and the real image distribution as short and straight as possible, instant generation of images becomes possible!
This is called an image-image translation/transport mapping problem.
Transport Mapping Problem in GANs
The transport mapping problem has seen its fair share of interest in almost any field that deals with images. For example, CycleGAN
And its subsequent improvements of StarGAN and StarGAN2 use GANs to learn the mapping between images of one domain to another! The main method used to make this work is called cycle consistency loss where given an image of a zebra, you make it turn into a horse, then back to a zebra and make sure it's still a zebra. For more details, check out the link!
However, GANs, by themselves have severe training instability and the generated do not have as good quality as diffusion models so that's where this method comes in!
Back to Rectified Flows
Now, let us sample and back to our analogy, this will mean we will get a particular image of a horse, , from the distribution of images of horses, , and a particular zebra image , from a distribution of zebra images .
Now, given the pair of images . Here, we will define a parameter t which is between(including) 0 and 1 which says how far between and we are. So will be exactly part way there.
Now, while in traditional diffusion models, we can do this, the pathway is not exactly straight. In fact, it can be very roundabout as can be seen below
So now, let's move from pathway X to pathway Z. We still want the same endpoints, in that and but we want the path to the points as straight as possible. This is written as
Now, first of all, d means a very tiny step. And v here means the velocity at the point of at timestep t. What this means is that if we keep pushing in the direction of , recomputing v each time, then we will reach and we will successfully have a zebra.
Now, as we mentioned plenty of times before, we want a straight line. What will this mean for the context of v? First of all, v should be constant because a straight line should push the same amount and same direction regardless of what t is. Secondary, it should be the closest path. This will mean that when we do integration(which means just adding the length of the entire path made in Z), it should be .
Now, in formal terms, this will mean
with
this can be written also as
too!
A quick sidenote here, for DDIMs, this will be
if people are interested, I can link some theory background for this here!
Now, the paper goes into some very interesting math parts which I'll skip in this blog but I recommend you check out if you like math and differential equations and why the above won't fall into some pitfalls.
Now, in practice, as you may have guessed, v will be our stable diffusion model. And will be the initial noise and will be the output image. So one strategy I am understanding is we can record a huge dataset of initial noise and the output image from stable diffusion. Then, we can finetune a stable diffusion model so that the epsilon/v predicted is always a straight line between the 2 given . So overall algorithm is
One great thing about this, as can be seen from the algorithm, is that we defined t to be between 0 and 1 so we can just add v times 1 to to get !
Reflow
Now, one problem is is this straightening that much of a trivial solution? Is there no error associated there? And the answer is yes! The solution is once you get your best possible path of to , you just apply rectified flow on that path again and again until it finally becomes straight as you can see below
the algorithm is
However, the paper mentions that while doing reflow makes the line straighter and shorter, it'll be at the cost of getting a proper as it deviates too much.
Distillation
Now, given we have a reflow model that can predict velocity, we can distill it. For this, InstaFlow gave the best equation so
Essentially, what this does is instead of us trying to predict the velocity that when added to will become , we are trying to directly predict which is pretty interesting.
Now, the paper has math applying to DDIMs given non-linear problems but since we are mainly concerned with getting a PR done, let's move on to instaflows!
InstaFlow
Instaflow is pretty much stable diffusion applied to rectified flows. For some statistics, they trained with 199 A100 days=4776 A100 GPU hours which should cost around 5000 dollars for institutions or 10000 for those without deals. This is pretty cheap considering stable diffusion 2.1 was trained with 200000 GPU hours which does translate pretty much to 200k dollars. It can generate images in 0.12 seconds on A100 which makes sense as it is a 1-step model.
Training algorithm
As can be seen above, the algorithm is pretty much exactly the same except we condition on text while the original rectified flow was unconditional. Then, there's an extra step for distilling. The authors observed reflow was very important for good quality.
Instaflow training setup
They used a subset of prompts from laion2B-en. The generated images are done with 25 steps in the DPM solver with a guidance scale of 6.0. For distillation, they used LPIPS loss using a network which I assume is vgg to get the high-level similarities of images(faces, objects etc). Finally, they used a batch size of 32 and 8 A 100 GPUS for training with AdamW optimizer.
TODO list
Overall, this is it! So as a TODO list for the PR, we need to
- Figure out how to map epsilon to velocity My understanding is we ignore DDPM/epsilon objectives during Rectified flow and just have the unet output v directly
- Make a script to generate the latent noise, images, and text to save to the dataset
- Make rectified flow/reflow script
- Make distillation script