sdiazlor HF staff commited on
Commit
2a1af1a
1 Parent(s): 07a8bbc

bug: correct generation parameters

Browse files
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -427,6 +427,7 @@ def push_dataset_to_hub(
427
 
428
  if task == TEXTCAT_TASK:
429
  if num_labels == 1:
 
430
  features = Features(
431
  {"text": Value("string"), "label": ClassLabel(names=labels)}
432
  )
 
427
 
428
  if task == TEXTCAT_TASK:
429
  if num_labels == 1:
430
+ dataframe["label"] = dataframe["label"].replace("", None)
431
  features = Features(
432
  {"text": Value("string"), "label": ClassLabel(names=labels)}
433
  )
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -53,6 +53,9 @@ def push_dataset_to_hub(
53
  num_labels: int = 1,
54
  ):
55
  original_dataframe = dataframe.copy(deep=True)
 
 
 
56
  labels = get_preprocess_labels(labels)
57
  try:
58
  push_to_hub_base(
@@ -80,6 +83,9 @@ def push_dataset_to_argilla(
80
  labels: List[str] = None,
81
  ) -> pd.DataFrame:
82
  original_dataframe = dataframe.copy(deep=True)
 
 
 
83
  try:
84
  progress(0.1, desc="Setting up user and workspace")
85
  client = get_argilla_client()
 
53
  num_labels: int = 1,
54
  ):
55
  original_dataframe = dataframe.copy(deep=True)
56
+ dataframe = dataframe[
57
+ (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
58
+ ]
59
  labels = get_preprocess_labels(labels)
60
  try:
61
  push_to_hub_base(
 
83
  labels: List[str] = None,
84
  ) -> pd.DataFrame:
85
  original_dataframe = dataframe.copy(deep=True)
86
+ dataframe = dataframe[
87
+ (dataframe["text"].str.strip() != "") & (dataframe["text"].notna())
88
+ ]
89
  try:
90
  progress(0.1, desc="Setting up user and workspace")
91
  client = get_argilla_client()
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -114,9 +114,11 @@ with Pipeline(name="textcat") as pipeline:
114
  "temperature": 0.8,
115
  "max_new_tokens": 2048,
116
  "do_sample": True,
117
- "seed": random.randint(0, 2**32 - 1),
 
118
  }},
119
  ),
 
120
  difficulty={None if difficulty == "mixed" else repr(difficulty)},
121
  clarity={None if clarity == "mixed" else repr(clarity)},
122
  num_generations={num_rows},
@@ -182,11 +184,13 @@ def get_textcat_generator(difficulty, clarity, is_sample):
182
  "temperature": 0.9,
183
  "max_new_tokens": 256 if is_sample else 2048,
184
  "do_sample": True,
185
- "seed": random.randint(0, 2**32 - 1),
 
186
  },
187
  ),
188
  difficulty=None if difficulty == "mixed" else difficulty,
189
  clarity=None if clarity == "mixed" else clarity,
 
190
  )
191
  textcat_generator.load()
192
  return textcat_generator
 
114
  "temperature": 0.8,
115
  "max_new_tokens": 2048,
116
  "do_sample": True,
117
+ "top_k": 50,
118
+ "top_p": 0.95,
119
  }},
120
  ),
121
+ seed=random.randint(0, 2**32 - 1),
122
  difficulty={None if difficulty == "mixed" else repr(difficulty)},
123
  clarity={None if clarity == "mixed" else repr(clarity)},
124
  num_generations={num_rows},
 
184
  "temperature": 0.9,
185
  "max_new_tokens": 256 if is_sample else 2048,
186
  "do_sample": True,
187
+ "top_k": 50,
188
+ "top_p": 0.95,
189
  },
190
  ),
191
  difficulty=None if difficulty == "mixed" else difficulty,
192
  clarity=None if clarity == "mixed" else clarity,
193
+ seed=random.randint(0, 2**32 - 1),
194
  )
195
  textcat_generator.load()
196
  return textcat_generator
src/distilabel_dataset_generator/utils.py CHANGED
@@ -124,4 +124,4 @@ def get_argilla_client() -> Union[rg.Argilla, None]:
124
  return None
125
 
126
  def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
127
- return [label.lower().strip() for label in labels] if labels else []
 
124
  return None
125
 
126
  def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
127
+ return list(set([label.lower().strip() for label in labels])) if labels else []