sps44 commited on
Commit
457c342
1 Parent(s): 642e447

Embedding dtypes

Browse files
Files changed (1) hide show
  1. run.py +3 -1
run.py CHANGED
@@ -87,7 +87,9 @@ if __name__ == "__main__":
87
  ds = datasets.concatenate_datasets([ds, ds_enrichment], split=ds.split, axis=1)
88
 
89
  dtypes = {"DistanceToDriverAhead": spotlight.Sequence1D, "RPM": spotlight.Sequence1D, "Speed": spotlight.Sequence1D, "nGear": spotlight.Sequence1D,
90
- "Throttle": spotlight.Sequence1D, "Brake": spotlight.Sequence1D, "DRS": spotlight.Sequence1D, "X": spotlight.Sequence1D, "Y": spotlight.Sequence1D, "Z": spotlight.Sequence1D}
 
 
91
 
92
  for col in ds.column_names:
93
  if "embedding" in col and isinstance(ds.features[col], datasets.Sequence):
 
87
  ds = datasets.concatenate_datasets([ds, ds_enrichment], split=ds.split, axis=1)
88
 
89
  dtypes = {"DistanceToDriverAhead": spotlight.Sequence1D, "RPM": spotlight.Sequence1D, "Speed": spotlight.Sequence1D, "nGear": spotlight.Sequence1D,
90
+ "Throttle": spotlight.Sequence1D, "Brake": spotlight.Sequence1D, "DRS": spotlight.Sequence1D, "X": spotlight.Sequence1D, "Y": spotlight.Sequence1D, "Z": spotlight.Sequence1D,
91
+ 'RPM_emb': spotlight.Embedding, 'Speed_emb': spotlight.Embedding, 'nGear_emb': spotlight.Embedding, 'Throttle_emb': spotlight.Embedding, 'Brake_emb': spotlight.Embedding,
92
+ 'X_emb': spotlight.Embedding, 'Y_emb': spotlight.Embedding, 'Z_emb': spotlight.Embedding}
93
 
94
  for col in ds.column_names:
95
  if "embedding" in col and isinstance(ds.features[col], datasets.Sequence):