Nathan Fradet commited on
Commit
9c80799
1 Parent(s): f5c6506

ruff formatting + changing gradio app loading

Browse files
Files changed (3) hide show
  1. app.py +14 -3
  2. ece.py +42 -31
  3. tests.py +3 -1
app.py CHANGED
@@ -1,6 +1,17 @@
 
 
1
  import evaluate
2
- from evaluate.utils import launch_gradio_widget
 
 
 
 
 
 
 
 
 
3
 
4
 
5
- module = evaluate.load("Natooz/ece")
6
- launch_gradio_widget(module)
 
1
+ """Application file."""
2
+
3
  import evaluate
4
+ import gradio as gr
5
+
6
+ """module = evaluate.load("Natooz/ece")
7
+ gradio_app = gr.Interface(
8
+ module,
9
+ inputs=gr.component(),
10
+ outputs=[gr.Image(label="Processed Image"), gr.Label(label="Result", num_top_classes=2)],
11
+ title=module.name,
12
+ )"""
13
+ gradio_app = gr.load("Natooz/ece", src="spaces")
14
 
15
 
16
+ if __name__ == "__main__":
17
+ gradio_app.launch()
ece.py CHANGED
@@ -1,27 +1,19 @@
1
- # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Dict
16
 
17
- import evaluate
18
  import datasets
19
- from torch import Tensor, LongTensor
 
20
  from torchmetrics.functional.classification.calibration_error import (
21
  binary_calibration_error,
22
  multiclass_calibration_error,
23
  )
24
 
 
 
25
 
26
  _CITATION = """\
27
  @InProceedings{huggingface:ece,
@@ -41,7 +33,8 @@ https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.h
41
  _KWARGS_DESCRIPTION = """
42
  Calculates how good are predictions given some references, using certain scores
43
  Args:
44
- predictions: list of predictions to score. They must have a shape (N,C,...) if multiclass, or (N,...) if binary.
 
45
  references: list of reference for each prediction, with a shape (N,...).
46
  Returns:
47
  ece: expected calibration error
@@ -65,11 +58,17 @@ Examples:
65
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
66
  class ECE(evaluate.Metric):
67
  """
68
- Proxy to the BinaryCalibrationError (ECE) metric of the torchmetrics package:
69
- https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
 
70
  """
71
 
72
- def _info(self):
 
 
 
 
 
73
  return evaluate.MetricInfo(
74
  # This is the description that will appear on the modules page.
75
  module_type="metric",
@@ -94,31 +93,43 @@ class ECE(evaluate.Metric):
94
  ],
95
  )
96
 
97
- def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]:
98
- """Returns the ece.
99
- See the torchmetrics documentation for more information on the arguments to pass.
 
 
 
 
 
 
 
100
  https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
101
  predictions: (N,C,...) if multiclass or (N,...) if binary
102
- references: (N,...)
103
 
104
- If "num_classes" is not provided in a multiclasses setting, the number maximum label index will
105
- be used as "num_classes".
106
  """
107
  # Convert the input
108
  predictions = Tensor(predictions)
109
  references = LongTensor(references)
110
 
111
  # Determine number of classes / binary or multiclass
112
- error_msg = "Expected to have predictions with shape (N,C,...) for multiclass or (N,...) for binary, " \
113
- f"and references with shape (N,...), but got {predictions.shape} and {references.shape}"
 
 
 
114
  binary = True
115
  if predictions.dim() == references.dim() + 1: # multiclass
116
  binary = False
117
  if "num_classes" not in kwargs:
118
  kwargs["num_classes"] = int(predictions.shape[1])
119
  elif predictions.dim() == references.dim() and "num_classes" in kwargs:
120
- raise ValueError("You gave the num_classes argument, with predictions and references having the"
121
- "same number of dimensions. " + error_msg)
 
 
122
  elif predictions.dim() != references.dim():
123
  raise ValueError("Bad input shape. " + error_msg)
124
 
 
1
+ """ECE metric file."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
 
 
 
 
 
 
 
 
 
 
6
 
 
7
  import datasets
8
+ import evaluate
9
+ from torch import LongTensor, Tensor
10
  from torchmetrics.functional.classification.calibration_error import (
11
  binary_calibration_error,
12
  multiclass_calibration_error,
13
  )
14
 
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Iterable
17
 
18
  _CITATION = """\
19
  @InProceedings{huggingface:ece,
 
33
  _KWARGS_DESCRIPTION = """
34
  Calculates how good are predictions given some references, using certain scores
35
  Args:
36
+ predictions: list of predictions to score. They must have a shape (N,C,...) if
37
+ multiclass, or (N,...) if binary.
38
  references: list of reference for each prediction, with a shape (N,...).
39
  Returns:
40
  ece: expected calibration error
 
58
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
59
  class ECE(evaluate.Metric):
60
  """
61
+ Module for the BinaryCalibrationError (ECE) metric of the torchmetrics package.
62
+
63
+ https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html.
64
  """
65
 
66
+ def _info(self) -> evaluate.MetricInfo:
67
+ """
68
+ Return the module info.
69
+
70
+ :return: module info.
71
+ """
72
  return evaluate.MetricInfo(
73
  # This is the description that will appear on the modules page.
74
  module_type="metric",
 
93
  ],
94
  )
95
 
96
+ def _compute(
97
+ self,
98
+ predictions: Iterable[float] | None = None,
99
+ references: Iterable[int] | None = None,
100
+ **kwargs
101
+ ) -> dict[str, float]:
102
+ """
103
+ Return the Expected Calibration Error (ECE).
104
+
105
+ See the torchmetrics documentation for more information on the method.
106
  https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
107
  predictions: (N,C,...) if multiclass or (N,...) if binary
108
+ references: (N,...).
109
 
110
+ If "num_classes" is not provided in a multiclass setting, the number maximum
111
+ label index will be used as "num_classes".
112
  """
113
  # Convert the input
114
  predictions = Tensor(predictions)
115
  references = LongTensor(references)
116
 
117
  # Determine number of classes / binary or multiclass
118
+ error_msg = (
119
+ "Expected to have predictions with shape (N,C,...) for multiclass or "
120
+ "(N,...) for binary, and references with shape (N,...), but got "
121
+ f"{predictions.shape} and {references.shape}"
122
+ )
123
  binary = True
124
  if predictions.dim() == references.dim() + 1: # multiclass
125
  binary = False
126
  if "num_classes" not in kwargs:
127
  kwargs["num_classes"] = int(predictions.shape[1])
128
  elif predictions.dim() == references.dim() and "num_classes" in kwargs:
129
+ raise ValueError(
130
+ "You gave the num_classes argument, with predictions and references "
131
+ "having the same number of dimensions. " + error_msg
132
+ )
133
  elif predictions.dim() != references.dim():
134
  raise ValueError("Bad input shape. " + error_msg)
135
 
tests.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  test_cases = [
2
  {
3
  "predictions": [0, 0],
@@ -14,4 +16,4 @@ test_cases = [
14
  "references": [1, 1],
15
  "result": {"metric_score": 0.5}
16
  }
17
- ]
 
1
+ """Test cases."""
2
+
3
  test_cases = [
4
  {
5
  "predictions": [0, 0],
 
16
  "references": [1, 1],
17
  "result": {"metric_score": 0.5}
18
  }
19
+ ]