cyrusyc commited on
Commit
bcf00b7
1 Parent(s): ba1879f

fix colormap not updating due to cache

Browse files
.gitignore CHANGED
@@ -4,6 +4,7 @@
4
  *.traj
5
  mlip_arena/tasks/*/*/*/
6
  lab/
 
7
 
8
  # Byte-compiled / optimized / DLL files
9
  __pycache__/
 
4
  *.traj
5
  mlip_arena/tasks/*/*/*/
6
  lab/
7
+ manuscripts/
8
 
9
  # Byte-compiled / optimized / DLL files
10
  __pycache__/
mlip_arena/tasks/run.py CHANGED
@@ -311,4 +311,8 @@ def md(
311
 
312
  traj.close()
313
 
314
- return {"runtime": end_time - start_time, "n_steps": n_steps}
 
 
 
 
 
311
 
312
  traj.close()
313
 
314
+ return {
315
+ "atoms": atoms,
316
+ "runtime": end_time - start_time,
317
+ "n_steps": n_steps,
318
+ }
serve/tasks/homonuclear-diatomics.py CHANGED
@@ -55,15 +55,18 @@ palette_colors = list(color_palettes.values())
55
  palette_name = vis.selectbox("Color sequence", options=palette_names, index=22)
56
 
57
  color_sequence = color_palettes[palette_name] # type: ignore
58
-
59
- DATA_DIR = Path("mlip_arena/tasks/diatomics")
60
  if not mlip_methods and not dft_methods:
61
  st.stop()
62
 
 
63
  @st.cache_data
64
  def get_data(mlip_methods, dft_methods):
 
 
65
  dfs = [
66
- pd.read_json(DATA_DIR / REGISTRY[method]["family"] / "homonuclear-diatomics.json")
 
 
67
  for method in mlip_methods
68
  ]
69
  dfs.extend(
@@ -76,6 +79,7 @@ def get_data(mlip_methods, dft_methods):
76
  df.drop_duplicates(inplace=True, subset=["name", "method"])
77
  return df
78
 
 
79
  df = get_data(mlip_methods, dft_methods)
80
 
81
  method_color_mapping = {
@@ -83,18 +87,12 @@ method_color_mapping = {
83
  for i, method in enumerate(df["method"].unique())
84
  }
85
 
86
- # img_dir = Path('./images')
87
- # img_dir.mkdir(exist_ok=True)
88
-
89
 
90
  @st.cache_data
91
- def get_plots(df, energy_plot, force_plot):
92
-
93
  figs = []
94
 
95
  for i, symbol in enumerate(chemical_symbols[1:]):
96
-
97
-
98
  rows = df[df["name"] == symbol + symbol]
99
 
100
  if rows.empty:
@@ -187,7 +185,7 @@ def get_plots(df, energy_plot, force_plot):
187
  xanchor="right",
188
  y=1,
189
  yanchor="top",
190
- bgcolor="rgba(0, 0, 0, 0)"
191
  # entrywidth=0.3,
192
  # entrywidthmode='fraction',
193
  ),
@@ -219,19 +217,17 @@ def get_plots(df, energy_plot, force_plot):
219
  ),
220
  )
221
 
222
-
223
  # cols[i % ncols].plotly_chart(fig, use_container_width=True)
224
 
225
  figs.append(fig)
226
-
227
  return figs
228
  # fig.write_image(format='svg', file=img_dir / f"{name}.svg")
229
 
230
 
231
- figs = get_plots(df, energy_plot, force_plot)
232
 
233
  for i, fig in enumerate(figs):
234
  if i % ncols == 0:
235
  cols = st.columns(ncols)
236
  cols[i % ncols].plotly_chart(fig, use_container_width=True)
237
-
 
55
  palette_name = vis.selectbox("Color sequence", options=palette_names, index=22)
56
 
57
  color_sequence = color_palettes[palette_name] # type: ignore
 
 
58
  if not mlip_methods and not dft_methods:
59
  st.stop()
60
 
61
+
62
  @st.cache_data
63
  def get_data(mlip_methods, dft_methods):
64
+ DATA_DIR = Path("mlip_arena/tasks/diatomics")
65
+
66
  dfs = [
67
+ pd.read_json(
68
+ DATA_DIR / REGISTRY[method]["family"] / "homonuclear-diatomics.json"
69
+ )
70
  for method in mlip_methods
71
  ]
72
  dfs.extend(
 
79
  df.drop_duplicates(inplace=True, subset=["name", "method"])
80
  return df
81
 
82
+
83
  df = get_data(mlip_methods, dft_methods)
84
 
85
  method_color_mapping = {
 
87
  for i, method in enumerate(df["method"].unique())
88
  }
89
 
 
 
 
90
 
91
  @st.cache_data
92
+ def get_plots(df, energy_plot: bool, force_plot: bool, method_color_mapping: dict):
 
93
  figs = []
94
 
95
  for i, symbol in enumerate(chemical_symbols[1:]):
 
 
96
  rows = df[df["name"] == symbol + symbol]
97
 
98
  if rows.empty:
 
185
  xanchor="right",
186
  y=1,
187
  yanchor="top",
188
+ bgcolor="rgba(0, 0, 0, 0)",
189
  # entrywidth=0.3,
190
  # entrywidthmode='fraction',
191
  ),
 
217
  ),
218
  )
219
 
 
220
  # cols[i % ncols].plotly_chart(fig, use_container_width=True)
221
 
222
  figs.append(fig)
223
+
224
  return figs
225
  # fig.write_image(format='svg', file=img_dir / f"{name}.svg")
226
 
227
 
228
+ figs = get_plots(df, energy_plot, force_plot, method_color_mapping)
229
 
230
  for i, fig in enumerate(figs):
231
  if i % ncols == 0:
232
  cols = st.columns(ncols)
233
  cols[i % ncols].plotly_chart(fig, use_container_width=True)