BlendMMM commited on
Commit
3c56188
1 Parent(s): a3f74f1

Update pages/7_Build_Response_Curves.py

Browse files
Files changed (1) hide show
  1. pages/7_Build_Response_Curves.py +212 -212
pages/7_Build_Response_Curves.py CHANGED
@@ -1,213 +1,213 @@
1
- import streamlit as st
2
- import plotly.express as px
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- from utilities_with_panel import channel_name_formating, load_authenticator, initialize_data
6
- from sklearn.metrics import r2_score
7
- from collections import OrderedDict
8
- from classes import class_from_dict,class_to_dict
9
- import pickle
10
- import json
11
- import pandas as pd
12
- from utilities import (
13
- load_local_css,
14
- set_header,
15
- channel_name_formating,
16
- )
17
-
18
- for k, v in st.session_state.items():
19
- if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
20
- st.session_state[k] = v
21
-
22
- def s_curve(x,K,b,a,x0):
23
- return K / (1 + b*np.exp(-a*(x-x0)))
24
-
25
- def save_scenario(scenario_name):
26
- """
27
- Save the current scenario with the mentioned name in the session state
28
-
29
- Parameters
30
- ----------
31
- scenario_name
32
- Name of the scenario to be saved
33
- """
34
- if 'saved_scenarios' not in st.session_state:
35
- st.session_state = OrderedDict()
36
-
37
- #st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
38
- st.session_state['saved_scenarios'][scenario_name] = class_to_dict(st.session_state['scenario'])
39
- st.session_state['scenario_input'] = ""
40
- print(type(st.session_state['saved_scenarios']))
41
- with open('../saved_scenarios.pkl', 'wb') as f:
42
- pickle.dump(st.session_state['saved_scenarios'],f)
43
-
44
-
45
- def reset_curve_parameters():
46
- del st.session_state['K']
47
- del st.session_state['b']
48
- del st.session_state['a']
49
- del st.session_state['x0']
50
-
51
- def update_response_curve():
52
- # st.session_state['rcs'][selected_channel_name]['K'] = st.session_state['K']
53
- # st.session_state['rcs'][selected_channel_name]['b'] = st.session_state['b']
54
- # st.session_state['rcs'][selected_channel_name]['a'] = st.session_state['a']
55
- # st.session_state['rcs'][selected_channel_name]['x0'] = st.session_state['x0']
56
- # rcs = st.session_state['rcs']
57
- _channel_class = st.session_state['scenario'].channels[selected_channel_name]
58
- _channel_class.update_response_curves({
59
- 'K' : st.session_state['K'],
60
- 'b' : st.session_state['b'],
61
- 'a' : st.session_state['a'],
62
- 'x0' : st.session_state['x0']})
63
-
64
-
65
- # authenticator = st.session_state.get('authenticator')
66
- # if authenticator is None:
67
- # authenticator = load_authenticator()
68
-
69
- # name, authentication_status, username = authenticator.login('Login', 'main')
70
- # auth_status = st.session_state.get('authentication_status')
71
-
72
- # if auth_status == True:
73
- # is_state_initiaized = st.session_state.get('initialized',False)
74
- # if not is_state_initiaized:
75
- # print("Scenario page state reloaded")
76
-
77
- # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
78
- st.set_page_config(layout='wide')
79
- load_local_css('styles.css')
80
- set_header()
81
- def panel_fetch(file_selected):
82
- raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
83
-
84
- if "Panel" in raw_data_mmm_df.columns:
85
- panel = list(set(raw_data_mmm_df["Panel"]))
86
- else:
87
- raw_data_mmm_df = None
88
- panel = None
89
-
90
- return panel
91
-
92
- metrics_selected='Revenue'
93
-
94
- file_selected = (
95
- f"Overview_data_test_panel@#{metrics_selected}.xlsx"
96
- )
97
- panel_list = panel_fetch(file_selected)
98
-
99
-
100
-
101
- if "used_response_metrics" in st.session_state and st.session_state['used_response_metrics']!=[]:
102
- sel_target_col = st.selectbox("Select the response metric", st.session_state['used_response_metrics'])
103
- target_col = sel_target_col.lower().replace(" ", "_").replace('-', '').replace(':', '').replace("__", "_")
104
- else :
105
- sel_target_col = 'Total Approved Accounts - Revenue'
106
- target_col = 'total_approved_accounts_revenue'
107
-
108
- st.subheader("Build response curves")
109
-
110
-
111
- st.session_state['selected_markets']= st.selectbox(
112
- "Select Markets",
113
- ["Total Market"] + panel_list,
114
- index=0,
115
- )
116
- initialize_data(target_col,st.session_state['selected_markets'])
117
-
118
-
119
-
120
- channels_list = st.session_state['channels_list']
121
- selected_channel_name = st.selectbox('Channel', st.session_state['channels_list'], format_func=channel_name_formating,on_change=reset_curve_parameters)
122
-
123
- rcs = {}
124
- for channel_name in channels_list:
125
- rcs[channel_name] = st.session_state['scenario'].channels[channel_name].response_curve_params
126
- # rcs = st.session_state['rcs']
127
-
128
-
129
- if 'K' not in st.session_state:
130
- st.session_state['K'] = rcs[selected_channel_name]['K']
131
- if 'b' not in st.session_state:
132
- st.session_state['b'] = rcs[selected_channel_name]['b']
133
- if 'a' not in st.session_state:
134
- st.session_state['a'] = rcs[selected_channel_name]['a']
135
- if 'x0' not in st.session_state:
136
- st.session_state['x0'] = rcs[selected_channel_name]['x0']
137
-
138
- x = st.session_state['actual_input_df'][selected_channel_name].values
139
- y = st.session_state['actual_contribution_df'][selected_channel_name].values
140
-
141
- power = (np.ceil(np.log(x.max()) / np.log(10) )- 3)
142
-
143
- # fig = px.scatter(x, s_curve(x/10**power,
144
- # st.session_state['K'],
145
- # st.session_state['b'],
146
- # st.session_state['a'],
147
- # st.session_state['x0']))
148
-
149
- fig = px.scatter(x=x, y=y)
150
- fig.add_trace(go.Scatter(x=sorted(x), y=s_curve(sorted(x)/10**power,st.session_state['K'],
151
- st.session_state['b'],
152
- st.session_state['a'],
153
- st.session_state['x0']),
154
- line=dict(color='red')))
155
-
156
- fig.update_layout(title_text="Response Curve",showlegend=False)
157
- fig.update_annotations(font_size=10)
158
- fig.update_xaxes(title='Spends')
159
- fig.update_yaxes(title=sel_target_col)
160
-
161
- st.plotly_chart(fig,use_container_width=True)
162
-
163
- r2 = r2_score(y, s_curve(x / 10**power,
164
- st.session_state['K'],
165
- st.session_state['b'],
166
- st.session_state['a'],
167
- st.session_state['x0']))
168
-
169
- st.metric('R2',round(r2,2))
170
- columns = st.columns(4)
171
-
172
- with columns[0]:
173
- st.number_input('K',key='K',format="%0.5f")
174
- with columns[1]:
175
- st.number_input('b',key='b',format="%0.5f")
176
- with columns[2]:
177
- st.number_input('a',key='a',step=0.0001,format="%0.5f")
178
- with columns[3]:
179
- st.number_input('x0',key='x0',format="%0.5f")
180
-
181
-
182
- st.button('Update parameters',on_click=update_response_curve)
183
- st.button('Reset parameters',on_click=reset_curve_parameters)
184
- scenario_name = st.text_input('Scenario name', key='scenario_input',placeholder='Scenario name',label_visibility='collapsed')
185
- st.button('Save', on_click=lambda : save_scenario(scenario_name),disabled=len(st.session_state['scenario_input']) == 0)
186
-
187
- file_name = st.text_input('rcs download file name', key='file_name_input',placeholder='file name',label_visibility='collapsed')
188
- st.download_button(
189
- label="Download response curves",
190
- data=json.dumps(rcs),
191
- file_name=f"{file_name}.json",
192
- mime="application/json",
193
- disabled= len(file_name) == 0,
194
- )
195
-
196
-
197
- def s_curve_derivative(x, K, b, a, x0):
198
- # Derivative of the S-curve function
199
- return a * b * K * np.exp(-a * (x - x0)) / ((1 + b * np.exp(-a * (x - x0))) ** 2)
200
-
201
- # Parameters of the S-curve
202
- K = st.session_state['K']
203
- b = st.session_state['b']
204
- a = st.session_state['a']
205
- x0 = st.session_state['x0']
206
-
207
- # Optimized spend value obtained from the tool
208
- optimized_spend = st.number_input('value of x') # Replace this with your optimized spend value
209
-
210
- # Calculate the slope at the optimized spend value
211
- slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
212
-
213
  st.write("Slope ", slope_at_optimized_spend)
 
1
+ import streamlit as st
2
+ import plotly.express as px
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from utilities_with_panel import channel_name_formating, load_authenticator, initialize_data
6
+ from sklearn.metrics import r2_score
7
+ from collections import OrderedDict
8
+ from classes import class_from_dict,class_to_dict
9
+ import pickle
10
+ import json
11
+ import pandas as pd
12
+ from utilities import (
13
+ load_local_css,
14
+ set_header,
15
+ channel_name_formating,
16
+ )
17
+
18
+ for k, v in st.session_state.items():
19
+ if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
20
+ st.session_state[k] = v
21
+
22
+ def s_curve(x,K,b,a,x0):
23
+ return K / (1 + b*np.exp(-a*(x-x0)))
24
+
25
+ def save_scenario(scenario_name):
26
+ """
27
+ Save the current scenario with the mentioned name in the session state
28
+
29
+ Parameters
30
+ ----------
31
+ scenario_name
32
+ Name of the scenario to be saved
33
+ """
34
+ if 'saved_scenarios' not in st.session_state:
35
+ st.session_state = OrderedDict()
36
+
37
+ #st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
38
+ st.session_state['saved_scenarios'][scenario_name] = class_to_dict(st.session_state['scenario'])
39
+ st.session_state['scenario_input'] = ""
40
+ print(type(st.session_state['saved_scenarios']))
41
+ with open('../saved_scenarios.pkl', 'wb') as f:
42
+ pickle.dump(st.session_state['saved_scenarios'],f)
43
+
44
+
45
+ def reset_curve_parameters():
46
+ del st.session_state['K']
47
+ del st.session_state['b']
48
+ del st.session_state['a']
49
+ del st.session_state['x0']
50
+
51
+ def update_response_curve():
52
+ # st.session_state['rcs'][selected_channel_name]['K'] = st.session_state['K']
53
+ # st.session_state['rcs'][selected_channel_name]['b'] = st.session_state['b']
54
+ # st.session_state['rcs'][selected_channel_name]['a'] = st.session_state['a']
55
+ # st.session_state['rcs'][selected_channel_name]['x0'] = st.session_state['x0']
56
+ # rcs = st.session_state['rcs']
57
+ _channel_class = st.session_state['scenario'].channels[selected_channel_name]
58
+ _channel_class.update_response_curves({
59
+ 'K' : st.session_state['K'],
60
+ 'b' : st.session_state['b'],
61
+ 'a' : st.session_state['a'],
62
+ 'x0' : st.session_state['x0']})
63
+
64
+
65
+ # authenticator = st.session_state.get('authenticator')
66
+ # if authenticator is None:
67
+ # authenticator = load_authenticator()
68
+
69
+ # name, authentication_status, username = authenticator.login('Login', 'main')
70
+ # auth_status = st.session_state.get('authentication_status')
71
+
72
+ # if auth_status == True:
73
+ # is_state_initiaized = st.session_state.get('initialized',False)
74
+ # if not is_state_initiaized:
75
+ # print("Scenario page state reloaded")
76
+
77
+ # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
78
+ st.set_page_config(layout='wide')
79
+ load_local_css('styles.css')
80
+ set_header()
81
+ def panel_fetch(file_selected):
82
+ raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
83
+
84
+ if "Panel" in raw_data_mmm_df.columns:
85
+ panel = list(set(raw_data_mmm_df["Panel"]))
86
+ else:
87
+ raw_data_mmm_df = None
88
+ panel = None
89
+
90
+ return panel
91
+
92
+ metrics_selected='revenue'
93
+
94
+ file_selected = (
95
+ f"Overview_data_test_panel@#{metrics_selected}.xlsx"
96
+ )
97
+ panel_list = panel_fetch(file_selected)
98
+
99
+
100
+
101
+ if "used_response_metrics" in st.session_state and st.session_state['used_response_metrics']!=[]:
102
+ sel_target_col = st.selectbox("Select the response metric", st.session_state['used_response_metrics'])
103
+ target_col = sel_target_col.lower().replace(" ", "_").replace('-', '').replace(':', '').replace("__", "_")
104
+ else :
105
+ sel_target_col = 'Total Approved Accounts - Revenue'
106
+ target_col = 'total_approved_accounts_revenue'
107
+
108
+ st.subheader("Build response curves")
109
+
110
+
111
+ st.session_state['selected_markets']= st.selectbox(
112
+ "Select Markets",
113
+ ["Total Market"] + panel_list,
114
+ index=0,
115
+ )
116
+ initialize_data(target_col,st.session_state['selected_markets'])
117
+
118
+
119
+
120
+ channels_list = st.session_state['channels_list']
121
+ selected_channel_name = st.selectbox('Channel', st.session_state['channels_list'], format_func=channel_name_formating,on_change=reset_curve_parameters)
122
+
123
+ rcs = {}
124
+ for channel_name in channels_list:
125
+ rcs[channel_name] = st.session_state['scenario'].channels[channel_name].response_curve_params
126
+ # rcs = st.session_state['rcs']
127
+
128
+
129
+ if 'K' not in st.session_state:
130
+ st.session_state['K'] = rcs[selected_channel_name]['K']
131
+ if 'b' not in st.session_state:
132
+ st.session_state['b'] = rcs[selected_channel_name]['b']
133
+ if 'a' not in st.session_state:
134
+ st.session_state['a'] = rcs[selected_channel_name]['a']
135
+ if 'x0' not in st.session_state:
136
+ st.session_state['x0'] = rcs[selected_channel_name]['x0']
137
+
138
+ x = st.session_state['actual_input_df'][selected_channel_name].values
139
+ y = st.session_state['actual_contribution_df'][selected_channel_name].values
140
+
141
+ power = (np.ceil(np.log(x.max()) / np.log(10) )- 3)
142
+
143
+ # fig = px.scatter(x, s_curve(x/10**power,
144
+ # st.session_state['K'],
145
+ # st.session_state['b'],
146
+ # st.session_state['a'],
147
+ # st.session_state['x0']))
148
+
149
+ fig = px.scatter(x=x, y=y)
150
+ fig.add_trace(go.Scatter(x=sorted(x), y=s_curve(sorted(x)/10**power,st.session_state['K'],
151
+ st.session_state['b'],
152
+ st.session_state['a'],
153
+ st.session_state['x0']),
154
+ line=dict(color='red')))
155
+
156
+ fig.update_layout(title_text="Response Curve",showlegend=False)
157
+ fig.update_annotations(font_size=10)
158
+ fig.update_xaxes(title='Spends')
159
+ fig.update_yaxes(title=sel_target_col)
160
+
161
+ st.plotly_chart(fig,use_container_width=True)
162
+
163
+ r2 = r2_score(y, s_curve(x / 10**power,
164
+ st.session_state['K'],
165
+ st.session_state['b'],
166
+ st.session_state['a'],
167
+ st.session_state['x0']))
168
+
169
+ st.metric('R2',round(r2,2))
170
+ columns = st.columns(4)
171
+
172
+ with columns[0]:
173
+ st.number_input('K',key='K',format="%0.5f")
174
+ with columns[1]:
175
+ st.number_input('b',key='b',format="%0.5f")
176
+ with columns[2]:
177
+ st.number_input('a',key='a',step=0.0001,format="%0.5f")
178
+ with columns[3]:
179
+ st.number_input('x0',key='x0',format="%0.5f")
180
+
181
+
182
+ st.button('Update parameters',on_click=update_response_curve)
183
+ st.button('Reset parameters',on_click=reset_curve_parameters)
184
+ scenario_name = st.text_input('Scenario name', key='scenario_input',placeholder='Scenario name',label_visibility='collapsed')
185
+ st.button('Save', on_click=lambda : save_scenario(scenario_name),disabled=len(st.session_state['scenario_input']) == 0)
186
+
187
+ file_name = st.text_input('rcs download file name', key='file_name_input',placeholder='file name',label_visibility='collapsed')
188
+ st.download_button(
189
+ label="Download response curves",
190
+ data=json.dumps(rcs),
191
+ file_name=f"{file_name}.json",
192
+ mime="application/json",
193
+ disabled= len(file_name) == 0,
194
+ )
195
+
196
+
197
+ def s_curve_derivative(x, K, b, a, x0):
198
+ # Derivative of the S-curve function
199
+ return a * b * K * np.exp(-a * (x - x0)) / ((1 + b * np.exp(-a * (x - x0))) ** 2)
200
+
201
+ # Parameters of the S-curve
202
+ K = st.session_state['K']
203
+ b = st.session_state['b']
204
+ a = st.session_state['a']
205
+ x0 = st.session_state['x0']
206
+
207
+ # Optimized spend value obtained from the tool
208
+ optimized_spend = st.number_input('value of x') # Replace this with your optimized spend value
209
+
210
+ # Calculate the slope at the optimized spend value
211
+ slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
212
+
213
  st.write("Slope ", slope_at_optimized_spend)