Martijn van Beers commited on
Commit
decf1d6
1 Parent(s): 86d2882

Always explain based on the positive class for IG

Browse files

With integrated Gradients, always use the positive class
as the target, so that the visualization always shows
green for words that contribute to a positive prediction
and red for words that contribute to a negative prediction

Files changed (1) hide show
  1. lib/integrated_gradients.py +3 -3
lib/integrated_gradients.py CHANGED
@@ -35,8 +35,8 @@ class IntegratedGradientsExplainer:
35
  def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20):
36
  try:
37
  output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
38
- if index is None:
39
- index = output.argmax(axis=-1).item()
40
 
41
  ablator = LayerIntegratedGradients(self.custom_forward, layer)
42
  input_tensor = input_ids
@@ -45,7 +45,7 @@ class IntegratedGradientsExplainer:
45
  inputs=input_ids,
46
  baselines=self.ref_token_id,
47
  additional_forward_args=(attention_mask),
48
- target=index,
49
  n_steps=steps,
50
  )
51
  return self.summarize_attributions(attributions).unsqueeze_(0), output, index
 
35
  def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20):
36
  try:
37
  output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
38
+ # if index is None:
39
+ # index = output.argmax(axis=-1).item()
40
 
41
  ablator = LayerIntegratedGradients(self.custom_forward, layer)
42
  input_tensor = input_ids
 
45
  inputs=input_ids,
46
  baselines=self.ref_token_id,
47
  additional_forward_args=(attention_mask),
48
+ target=1,
49
  n_steps=steps,
50
  )
51
  return self.summarize_attributions(attributions).unsqueeze_(0), output, index