NotXia commited on
Commit
6d735c8
1 Parent(s): 95ec5ce

Add scores in output

Browse files
Files changed (1) hide show
  1. pipeline.py +12 -2
pipeline.py CHANGED
@@ -25,6 +25,8 @@ class ExtSummPipeline(Pipeline):
25
  strategy_args : any
26
  Parameters of the strategy.
27
 
 
 
28
  Outputs
29
  -------
30
  selected_sents : list[str]
@@ -32,6 +34,8 @@ class ExtSummPipeline(Pipeline):
32
 
33
  selected_idxs : list[int]
34
  List of the indexes of the selected sentences in the original input
 
 
35
  """
36
 
37
 
@@ -44,6 +48,8 @@ class ExtSummPipeline(Pipeline):
44
  postprocess_kwargs["strategy"] = kwargs["strategy"]
45
  if "strategy_args" in kwargs:
46
  postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
 
 
47
 
48
  return {}, {}, postprocess_kwargs
49
 
@@ -95,7 +101,11 @@ class ExtSummPipeline(Pipeline):
95
  return { "predictions": out_predictions, "sentences": sentences }
96
 
97
 
98
- def postprocess(self, args, strategy: str="count", strategy_args=3):
99
  predictions = args["predictions"]
100
  sentences = args["sentences"]
101
- return select(sentences, predictions, strategy, strategy_args)
 
 
 
 
 
25
  strategy_args : any
26
  Parameters of the strategy.
27
 
28
+ out_scores : bool
29
+ If True, the score for each sentence is returned.
30
  Outputs
31
  -------
32
  selected_sents : list[str]
 
34
 
35
  selected_idxs : list[int]
36
  List of the indexes of the selected sentences in the original input
37
+
38
+ sents_scores : Tensor (optional)
39
  """
40
 
41
 
 
48
  postprocess_kwargs["strategy"] = kwargs["strategy"]
49
  if "strategy_args" in kwargs:
50
  postprocess_kwargs["strategy_args"] = kwargs["strategy_args"]
51
+ if "out_scores" in kwargs:
52
+ postprocess_kwargs["out_scores"] = kwargs["out_scores"]
53
 
54
  return {}, {}, postprocess_kwargs
55
 
 
101
  return { "predictions": out_predictions, "sentences": sentences }
102
 
103
 
104
+ def postprocess(self, args, strategy: str="count", strategy_args=3, out_scores=False):
105
  predictions = args["predictions"]
106
  sentences = args["sentences"]
107
+ out = select(sentences, predictions, strategy, strategy_args)
108
+
109
+ if out_scores: out += (predictions,)
110
+
111
+ return out