Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,322 Bytes
9d3cb0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import inspect
import typing
from functools import wraps
from . import util
def format_figure(func):
"""Decorator for formatting figures produced by the code below.
See :py:func:`audiotools.core.util.format_figure` for more.
Parameters
----------
func : Callable
Plotting function that is decorated by this function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
f_keys = inspect.signature(util.format_figure).parameters.keys()
f_kwargs = {}
for k, v in list(kwargs.items()):
if k in f_keys:
kwargs.pop(k)
f_kwargs[k] = v
func(*args, **kwargs)
util.format_figure(**f_kwargs)
return wrapper
class DisplayMixin:
@format_figure
def specshow(
self,
preemphasis: bool = False,
x_axis: str = "time",
y_axis: str = "linear",
n_mels: int = 128,
**kwargs,
):
"""Displays a spectrogram, using ``librosa.display.specshow``.
Parameters
----------
preemphasis : bool, optional
Whether or not to apply preemphasis, which makes high
frequency detail easier to see, by default False
x_axis : str, optional
How to label the x axis, by default "time"
y_axis : str, optional
How to label the y axis, by default "linear"
n_mels : int, optional
If displaying a mel spectrogram with ``y_axis = "mel"``,
this controls the number of mels, by default 128.
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
"""
import librosa
import librosa.display
# Always re-compute the STFT data before showing it, in case
# it changed.
signal = self.clone()
signal.stft_data = None
if preemphasis:
signal.preemphasis()
ref = signal.magnitude.max()
log_mag = signal.log_magnitude(ref_value=ref)
if y_axis == "mel":
log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
log_mag -= log_mag.max()
librosa.display.specshow(
log_mag.numpy()[0].mean(axis=0),
x_axis=x_axis,
y_axis=y_axis,
sr=signal.sample_rate,
**kwargs,
)
@format_figure
def waveplot(self, x_axis: str = "time", **kwargs):
"""Displays a waveform plot, using ``librosa.display.waveshow``.
Parameters
----------
x_axis : str, optional
How to label the x axis, by default "time"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
"""
import librosa
import librosa.display
audio_data = self.audio_data[0].mean(dim=0)
audio_data = audio_data.cpu().numpy()
plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
wave_plot_fn = getattr(librosa.display, plot_fn)
wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
@format_figure
def wavespec(self, x_axis: str = "time", **kwargs):
"""Displays a waveform plot, using ``librosa.display.waveshow``.
Parameters
----------
x_axis : str, optional
How to label the x axis, by default "time"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
"""
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
gs = GridSpec(6, 1)
plt.subplot(gs[0, :])
self.waveplot(x_axis=x_axis)
plt.subplot(gs[1:, :])
self.specshow(x_axis=x_axis, **kwargs)
def write_audio_to_tb(
self,
tag: str,
writer,
step: int = None,
plot_fn: typing.Union[typing.Callable, str] = "specshow",
**kwargs,
):
"""Writes a signal and its spectrogram to Tensorboard. Will show up
under the Audio and Images tab in Tensorboard.
Parameters
----------
tag : str
Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
writer : SummaryWriter
A SummaryWriter object from PyTorch library.
step : int, optional
The step to write the signal to, by default None
plot_fn : typing.Union[typing.Callable, str], optional
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
whatever ``plot_fn`` is set to.
"""
import matplotlib.pyplot as plt
audio_data = self.audio_data[0, 0].detach().cpu()
sample_rate = self.sample_rate
writer.add_audio(tag, audio_data, step, sample_rate)
if plot_fn is not None:
if isinstance(plot_fn, str):
plot_fn = getattr(self, plot_fn)
fig = plt.figure()
plt.clf()
plot_fn(**kwargs)
writer.add_figure(tag.replace("wav", "png"), fig, step)
def save_image(
self,
image_path: str,
plot_fn: typing.Union[typing.Callable, str] = "specshow",
**kwargs,
):
"""Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
a specified file.
Parameters
----------
image_path : str
Where to save the file to.
plot_fn : typing.Union[typing.Callable, str], optional
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
whatever ``plot_fn`` is set to.
"""
import matplotlib.pyplot as plt
if isinstance(plot_fn, str):
plot_fn = getattr(self, plot_fn)
plt.clf()
plot_fn(**kwargs)
plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
plt.close()
|