wanzin commited on
Commit
221fdce
1 Parent(s): 0006e80

add device

Browse files
Files changed (1) hide show
  1. dmxMetric.py +6 -1
dmxMetric.py CHANGED
@@ -3,6 +3,7 @@ import lm_eval
3
  from typing import Union, List, Optional
4
  from dmx.compressor.dmx import config_rules, DmxModel
5
  import datasets
 
6
 
7
  _DESCRIPTION = """
8
  Evaluation function using lm-eval with d-Matrix integration.
@@ -54,6 +55,7 @@ class DmxMetric(evaluate.Metric):
54
  batch_size: Optional[Union[int, str]] = None,
55
  max_batch_size: Optional[int] = None,
56
  limit: Optional[Union[int, float]] = None,
 
57
  revision: str = "main",
58
  trust_remote_code: bool = False,
59
  log_samples: bool = True,
@@ -63,7 +65,10 @@ class DmxMetric(evaluate.Metric):
63
  """
64
  Evaluate a model on multiple tasks and metrics using lm-eval with optional d-Matrix integration.
65
  """
66
- model_args = f"pretrained={model},revision={revision},trust_remote_code={str(trust_remote_code)}"
 
 
 
67
 
68
  lm = lm_eval.api.registry.get_model("hf").create_from_arg_string(
69
  model_args,
 
3
  from typing import Union, List, Optional
4
  from dmx.compressor.dmx import config_rules, DmxModel
5
  import datasets
6
+ import torch
7
 
8
  _DESCRIPTION = """
9
  Evaluation function using lm-eval with d-Matrix integration.
 
55
  batch_size: Optional[Union[int, str]] = None,
56
  max_batch_size: Optional[int] = None,
57
  limit: Optional[Union[int, float]] = None,
58
+ device: Optional[str] = None,
59
  revision: str = "main",
60
  trust_remote_code: bool = False,
61
  log_samples: bool = True,
 
65
  """
66
  Evaluate a model on multiple tasks and metrics using lm-eval with optional d-Matrix integration.
67
  """
68
+ if device is None:
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+
71
+ model_args = f"pretrained={model},revision={revision},trust_remote_code={str(trust_remote_code)},device={device}"
72
 
73
  lm = lm_eval.api.registry.get_model("hf").create_from_arg_string(
74
  model_args,