Spaces:
Running
Running
add equiformer, escn; add leaderboard
Browse files- serve/models/leaderboard.py +116 -0
serve/models/leaderboard.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
DATA_DIR = Path("mlip_arena/tasks/diatomics")
|
6 |
+
methods = ["MACE-MP", "Equiformer", "CHGNet", "MACE-OFF"]
|
7 |
+
dfs = [pd.read_json(DATA_DIR / method.lower() / "homonuclear-diatomics.json") for method in methods]
|
8 |
+
df = pd.concat(dfs, ignore_index=True)
|
9 |
+
|
10 |
+
table = pd.DataFrame(columns=["Model", "No. of supported elements", "No. of reversed forces", "Energy-consistent forces"])
|
11 |
+
|
12 |
+
for method in df["method"].unique():
|
13 |
+
rows = df[df["method"] == method]
|
14 |
+
new_row = {
|
15 |
+
"Model": method,
|
16 |
+
"No. of supported elements": len(rows["name"].unique()),
|
17 |
+
"No. of reversed forces": None, # Replace with actual logic if available
|
18 |
+
"Energy-consistent forces": None # Replace with actual logic if available
|
19 |
+
}
|
20 |
+
table = pd.concat([table, pd.DataFrame([new_row])], ignore_index=True)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# Define the data
|
25 |
+
# data = {
|
26 |
+
# "Metrics": [
|
27 |
+
# "No. of supported elements",
|
28 |
+
# "No. of reversed forces",
|
29 |
+
# "Energy-consistent forces",
|
30 |
+
# ],
|
31 |
+
# "MACE-MP(M)": ["10", "5", "Yes"],
|
32 |
+
# "CHGNet": ["20", "3", "No"],
|
33 |
+
# "Equiformer": ["15", "7", "Yes"]
|
34 |
+
# }
|
35 |
+
|
36 |
+
# # Convert the data to a DataFrame
|
37 |
+
# df = pd.DataFrame(data)
|
38 |
+
|
39 |
+
# # Set the 'Metrics' column as the index
|
40 |
+
# df.set_index("Metrics", inplace=True)
|
41 |
+
|
42 |
+
# # Transpose the DataFrame
|
43 |
+
# df = df.T
|
44 |
+
|
45 |
+
# Apply custom CSS to center the table
|
46 |
+
# Create the Streamlit table
|
47 |
+
|
48 |
+
table.set_index("Model", inplace=True)
|
49 |
+
|
50 |
+
|
51 |
+
s = table.style.background_gradient(
|
52 |
+
cmap="Spectral",
|
53 |
+
subset=["No. of supported elements"],
|
54 |
+
vmin=0, vmax=120
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
st.markdown("# Leaderboard")
|
59 |
+
st.dataframe(s, use_container_width=True)
|
60 |
+
|
61 |
+
# Define custom CSS for table
|
62 |
+
# custom_css = """
|
63 |
+
# <style>
|
64 |
+
# table {
|
65 |
+
# width: 100%;
|
66 |
+
# border-collapse: collapse;
|
67 |
+
# }
|
68 |
+
# th, td {
|
69 |
+
# border: 1px solid #ddd;
|
70 |
+
# padding: 8px;
|
71 |
+
# }
|
72 |
+
# th {
|
73 |
+
# background-color: #4CAF50;
|
74 |
+
# color: white;
|
75 |
+
# text-align: left;
|
76 |
+
# }
|
77 |
+
# tr:nth-child(even) {
|
78 |
+
# background-color: #f2f2f2;
|
79 |
+
# }
|
80 |
+
# tr:hover {
|
81 |
+
# background-color: #ddd;
|
82 |
+
# }
|
83 |
+
# </style>
|
84 |
+
# """
|
85 |
+
|
86 |
+
# # Display the table with custom CSS
|
87 |
+
# st.markdown(custom_css, unsafe_allow_html=True)
|
88 |
+
# st.markdown(table.to_html(index=False), unsafe_allow_html=True)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
# import numpy as np
|
95 |
+
# import plotly.figure_factory as ff
|
96 |
+
# import streamlit as st
|
97 |
+
|
98 |
+
# st.markdown("# Dashboard")
|
99 |
+
|
100 |
+
# # Add histogram data
|
101 |
+
# x1 = np.random.randn(200) - 2
|
102 |
+
# x2 = np.random.randn(200)
|
103 |
+
# x3 = np.random.randn(200) + 2
|
104 |
+
|
105 |
+
# # Group data together
|
106 |
+
# hist_data = [x1, x2, x3]
|
107 |
+
|
108 |
+
# group_labels = ["Group 1", "Group 2", "Group 3"]
|
109 |
+
|
110 |
+
# # Create distplot with custom bin_size
|
111 |
+
# fig = ff.create_distplot(
|
112 |
+
# hist_data, group_labels, bin_size=[.1, .25, .5]
|
113 |
+
# )
|
114 |
+
|
115 |
+
# # Plot!
|
116 |
+
# st.plotly_chart(fig, use_container_width=True)
|