File size: 6,100 Bytes
f71c233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Title: Enhancing Diffusion Models with Generative Adversarial Networks for Improved Sample Quality
# Experiment description: In this experiment, we will integrate a GAN framework into the diffusion model. Specifically, we will: (1) Implement a simple discriminator network to distinguish between real and generated samples, using a small MLP architecture, (2) Modify the MLPDenoiser to include an adversarial loss term along with the existing reconstruction loss, using a gradient penalty to improve training stability, (3) Adjust the training loop to alternately train the discriminator and the denoiser, ensuring that the denoiser learns to produce more realistic samples based on the feedback from the discriminator, (4) Train the GAN-enhanced diffusion model on the same datasets, and (5) Compare the results in terms of training time, evaluation loss, KL divergence, and sample quality using both quantitative metrics (e.g., KL divergence) and qualitative visual inspection.
## Run 0: Baseline
Results: {'circle': {'training_time': 52.92525362968445, 'eval_loss': 0.43425207659411613, 'inference_time': 0.14317584037780762, 'kl_divergence': 0.3408613096628985}, 'dino': {'training_time': 79.84856963157654, 'eval_loss': 0.665187395899497, 'inference_time': 0.11029982566833496, 'kl_divergence': 1.1213630053295838}, 'line': {'training_time': 54.4330997467041, 'eval_loss': 0.8009099938985332, 'inference_time': 0.11038732528686523, 'kl_divergence': 0.16666481475290315}, 'moons': {'training_time': 54.482470750808716, 'eval_loss': 0.6143880410267569, 'inference_time': 0.1099700927734375, 'kl_divergence': 0.08647709108345572}}
Description: Baseline results.

## Run 2: Adding Gradient Penalty
Results: {'circle': {'training_time': 265.2883794307709, 'eval_loss': 0.434587472006488, 'inference_time': 0.14093685150146484, 'kl_divergence': 0.3604047182045786}, 'dino': {'training_time': 243.74999475479126, 'eval_loss': 0.665170654463951, 'inference_time': 0.11144709587097168, 'kl_divergence': 1.0364102334466179}, 'line': {'training_time': 261.86758947372437, 'eval_loss': 0.8037817526961226, 'inference_time': 0.1273975372314453, 'kl_divergence': 0.14531509758172917}, 'moons': {'training_time': 263.7597990036011, 'eval_loss': 0.6178049460396438, 'inference_time': 0.1425645351409912, 'kl_divergence': 0.10209078394617543}}
Description: In this run, we added a gradient penalty to the adversarial loss to improve training stability. The results show that the training time increased significantly, but the evaluation loss and KL divergence metrics did not show substantial improvement. This suggests that while the gradient penalty may help with training stability, it does not necessarily lead to better sample quality in this context.

## Run 3: Fine-Tuning Hyperparameters
Results: {'circle': {'training_time': 273.7930860519409, 'eval_loss': 0.4346648931808179, 'inference_time': 0.12023258209228516, 'kl_divergence': 0.3500326621068686}, 'dino': {'training_time': 253.12539935112, 'eval_loss': 0.6642400203153606, 'inference_time': 0.12917494773864746, 'kl_divergence': 1.0426551796727075}, 'line': {'training_time': 281.75747752189636, 'eval_loss': 0.8052869840046329, 'inference_time': 0.12714266777038574, 'kl_divergence': 0.18203549509286743}, 'moons': {'training_time': 283.6050419807434, 'eval_loss': 0.6191892492039429, 'inference_time': 0.1302032470703125, 'kl_divergence': 0.09779758687759507}}
Description: In this run, we fine-tuned the hyperparameters by adjusting the learning rate and the number of hidden layers in the discriminator. The results show that the training time increased slightly compared to Run 2. The evaluation loss and KL divergence metrics showed minor improvements, indicating that fine-tuning the hyperparameters had a positive but limited impact on the model's performance.

## Run 4: Changing Beta Schedule to Quadratic
Results: {'circle': {'training_time': 267.81446051597595, 'eval_loss': 0.37992295295076295, 'inference_time': 0.17848443984985352, 'kl_divergence': 0.4426549895288475}, 'dino': {'training_time': 273.8629205226898, 'eval_loss': 0.6424662062274221, 'inference_time': 0.13181495666503906, 'kl_divergence': 0.5710263317303715}, 'line': {'training_time': 287.80195713043213, 'eval_loss': 0.8638582183881793, 'inference_time': 0.13017058372497559, 'kl_divergence': 0.3497732784746851}, 'moons': {'training_time': 274.91384768486023, 'eval_loss': 0.6413522321549828, 'inference_time': 0.12932515144348145, 'kl_divergence': 0.22292086418253548}}
Description: In this run, we changed the beta schedule from "linear" to "quadratic" to see if it improves the model's performance. The results show that the training time increased slightly compared to Run 3. The evaluation loss and KL divergence metrics showed mixed results, with some datasets showing improvement and others not. This suggests that the quadratic beta schedule may have a different impact depending on the dataset.

# Plot Descriptions

## Plot 1: Training Loss
Filename: train_loss.png
Description: This plot shows the training loss over time for each dataset across different runs. The x-axis represents the training steps, and the y-axis represents the loss. Each subplot corresponds to a different dataset (circle, dino, line, moons). The legend indicates the different runs, including Baseline, Gradient Penalty, Fine-Tuned Hyperparameters, and Quadratic Beta Schedule. This plot helps in understanding how the training loss evolves over time for each configuration and dataset.

## Plot 2: Generated Samples
Filename: generated_images.png
Description: This plot visualizes the generated samples for each dataset across different runs. Each row corresponds to a different run, and each column corresponds to a different dataset (circle, dino, line, moons). The scatter plots show the generated samples in 2D space. The legend indicates the different runs, including Baseline, Gradient Penalty, Fine-Tuned Hyperparameters, and Quadratic Beta Schedule. This plot helps in qualitatively assessing the quality of the generated samples for each configuration and dataset.