haoheliu commited on
Commit
38b7cd1
1 Parent(s): b650afc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -30,8 +30,13 @@ guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10
30
  random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
31
  latent_t_per_second = 12.8
32
 
 
33
  # Helper function: Plot linear STFT spectrogram
34
  def plot_spectrogram(waveform, sample_rate, title):
 
 
 
 
35
  plt.figure(figsize=(10, 4))
36
  spectrogram = torch.stft(
37
  torch.tensor(waveform),
@@ -44,7 +49,7 @@ def plot_spectrogram(waveform, sample_rate, title):
44
  np.log1p(spectrogram),
45
  aspect="auto",
46
  origin="lower",
47
- extent=[0, waveform.shape[-1] / sample_rate, 0, sample_rate / 2],
48
  cmap="viridis",
49
  )
50
  plt.colorbar(format="%+2.0f dB")
@@ -54,6 +59,7 @@ def plot_spectrogram(waveform, sample_rate, title):
54
  plt.tight_layout()
55
  st.pyplot(plt)
56
 
 
57
  # Process Button
58
  if uploaded_file and st.button("Enhance Audio"):
59
  st.write("Processing audio...")
 
30
  random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
31
  latent_t_per_second = 12.8
32
 
33
+ # Helper function: Plot linear STFT spectrogram
34
  # Helper function: Plot linear STFT spectrogram
35
  def plot_spectrogram(waveform, sample_rate, title):
36
+ # Ensure waveform is a 1D tensor
37
+ if len(waveform.shape) > 1:
38
+ waveform = waveform.squeeze() # Remove extra dimensions
39
+
40
  plt.figure(figsize=(10, 4))
41
  spectrogram = torch.stft(
42
  torch.tensor(waveform),
 
49
  np.log1p(spectrogram),
50
  aspect="auto",
51
  origin="lower",
52
+ extent=[0, len(waveform) / sample_rate, 0, sample_rate / 2],
53
  cmap="viridis",
54
  )
55
  plt.colorbar(format="%+2.0f dB")
 
59
  plt.tight_layout()
60
  st.pyplot(plt)
61
 
62
+
63
  # Process Button
64
  if uploaded_file and st.button("Enhance Audio"):
65
  st.write("Processing audio...")