## Codes by Anonymous Submission to NeurIPS2024

Paper ID: 531

FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification

## Introduction:

Diffusion Transformers (DiT) have attracted significant attention in research. However, they suffer from a slow convergence rate. In this paper, we aim to accelerate DiT training without any architectural modification. We identify the following issues in the training process: firstly, certain training strategies do not consistently perform well across different data. Secondly, the effectiveness of supervision at specific timesteps is limited. In response, we propose the following contributions: (1) We introduce a new perspective for interpreting the failure of the strategies. Specifically, we slightly extend the definition of Signal-to-Noise Ratio (SNR) and suggest observing the Probability Density Function (PDF) of SNR to understand the essence of the data robustness of the strategy. (2) We conduct numerous experiments and report over one hundred experimental results to empirically summarize a unified accelerating strategy from the perspective of PDF. (3) We develop a new supervision method that further accelerates the training process of DiT. Based on them, we propose FasterDiT, an exceedingly simple and practicable design strategy. With few lines of code modifications, it achieves 2.30 FID on ImageNet 256 resolution at 1000k iterations, which is comparable to DiT (2.27 FID) but 7 times faster in training.

## Using of Codes:

Our codes are maily built with official repo of SiT: https://github.com/willisma/SiT

### Key points
Our main modification is:
1. We add 'magic_number' as an additional hyperparameters to ```train.py```, ```sample.py``` and ```sample_ddp.py```. Since the importance of correcting concentration SNR PDF.

2. We use logit normal function to do timestep sampling (```transport/transport.py``` L104).
```
    def sample_logit_normal(self, mu, sigma, size=1):
        # Generate samples from the normal distribution
        samples = norm.rvs(loc=mu, scale=sigma, size=size)
        
        # Transform samples to be in the range (0, 1) using the logistic function
        samples = 1 / (1 + np.exp(-samples))

        # Numpy to Tensor
        samples = th.tensor(samples, dtype=th.float32)

        return samples

    def sample(self, x1):
        """Sampling x0 & t based on shape of x1 (if needed)
          Args:
            x1 - data point; [batch, *dim]
        """
        
        x0 = th.randn_like(x1)
        t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
        # t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
        t = self.sample_logit_normal(0, 1, size=x1.shape[0]) * (t1 - t0) + t0
        t = t.to(x1)
        return t, x0, x1
```

3. We add a direction supervision to training loss (```transport/transport.py``` L155). 
```
    # Calculate the cosine similarity between model_output and ut
    cos_sim = cosine_similarity(model_output, ut, dim=1)
    # Add the cosine similarity to the loss
    terms['loss'] = mean_flat(((model_output - ut) ** 2))  + mean_flat(1 - cos_sim)
```

## Train

Our main results could be trained with:
```
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --wandb
```