roszcz commited on
Commit
b31ba68
1 Parent(s): 8da77f1

show more datasets

Browse files
Files changed (1) hide show
  1. app.py +33 -2
app.py CHANGED
@@ -6,13 +6,44 @@ from datasets import load_dataset
6
 
7
 
8
  def main():
9
- dataset = load_dataset("epr-labs/pijamia-midi-v1", split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  record_idx = st.number_input(
12
  label="record id",
13
  min_value=0,
14
  max_value=len(dataset) - 1,
15
- value=0,
16
  )
17
 
18
  record = dataset[record_idx]
 
6
 
7
 
8
  def main():
9
+
10
+ available_datasets = [
11
+ "pijamia-midi-v1",
12
+ "lakh-lmd-full",
13
+ "giant-midi-sustain-v2",
14
+ "maestro-sustain-v2",
15
+ ]
16
+ dataset_name = st.selectbox(
17
+ label="Select dataset",
18
+ options=available_datasets,
19
+ )
20
+ preview_dataset(dataset_name)
21
+
22
+
23
+ def preview_dataset(dataset_name: str):
24
+ dataset = load_dataset(f"epr-labs/{dataset_name}", split="train[100:200]")
25
+
26
+ st.write(f"### Dataset: {dataset_name}")
27
+
28
+ st.write(f"""
29
+ Number of records: {len(dataset)}
30
+ """)
31
+ code = f"""
32
+ dataset = load_dataset("epr-labs/{dataset_name}", split="train")
33
+
34
+ record = dataset[321]
35
+ piece = MidiPiece.from_huggingface(record)
36
+
37
+ # Playback in streamlit
38
+ streamlit_pianoroll.from_fortepyan(piece)
39
+ """
40
+ st.code(code, language="python")
41
 
42
  record_idx = st.number_input(
43
  label="record id",
44
  min_value=0,
45
  max_value=len(dataset) - 1,
46
+ value=50,
47
  )
48
 
49
  record = dataset[record_idx]