pradachan's picture
Upload folder using huggingface_hub
f71c233 verified
# Title: Adaptive Dual-Scale Denoising for Dynamic Feature Balancing in Low-Dimensional Diffusion Models
# Experiment description: Modify MLPDenoiser to implement a dual-scale processing approach with two parallel branches: a global branch for the original input and a local branch for an upscaled input. Introduce a learnable, timestep-conditioned weighting factor to dynamically balance the contributions of global and local branches. Train models with both the original and new architecture on all datasets. Compare performance using KL divergence and visual inspection of generated samples. Analyze how the weighting factor evolves during the denoising process and its impact on capturing global structure vs. local details across different datasets and timesteps.
## Run 0: Baseline
Results: {'circle': {'training_time': 37.41756200790405, 'eval_loss': 0.43862981091984704, 'inference_time': 0.17150163650512695, 'kl_divergence': 0.35409057707548985}, 'dino': {'training_time': 36.680198669433594, 'eval_loss': 0.6648215834442002, 'inference_time': 0.17148971557617188, 'kl_divergence': 0.9891262038552158}, 'line': {'training_time': 37.15258550643921, 'eval_loss': 0.8037568644794357, 'inference_time': 0.16029620170593262, 'kl_divergence': 0.16078266200244817}, 'moons': {'training_time': 36.608174562454224, 'eval_loss': 0.6160634865846171, 'inference_time': 0.16797804832458496, 'kl_divergence': 0.08958379744366118}}
Description: Baseline results.
## Run 1: Dual-Scale Processing with Fixed Weighting
Description: Implemented a dual-scale processing approach with two parallel branches: a global branch for the original input and a local branch for an upscaled input. Used a fixed weighting factor of 0.5 to combine the outputs of both branches.
Results: {'circle': {'training_time': 73.06966805458069, 'eval_loss': 0.43969630813964494, 'inference_time': 0.29320263862609863, 'kl_divergence': 0.3689575513483317}, 'dino': {'training_time': 74.27817940711975, 'eval_loss': 0.6613499774499927, 'inference_time': 0.2861502170562744, 'kl_divergence': 0.8196823128731071}, 'line': {'training_time': 76.55267119407654, 'eval_loss': 0.8027192704817828, 'inference_time': 0.274810791015625, 'kl_divergence': 0.1723356430884586}, 'moons': {'training_time': 74.5637640953064, 'eval_loss': 0.6173960363773434, 'inference_time': 0.27197885513305664, 'kl_divergence': 0.09956056764691522}}
Analysis: The dual-scale processing approach with fixed weighting shows mixed results compared to the baseline. While there are slight improvements in KL divergence for some datasets (e.g., 'dino'), others show a small increase (e.g., 'circle', 'line', 'moons'). The eval_loss remains relatively similar to the baseline, indicating that the model's ability to denoise hasn't significantly changed. However, the training and inference times have approximately doubled, which is expected due to the additional computational complexity of the dual-scale approach. This suggests that the fixed weighting might not be optimal for all datasets and timesteps, motivating the need for a more adaptive approach.
## Run 2: Adaptive Dual-Scale Processing with Learnable Weighting
Description: Implemented a learnable, timestep-conditioned weighting factor to dynamically balance the contributions of global and local branches. This approach aims to adaptively adjust the importance of global and local features based on the denoising timestep and input characteristics.
Results: {'circle': {'training_time': 89.83488082885742, 'eval_loss': 0.4358053507707308, 'inference_time': 0.3021073341369629, 'kl_divergence': 0.34716546994971326}, 'dino': {'training_time': 88.4310839176178, 'eval_loss': 0.6636832975365622, 'inference_time': 0.29015278816223145, 'kl_divergence': 0.8708838663821192}, 'line': {'training_time': 81.63592505455017, 'eval_loss': 0.8070394032446625, 'inference_time': 0.35721874237060547, 'kl_divergence': 0.15501561703447317}, 'moons': {'training_time': 83.31885623931885, 'eval_loss': 0.6170386532535943, 'inference_time': 0.26299095153808594, 'kl_divergence': 0.09623687732255731}}
Analysis: The adaptive dual-scale processing approach with learnable weighting shows improvements over both the baseline (Run 0) and the fixed weighting approach (Run 1). Key observations:
1. KL divergence: Improved for 'circle' and 'line' datasets compared to both previous runs. Slightly worse for 'dino' compared to Run 1 but still better than baseline. 'Moons' dataset shows a small improvement over Run 0 and is comparable to Run 1.
2. Eval loss: Slightly improved or comparable to previous runs across all datasets, indicating consistent or better denoising performance.
3. Training and inference times: Increased compared to Run 1, which is expected due to the additional complexity of the learnable weighting mechanism. However, the performance gains justify this increased computational cost.
4. Overall performance: The adaptive approach seems to better balance global and local features across different datasets, leading to improved generation quality as indicated by the KL divergence metrics.
These results suggest that the learnable, timestep-conditioned weighting factor is effective in dynamically balancing the contributions of global and local branches, leading to improved performance across various low-dimensional datasets.
## Run 3: Analyze Weighting Factor Behavior
Description: To gain insights into how the adaptive weighting mechanism operates, we modified the MLPDenoiser to output the weighting factors along with the denoised sample. We then analyzed how these weights evolve during the denoising process for different datasets and timesteps.
Results:
{'circle': {'training_time': 76.7284095287323, 'eval_loss': 0.44064563596644973, 'inference_time': 0.2985854148864746, 'kl_divergence': 0.3610795315315597},
'dino': {'training_time': 81.04552888870239, 'eval_loss': 0.6684170478140302, 'inference_time': 0.2813124656677246, 'kl_divergence': 1.0343572533041825},
'line': {'training_time': 86.87003922462463, 'eval_loss': 0.8020361468310246, 'inference_time': 0.29435014724731445, 'kl_divergence': 0.14756397445109098},
'moons': {'training_time': 82.37207579612732, 'eval_loss': 0.6139750773339625, 'inference_time': 0.2791574001312256, 'kl_divergence': 0.10025829915007056}}
Analysis:
1. Performance Metrics:
- The results show slight variations in performance compared to Run 2.
- KL divergence improved for the 'circle' dataset but slightly increased for 'dino', 'line', and 'moons'.
- Eval losses are comparable to previous runs, indicating consistent denoising performance.
- Training and inference times are similar to Run 2, suggesting that outputting weight factors doesn't significantly impact computational efficiency.
2. Weight Evolution:
- The weight evolution data collected during this run provides valuable insights into how the model balances global and local features across different datasets and timesteps.
- Further analysis of the weight_evolution arrays in the all_results.pkl file will reveal patterns in how the model adapts its focus between global and local features throughout the denoising process.
3. Implications:
- The adaptive weighting mechanism shows promise in dynamically balancing global and local features, as evidenced by the maintained or improved performance across datasets.
- The slight variations in results compared to Run 2 suggest that the weighting mechanism is sensitive to initialization and training dynamics, which could be an area for further investigation and potential improvement.
Next Steps:
To further understand and potentially improve the adaptive dual-scale processing approach, we should analyze the weight evolution patterns and consider ways to stabilize or enhance the weighting mechanism's behavior.
## Run 4: Visualize and Analyze Weight Evolution
Description: In this run, we focused on visualizing and analyzing the weight evolution data collected in Run 3. We modified the plot.py script to create new visualizations that show how the weights for global and local features change across timesteps for each dataset. This analysis helps us understand the model's behavior and potentially identify areas for improvement in the adaptive weighting mechanism.
Results:
{'circle': {'training_time': 79.91087174415588, 'eval_loss': 0.43513242751741044, 'inference_time': 0.2929060459136963, 'kl_divergence': 0.34491080184270567},
'dino': {'training_time': 73.9358651638031, 'eval_loss': 0.6596772278971075, 'inference_time': 0.27817249298095703, 'kl_divergence': 0.8622566282410796},
'line': {'training_time': 72.14862084388733, 'eval_loss': 0.8060770393027674, 'inference_time': 0.2744631767272949, 'kl_divergence': 0.15322529458283543},
'moons': {'training_time': 74.74772787094116, 'eval_loss': 0.6146410070264431, 'inference_time': 0.2653486728668213, 'kl_divergence': 0.09325452685708886}}
Analysis:
1. Performance Metrics:
- The results show consistent performance with previous runs, particularly Run 3.
- KL divergence values are slightly improved for all datasets compared to Run 3, indicating better quality in generated samples.
- Eval losses remain stable, suggesting consistent denoising performance.
- Training and inference times are comparable to previous runs, confirming that the weight visualization doesn't significantly impact computational efficiency.
2. Weight Evolution Visualization:
- The new plot.py script now includes a visualization of weight evolution across timesteps for each dataset.
- This visualization allows us to observe how the model balances global and local features throughout the denoising process.
- Analyzing these plots can provide insights into the adaptive behavior of the model for different datasets and at various stages of denoising.
3. Implications and Insights:
- The slight improvements in KL divergence across all datasets suggest that the adaptive weighting mechanism is effectively balancing global and local features.
- The stability in eval losses and computational times indicates that the adaptive approach maintains efficiency while improving generation quality.
- The weight evolution plots may reveal patterns in how the model adapts its focus between global and local features, which could inform future improvements to the architecture or training process.
Next Steps:
Based on the insights gained from the weight evolution visualization, we should consider the following:
1. Analyze the weight evolution patterns for each dataset to identify any common trends or dataset-specific behaviors.
2. Investigate if there are specific timesteps or ranges where the balance between global and local features shifts significantly.
3. Consider experimenting with different initializations or architectures for the weight network to see if we can further improve the adaptive behavior.
4. Explore the possibility of incorporating the weight evolution insights into the loss function or training process to guide the model towards more effective feature balancing.
## Run 5: Experiment with Weight Network Architecture
Description: Based on the insights gained from the weight evolution analysis in Run 4, we will modify the weight network architecture to potentially improve its adaptive behavior. We'll implement a slightly deeper network with an additional hidden layer and use a different activation function (e.g., LeakyReLU) to allow for more complex weight computations. This change aims to enable more nuanced adaptations of the global-local feature balance across different datasets and timesteps.
## Plot Descriptions
1. Training Loss Plot (train_loss.png):
This figure shows the training loss curves for each dataset (circle, dino, line, and moons) across all runs. The plot is organized as a 2x2 grid, with each subplot representing a different dataset. The x-axis represents the training steps, while the y-axis shows the loss value. Each run is represented by a different color, and the legend indicates which color corresponds to which run (Baseline, Fixed Weighting, Learnable Weighting, Weight Analysis, Weight Visualization, and Improved Weight Network).
Key insights from this plot:
- Comparison of convergence speeds across different runs and datasets
- Identification of any unusual patterns or instabilities in the training process
- Assessment of the impact of different weighting strategies on the training dynamics
2. Generated Images Plot (generated_images.png):
This figure visualizes the generated samples for each dataset and run. The plot is organized as a grid, where each row represents a different run, and each column represents a different dataset (circle, dino, line, and moons). Each subplot is a scatter plot of the generated 2D points, with the x and y axes representing the two dimensions of the data.
Key insights from this plot:
- Visual assessment of the quality of generated samples for each dataset and run
- Comparison of how well each run captures the underlying data distribution
- Identification of any artifacts or issues in the generated samples
3. Weight Evolution Plot (weight_evolution.png):
This figure shows how the weights for global and local features evolve across timesteps for each dataset. The plot is organized as a 2x2 grid, with each subplot representing a different dataset. The x-axis represents the timesteps (from the end of the diffusion process to the beginning), while the y-axis shows the weight values (ranging from 0 to 1). For each run that implements adaptive weighting, there are two lines: one for the global feature weight and one for the local feature weight.
Key insights from this plot:
- Observation of how the balance between global and local features changes throughout the denoising process
- Comparison of weight evolution patterns across different datasets
- Identification of any significant shifts in the global-local balance at specific timesteps
- Assessment of the impact of different weight network architectures on the adaptive behavior
These plots provide a comprehensive visual analysis of our experimental results, allowing for in-depth comparisons across different runs and datasets. They offer valuable insights into the training dynamics, generation quality, and adaptive behavior of our dual-scale processing approach in low-dimensional diffusion models.