Martijn van Beers
commited on
Commit
•
86d2882
1
Parent(s):
2578029
Remove unwanted fields in visualization
Browse filesRemove fields from the captum visualizationrecord that either don't make
sense for the demo, or we are unsure about their interpretation
- lib/util.py +12 -11
lib/util.py
CHANGED
@@ -31,10 +31,10 @@ class PyTMinMaxScalerVectorized(object):
|
|
31 |
def visualize_text(datarecords, legend=True):
|
32 |
dom = ["<table width: 100%>"]
|
33 |
rows = [
|
34 |
-
"<tr><th>True Label</th>"
|
35 |
"<th>Predicted Label</th>"
|
36 |
"<th>Attribution Label</th>"
|
37 |
-
"<th>Attribution Score</th>"
|
38 |
"<th>Word Importance</th>"
|
39 |
]
|
40 |
for datarecord in datarecords:
|
@@ -42,16 +42,17 @@ def visualize_text(datarecords, legend=True):
|
|
42 |
"".join(
|
43 |
[
|
44 |
"<tr>",
|
45 |
-
visualization.format_classname(datarecord.true_class),
|
46 |
-
visualization.format_classname(
|
47 |
-
"{0} ({1:.2f})".format(
|
48 |
-
datarecord.pred_class
|
49 |
-
)
|
50 |
-
),
|
|
|
51 |
visualization.format_classname(datarecord.attr_class),
|
52 |
-
visualization.format_classname(
|
53 |
-
"{0:.2f}".format(datarecord.attr_score)
|
54 |
-
),
|
55 |
visualization.format_word_importances(
|
56 |
datarecord.raw_input_ids, datarecord.word_attributions
|
57 |
),
|
|
|
31 |
def visualize_text(datarecords, legend=True):
|
32 |
dom = ["<table width: 100%>"]
|
33 |
rows = [
|
34 |
+
# "<tr><th>True Label</th>"
|
35 |
"<th>Predicted Label</th>"
|
36 |
"<th>Attribution Label</th>"
|
37 |
+
# "<th>Attribution Score</th>"
|
38 |
"<th>Word Importance</th>"
|
39 |
]
|
40 |
for datarecord in datarecords:
|
|
|
42 |
"".join(
|
43 |
[
|
44 |
"<tr>",
|
45 |
+
# visualization.format_classname(datarecord.true_class),
|
46 |
+
# visualization.format_classname(
|
47 |
+
# "{0} ({1:.2f})".format(
|
48 |
+
# datarecord.pred_class#, datarecord.pred_prob
|
49 |
+
# )
|
50 |
+
# ),
|
51 |
+
visualization.format_classname(datarecord.pred_class),
|
52 |
visualization.format_classname(datarecord.attr_class),
|
53 |
+
# visualization.format_classname(
|
54 |
+
# "{0:.2f}".format(datarecord.attr_score)
|
55 |
+
# ),
|
56 |
visualization.format_word_importances(
|
57 |
datarecord.raw_input_ids, datarecord.word_attributions
|
58 |
),
|