Update app.py
Browse files
app.py
CHANGED
@@ -30,8 +30,13 @@ guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10
|
|
30 |
random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
|
31 |
latent_t_per_second = 12.8
|
32 |
|
|
|
33 |
# Helper function: Plot linear STFT spectrogram
|
34 |
def plot_spectrogram(waveform, sample_rate, title):
|
|
|
|
|
|
|
|
|
35 |
plt.figure(figsize=(10, 4))
|
36 |
spectrogram = torch.stft(
|
37 |
torch.tensor(waveform),
|
@@ -44,7 +49,7 @@ def plot_spectrogram(waveform, sample_rate, title):
|
|
44 |
np.log1p(spectrogram),
|
45 |
aspect="auto",
|
46 |
origin="lower",
|
47 |
-
extent=[0, waveform
|
48 |
cmap="viridis",
|
49 |
)
|
50 |
plt.colorbar(format="%+2.0f dB")
|
@@ -54,6 +59,7 @@ def plot_spectrogram(waveform, sample_rate, title):
|
|
54 |
plt.tight_layout()
|
55 |
st.pyplot(plt)
|
56 |
|
|
|
57 |
# Process Button
|
58 |
if uploaded_file and st.button("Enhance Audio"):
|
59 |
st.write("Processing audio...")
|
|
|
30 |
random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
|
31 |
latent_t_per_second = 12.8
|
32 |
|
33 |
+
# Helper function: Plot linear STFT spectrogram
|
34 |
# Helper function: Plot linear STFT spectrogram
|
35 |
def plot_spectrogram(waveform, sample_rate, title):
|
36 |
+
# Ensure waveform is a 1D tensor
|
37 |
+
if len(waveform.shape) > 1:
|
38 |
+
waveform = waveform.squeeze() # Remove extra dimensions
|
39 |
+
|
40 |
plt.figure(figsize=(10, 4))
|
41 |
spectrogram = torch.stft(
|
42 |
torch.tensor(waveform),
|
|
|
49 |
np.log1p(spectrogram),
|
50 |
aspect="auto",
|
51 |
origin="lower",
|
52 |
+
extent=[0, len(waveform) / sample_rate, 0, sample_rate / 2],
|
53 |
cmap="viridis",
|
54 |
)
|
55 |
plt.colorbar(format="%+2.0f dB")
|
|
|
59 |
plt.tight_layout()
|
60 |
st.pyplot(plt)
|
61 |
|
62 |
+
|
63 |
# Process Button
|
64 |
if uploaded_file and st.button("Enhance Audio"):
|
65 |
st.write("Processing audio...")
|