charleslin4 commited on
Commit
a9853a7
1 Parent(s): 9b78f9c

Incomplete editing implementation

Browse files
Files changed (3) hide show
  1. algs/lu.py +1 -1
  2. app.py +10 -7
  3. config.py +1 -0
algs/lu.py CHANGED
@@ -67,7 +67,7 @@ class LU(EditableModel):
67
 
68
  return output
69
 
70
- def edit(self, batch, condition=None):
71
  edit_model = self.model.eval()
72
  if "bert" in self.config.model.name.lower():
73
  _, encoder_states = self.model(**batch, output_hidden_states=True)
 
67
 
68
  return output
69
 
70
+ def edit(self, batch, condition=None, detach_history=False):
71
  edit_model = self.model.eval()
72
  if "bert" in self.config.model.name.lower():
73
  _, encoder_states = self.model(**batch, output_hidden_states=True)
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import time
 
4
  import importlib
5
  from torch.cuda import is_available as use_cuda
6
 
@@ -28,13 +29,11 @@ def reset():
28
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
29
 
30
  selected_alg = st.session_state.alg_selector
31
- selected_alg_idx = EDIT_ALGS.index(selected_alg)
32
-
 
 
33
  with st.spinner('Loading model...'):
34
- alg_abbrv = selected_alg[:selected_alg.index(":")]
35
- alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
36
- alg_class = getattr(alg_module, alg_abbrv.upper())
37
- st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
38
  st.session_state.editable_model = alg_class(
39
  st.session_state.model,
40
  st.session_state.config,
@@ -44,7 +43,11 @@ def reset():
44
  def apply_edit():
45
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
46
 
47
- ############# Actually do the edit to the model
 
 
 
 
48
 
49
  def sample_model():
50
  input_str = str(test_input)
 
1
  import streamlit as st
2
  import pandas as pd
3
  import time
4
+ import copy
5
  import importlib
6
  from torch.cuda import is_available as use_cuda
7
 
 
29
  st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)
30
 
31
  selected_alg = st.session_state.alg_selector
32
+ alg_abbrv = selected_alg[:selected_alg.index(":")]
33
+ alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
34
+ alg_class = getattr(alg_module, alg_abbrv.upper())
35
+ st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
36
  with st.spinner('Loading model...'):
 
 
 
 
37
  st.session_state.editable_model = alg_class(
38
  st.session_state.model,
39
  st.session_state.config,
 
43
  def apply_edit():
44
  st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]
45
 
46
+ with st.spinner("Editing model..."):
47
+ input_ids = st.session_state.tokenizer(str(edit_input), return_tensors="pt")["input_ids"].to(st.session_state.device)
48
+ label_ids = st.session_state.tokenizer(str(edit_label), return_tensors="pt")["input_ids"].to(st.session_state.device)
49
+ edit_sample = {"input_ids": input_ids, "labels": label_ids}
50
+ st.session_state.editable_model, _ = st.session_state.editable_model.edit(edit_sample, detach_history=True)
51
 
52
  def sample_model():
53
  input_str = str(test_input)
config.py CHANGED
@@ -24,6 +24,7 @@ ft_config = OmegaConf.create({
24
  "device": "cuda" if use_cuda() else "cpu",
25
  "edit_lr": 5e-6,
26
  "train_base": False,
 
27
  "ft": {
28
  "verbose": False,
29
  "max_edit_steps": 100,
 
24
  "device": "cuda" if use_cuda() else "cpu",
25
  "edit_lr": 5e-6,
26
  "train_base": False,
27
+ "grad_clip": 100,
28
  "ft": {
29
  "verbose": False,
30
  "max_edit_steps": 100,