File size: 2,248 Bytes
3d4323f
 
 
 
 
de66b6c
3d4323f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import json
import pandas as pd
import collections
import scipy.signal
import numpy as np
from functools import partial
from openwakeword.model import Model

# Load openWakeWord models
model = Model()

# Define function to process audio
def process_audio(audio, state=collections.defaultdict(partial(collections.deque, maxlen=60))):
    # Resample audio to 16khz if needed
    if audio[0] != 16000:
        data = scipy.signal.resample(audio[1], int(float(audio[1].shape[0])/audio[0]*16000))    
    
    # Get predictions
    for i in range(0, len(data), 1280):
        chunk = data[i:i+1280]
        if len(chunk) == 1280:
            prediction = model.predict(chunk)
        for key in prediction:
            #Fill deque with zeros if it's empty
            if len(state[key]) == 0:
                state[key].extend(np.zeros(60))
                
            # Add prediction
            state[key].append(prediction[key])
    
    # Make line plot
    dfs = []
    for key in state.keys():
        df = pd.DataFrame({"x": np.arange(len(state[key])), "y": state[key], "Model": key})
        dfs.append(df)
    
    df = pd.concat(dfs)
    plot = gr.LinePlot().update(value = df, x='x', y='y', color="Model", y_lim = (0,1), tooltip="Model",
                                width=600, height=300, x_title="Time (frames)", y_title="Model Score", color_legend_position="bottom")
    
    # Manually adjust how the legend is displayed
    tmp = json.loads(plot["value"]["plot"])
    tmp["layer"][0]['encoding']['color']['legend']["direction"] = "vertical"
    tmp["layer"][0]['encoding']['color']['legend']["columns"] = 4
    tmp["layer"][0]['encoding']['color']['legend']["labelFontSize"] = 12
    tmp["layer"][0]['encoding']['color']['legend']["titleFontSize"] = 14
    
    plot["value"]['plot'] = json.dumps(tmp)
    
    return plot, state

# Create Gradio interface and launch
gr_int = gr.Interface(
    css = ".flex {flex-direction: column} .gr-panel {width: 100%}",
    fn=process_audio,
    inputs=[
        gr.Audio(source="microphone", type="numpy", streaming=True, show_label=False), 
        "state"
    ],
    outputs=[
        gr.LinePlot(show_label=False),
        "state"
    ],
    live=True)

gr_int.launch()