xszheng2020
commited on
Commit
•
44a6f1e
1
Parent(s):
86e566d
Upload 2 files
Browse files
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import sklearn
|
2 |
+
import gradio as gr
|
3 |
+
# import joblib
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import lightgbm as lgb
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from PIL import Image
|
9 |
+
# import datasets
|
10 |
+
|
11 |
+
# pipe = joblib.load("./model.pkl")
|
12 |
+
|
13 |
+
title = "RegMix"
|
14 |
+
description = "TBD."
|
15 |
+
|
16 |
+
df = pd.read_csv('data.csv')
|
17 |
+
headers = df.columns.tolist()
|
18 |
+
|
19 |
+
inputs = [gr.Dataframe(headers=headers, row_count = (8, "dynamic"), datatype='number', col_count=(4,"fixed"), label="Dataset", interactive=1)]
|
20 |
+
outputs = [gr.ScatterPlot(), gr.Image(), gr.Dataframe(row_count = (2, "dynamic"), col_count=(2, "fixed"), datatype='number', label="Results", headers=["True Loss", "Pred Loss"])]
|
21 |
+
|
22 |
+
def infer(inputs):
|
23 |
+
df = pd.DataFrame(inputs, columns=headers)
|
24 |
+
|
25 |
+
X_columns = df.columns[0:-1]
|
26 |
+
y_column = df.columns[-1]
|
27 |
+
|
28 |
+
df_train, df_val = train_test_split(df, test_size=0.125, random_state=42)
|
29 |
+
|
30 |
+
hyper_params = {
|
31 |
+
'task': 'train',
|
32 |
+
'boosting_type': 'gbdt',
|
33 |
+
'objective': 'regression',
|
34 |
+
'metric': ['l1','l2'],
|
35 |
+
"num_iterations": 1000,
|
36 |
+
'seed': 42,
|
37 |
+
'learning_rate': 1e-2,
|
38 |
+
}
|
39 |
+
|
40 |
+
target = df_train[y_column]
|
41 |
+
eval_target = df_val[y_column]
|
42 |
+
|
43 |
+
np.random.seed(42)
|
44 |
+
|
45 |
+
gbm = lgb.LGBMRegressor(**hyper_params)
|
46 |
+
|
47 |
+
reg = gbm.fit(df_train[X_columns].values, target,
|
48 |
+
eval_set=[(df_val[X_columns].values, eval_target)],
|
49 |
+
eval_metric='l2',
|
50 |
+
callbacks=[
|
51 |
+
lgb.early_stopping(stopping_rounds=3),
|
52 |
+
]
|
53 |
+
)
|
54 |
+
|
55 |
+
predictions = reg.predict(df_val[X_columns].values)
|
56 |
+
df_val['Prediction'] = predictions
|
57 |
+
|
58 |
+
####
|
59 |
+
import matplotlib.pyplot as plt
|
60 |
+
plt.rcParams["font.family"] = "Times New Roman" # !!!!
|
61 |
+
plt.rcParams.update({'font.size': 24})
|
62 |
+
plt.rcParams.update({'axes.labelpad': 20})
|
63 |
+
|
64 |
+
from matplotlib import cm
|
65 |
+
from matplotlib.ticker import LinearLocator
|
66 |
+
|
67 |
+
fig, ax = plt.subplots(figsize=(12, 12), layout='compressed', subplot_kw={"projection": "3d"})
|
68 |
+
|
69 |
+
stride = 0.025
|
70 |
+
X = np.arange(0, 1+stride, stride)
|
71 |
+
Y = np.arange(0, 1+stride, stride)
|
72 |
+
|
73 |
+
X, Y = np.meshgrid(X, Y)
|
74 |
+
Z = []
|
75 |
+
for (x,y) in zip(X.reshape(-1), Y.reshape(-1)):
|
76 |
+
if (x+y)>1:
|
77 |
+
Z.append(np.inf)
|
78 |
+
else:
|
79 |
+
Z.append(
|
80 |
+
reg.predict(np.asarray([x, y, 1-x-y]).reshape(1, -1)
|
81 |
+
)[0])
|
82 |
+
Z = np.asarray(Z).reshape(len(np.arange(0, 1+stride, stride)), len(np.arange(0, 1+stride, stride)))
|
83 |
+
|
84 |
+
# Plot the surface.
|
85 |
+
surf = ax.plot_surface(X, Y, Z,
|
86 |
+
edgecolor='white',
|
87 |
+
lw=0.5, rstride=2, cstride=2,
|
88 |
+
alpha=0.85,
|
89 |
+
cmap='coolwarm',
|
90 |
+
vmin=min(Z[Z!=np.inf]),
|
91 |
+
vmax=max(Z[Z!=np.inf]),
|
92 |
+
# linewidth=8,
|
93 |
+
antialiased=False, )
|
94 |
+
|
95 |
+
ax.zaxis.set_major_locator(LinearLocator(10))
|
96 |
+
ax.zaxis.set_major_formatter('{x:.02f}')
|
97 |
+
|
98 |
+
ax.view_init(elev=25, azim=45, roll=0) #####
|
99 |
+
|
100 |
+
ax.contourf(X, Y, Z, zdir='z',
|
101 |
+
offset=np.min(Z)-0.35,
|
102 |
+
cmap=cm.coolwarm)
|
103 |
+
|
104 |
+
from matplotlib.patches import Circle
|
105 |
+
from mpl_toolkits.mplot3d import art3d
|
106 |
+
|
107 |
+
def add_point(ax, x, y, z, fc = None, ec = None, radius = 0.005):
|
108 |
+
xy_len, z_len = ax.get_figure().get_size_inches()
|
109 |
+
axis_length = [x[1] - x[0] for x in [ax.get_xbound(), ax.get_ybound(), ax.get_zbound()]]
|
110 |
+
axis_rotation = {'z': ((x, y, z), axis_length[1]/axis_length[0]),
|
111 |
+
'y': ((x, z, y), axis_length[2]/axis_length[0]*xy_len/z_len),
|
112 |
+
'x': ((y, z, x), axis_length[2]/axis_length[1]*xy_len/z_len)}
|
113 |
+
for a, ((x0, y0, z0), ratio) in axis_rotation.items():
|
114 |
+
p = Circle((x0, y0), radius, lw=1.5,
|
115 |
+
# width = radius, height = radius*ratio,
|
116 |
+
fc=fc,
|
117 |
+
ec=ec)
|
118 |
+
ax.add_patch(p)
|
119 |
+
art3d.pathpatch_2d_to_3d(p, z=z0, zdir=a)
|
120 |
+
|
121 |
+
|
122 |
+
add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z),
|
123 |
+
fc='Red',
|
124 |
+
ec='Red', radius=0.015)
|
125 |
+
|
126 |
+
add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z)-0.35,
|
127 |
+
fc='Red',
|
128 |
+
ec='Red', radius=0.015)
|
129 |
+
|
130 |
+
|
131 |
+
ax.set_xlabel('Github (%)', fontdict={
|
132 |
+
'size':24
|
133 |
+
})
|
134 |
+
ax.set_ylabel('Hacker News (%)', fontdict={
|
135 |
+
'size':24
|
136 |
+
})
|
137 |
+
|
138 |
+
ax.set_xticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
|
139 |
+
ax.set_yticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
|
140 |
+
|
141 |
+
ax.set_zticks(np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2), [str(np.round(num, 1)) for num in np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2)], )
|
142 |
+
|
143 |
+
ax.zaxis.labelpad=1
|
144 |
+
|
145 |
+
ax.set_zlim(np.min(Z)-0.35, max(Z[Z!=np.inf])+0.01)
|
146 |
+
ax.set_xlim(0, 1)
|
147 |
+
ax.set_ylim(0, 1)
|
148 |
+
ax.set_box_aspect(aspect=None, zoom=0.775)
|
149 |
+
|
150 |
+
ax.zaxis._axinfo['juggled'] = (1,2,2)
|
151 |
+
|
152 |
+
# Add a color bar which maps values to colors.
|
153 |
+
cbar = fig.colorbar(surf,
|
154 |
+
shrink=0.5,
|
155 |
+
aspect=25, pad=0.01
|
156 |
+
)
|
157 |
+
cbar.ax.set_ylabel('Prediction', fontdict={
|
158 |
+
'size':32
|
159 |
+
},
|
160 |
+
# rotation=270,
|
161 |
+
# labelpad=-90
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
filename = "tmp.png"
|
166 |
+
plt.savefig(filename, bbox_inches='tight', pad_inches=0.1)
|
167 |
+
####
|
168 |
+
return [gr.ScatterPlot(
|
169 |
+
value=df_val,
|
170 |
+
x="Prediction",
|
171 |
+
y="Target",
|
172 |
+
title="Scatter",
|
173 |
+
tooltip=["Prediction", "Target"],
|
174 |
+
x_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25],
|
175 |
+
y_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25]
|
176 |
+
),
|
177 |
+
gr.Image(Image.open('tmp.png')),
|
178 |
+
df_val[['Target', 'Prediction']], ]
|
179 |
+
|
180 |
+
gr.Interface(infer, inputs = inputs, outputs = outputs, title = title,
|
181 |
+
description = description, examples=[df], cache_examples=False, allow_flagging='never').launch(debug=False)
|
data.csv
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Github,Hacker News,Philpapers,Target
|
2 |
+
0.616,0.227,0.157,5.306806564331055
|
3 |
+
0.025,0.131,0.844,6.012040138244629
|
4 |
+
0.183,0.28,0.538,5.527338981628418
|
5 |
+
0.032,0.548,0.421,5.821677207946777
|
6 |
+
0.002,0.347,0.651,6.326714515686035
|
7 |
+
0.083,0.827,0.09,5.659005641937256
|
8 |
+
0.011,0.478,0.511,6.063345909118652
|
9 |
+
0.699,0.002,0.3,5.410695552825928
|
10 |
+
0.363,0.389,0.249,5.3828229904174805
|
11 |
+
0.542,0.257,0.201,5.31920337677002
|
12 |
+
0.262,0.311,0.427,5.465338230133057
|
13 |
+
0.677,0.251,0.072,5.3108978271484375
|
14 |
+
0.705,0.203,0.092,5.305531024932861
|
15 |
+
0.68,0.208,0.112,5.305020809173584
|
16 |
+
0.109,0.22,0.671,5.645120620727539
|
17 |
+
0.307,0.074,0.619,5.510026931762695
|
18 |
+
0.198,0.69,0.112,5.476624488830566
|
19 |
+
0.197,0.401,0.402,5.487608432769775
|
20 |
+
0.056,0.338,0.606,5.778319358825684
|
21 |
+
0.34,0.256,0.405,5.406895637512207
|
22 |
+
0.534,0.3,0.167,5.321551322937012
|
23 |
+
0.022,0.0,0.977,6.284101486206055
|
24 |
+
0.732,0.228,0.04,5.291826248168945
|
25 |
+
0.18,0.1,0.72,5.596135139465332
|
26 |
+
0.279,0.539,0.182,5.439793586730957
|
27 |
+
0.228,0.442,0.33,5.461205959320068
|
28 |
+
0.278,0.271,0.45,5.446168899536133
|
29 |
+
0.184,0.511,0.305,5.51557731628418
|
30 |
+
0.008,0.542,0.45,6.084725379943848
|
31 |
+
0.659,0.121,0.22,5.333937644958496
|
32 |
+
0.134,0.041,0.825,5.712595462799072
|
33 |
+
0.046,0.002,0.952,6.052066326141357
|
34 |
+
0.022,0.236,0.742,6.000083923339844
|
35 |
+
0.573,0.044,0.382,5.3919196128845215
|
36 |
+
0.57,0.181,0.249,5.336789608001709
|
37 |
+
0.244,0.424,0.333,5.454492568969727
|
38 |
+
0.464,0.031,0.505,5.44106388092041
|
39 |
+
0.349,0.046,0.605,5.503855705261231
|
40 |
+
0.435,0.019,0.545,5.472304344177246
|
41 |
+
0.011,0.571,0.418,6.020427703857422
|
42 |
+
0.083,0.794,0.123,5.637621879577637
|
43 |
+
0.433,0.125,0.442,5.397417545318604
|
44 |
+
0.032,0.457,0.512,5.846551895141602
|
45 |
+
0.248,0.128,0.624,5.5152506828308105
|
46 |
+
0.159,0.747,0.094,5.518942832946777
|
47 |
+
0.03,0.322,0.648,5.870717525482178
|
48 |
+
0.389,0.248,0.363,5.399688720703125
|
49 |
+
0.487,0.234,0.279,5.344883918762207
|
50 |
+
0.385,0.363,0.252,5.373892784118652
|
51 |
+
0.793,0.029,0.178,5.366541862487793
|
52 |
+
0.62,0.38,0.0,5.323075294494629
|
53 |
+
0.024,0.635,0.34,5.89354133605957
|
54 |
+
0.848,0.152,0.0,5.330050468444824
|
55 |
+
0.082,0.257,0.661,5.695749759674072
|
56 |
+
0.111,0.747,0.142,5.571730136871338
|
57 |
+
0.997,0.001,0.002,5.432588577270508
|
58 |
+
0.484,0.064,0.452,5.41372537612915
|
59 |
+
0.257,0.023,0.72,5.593489646911621
|
60 |
+
0.908,0.064,0.028,5.33869743347168
|
61 |
+
0.407,0.575,0.018,5.356371879577637
|
62 |
+
0.716,0.209,0.074,5.299798488616943
|
63 |
+
0.499,0.467,0.034,5.316855430603027
|
64 |
+
0.463,0.09,0.447,5.408260822296143
|
65 |
+
0.347,0.164,0.49,5.455391883850098
|
66 |
+
0.22,0.31,0.47,5.478835105895996
|
67 |
+
0.085,0.899,0.017,5.63015079498291
|
68 |
+
0.831,0.042,0.126,5.347104549407959
|
69 |
+
0.083,0.845,0.072,5.637035846710205
|
70 |
+
0.009,0.352,0.639,6.105093955993652
|
71 |
+
0.373,0.177,0.45,5.426303386688232
|
72 |
+
0.0,1.0,0.0,6.3570756912231445
|
73 |
+
0.001,0.0,0.999,7.201324462890625
|
74 |
+
0.577,0.032,0.391,5.384527683258057
|
75 |
+
0.699,0.248,0.053,5.30719518661499
|
76 |
+
0.131,0.379,0.491,5.566975593566895
|
77 |
+
0.042,0.865,0.093,5.747323036193848
|
78 |
+
0.009,0.773,0.218,6.052563190460205
|
79 |
+
0.593,0.198,0.209,5.319425582885742
|
80 |
+
0.335,0.063,0.602,5.493361473083496
|
81 |
+
0.508,0.2,0.292,5.362349033355713
|
82 |
+
0.001,0.073,0.926,6.6276702880859375
|
83 |
+
0.472,0.164,0.364,5.364439487457275
|
84 |
+
0.021,0.415,0.563,5.923987865447998
|
85 |
+
0.995,0.0,0.005,5.450720310211182
|
86 |
+
0.613,0.221,0.166,5.330704689025879
|
87 |
+
0.238,0.668,0.093,5.449023246765137
|
88 |
+
0.521,0.08,0.399,5.382278919219971
|
89 |
+
0.102,0.138,0.76,5.70582389831543
|
90 |
+
0.627,0.02,0.353,5.393834590911865
|
91 |
+
0.027,0.955,0.018,5.853610515594482
|
92 |
+
0.215,0.713,0.071,5.452286243438721
|
93 |
+
0.265,0.092,0.643,5.527949333190918
|
94 |
+
0.178,0.002,0.82,5.708561897277832
|
95 |
+
0.028,0.029,0.943,6.10589599609375
|
96 |
+
0.002,0.305,0.693,6.344947338104248
|
97 |
+
0.608,0.358,0.034,5.328068733215332
|
98 |
+
0.579,0.226,0.195,5.31940221786499
|
99 |
+
0.171,0.04,0.789,5.646611213684082
|
100 |
+
0.056,0.483,0.461,5.769045352935791
|
101 |
+
0.175,0.358,0.467,5.5205817222595215
|
102 |
+
0.223,0.713,0.065,5.465519428253174
|
103 |
+
0.359,0.095,0.546,5.459996223449707
|
104 |
+
0.051,0.672,0.276,5.736810684204102
|
105 |
+
0.727,0.198,0.075,5.289270401000977
|
106 |
+
0.019,0.203,0.778,6.054460525512695
|
107 |
+
0.12,0.877,0.003,5.571919918060303
|
108 |
+
0.771,0.026,0.203,5.370471477508545
|
109 |
+
0.642,0.091,0.267,5.33230447769165
|
110 |
+
0.209,0.089,0.702,5.567383766174316
|
111 |
+
0.603,0.036,0.361,5.384149551391602
|
112 |
+
0.185,0.31,0.504,5.524736404418945
|
113 |
+
0.489,0.328,0.183,5.332746982574463
|
114 |
+
0.014,0.245,0.742,6.066995620727539
|
115 |
+
0.75,0.241,0.009,5.288382530212402
|
116 |
+
0.527,0.352,0.121,5.32275915145874
|
117 |
+
0.291,0.301,0.408,5.430193901062012
|
118 |
+
0.046,0.755,0.199,5.744025707244873
|
119 |
+
0.031,0.949,0.02,5.835524559020996
|
120 |
+
0.252,0.015,0.733,5.597743988037109
|
121 |
+
0.524,0.004,0.471,5.446127414703369
|
122 |
+
0.0,0.619,0.381,6.416290760040283
|
123 |
+
0.08,0.903,0.017,5.660269260406494
|
124 |
+
0.0,0.87,0.13,6.306185245513916
|
125 |
+
0.209,0.205,0.586,5.519876003265381
|
126 |
+
0.057,0.533,0.41,5.761989116668701
|
127 |
+
0.307,0.597,0.096,5.392508506774902
|
128 |
+
0.008,0.98,0.011,6.026959896087647
|
129 |
+
0.865,0.039,0.096,5.351530075073242
|