alefiury's picture
Update README.md
bb96b52
|
raw
history blame
4.47 kB
metadata
license: apache-2.0
tags:
  - generated_from_trainer
metrics:
  - f1
model-index:
  - name: weights
    results: []
datasets:
  - librispeech_asr

wav2vec2-large-xlsr-53-gender-recognition-librispeech

This model is a fine-tuned version of facebook/wav2vec2-xls-r-300m on Librispeech-clean-100 for gender recognition. It achieves the following results on the evaluation set:

  • Loss: 0.0061
  • F1: 0.9993

Compute your inferences

class DataColletor:
    def __init__(
        self,
        processor: Wav2Vec2Processor,
        sampling_rate: int = 16000,
        padding: Union[bool, str] = True,
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        label2id: Dict = None,
        max_audio_len: int = 5
    ):

        self.processor = processor
        self.sampling_rate = sampling_rate

        self.padding = padding
        self.max_length = max_length
        self.pad_to_multiple_of = pad_to_multiple_of

        self.label2id = label2id

        self.max_audio_len = max_audio_len

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = []
        label_features = []
        for feature in features:
            speech_array, sampling_rate = torchaudio.load(feature["input_values"])

            # Transform to Mono
            speech_array = torch.mean(speech_array, dim=0, keepdim=True)

            if sampling_rate != self.sampling_rate:
                transform = torchaudio.transforms.Resample(sampling_rate, self.sampling_rate)
                speech_array = transform(speech_array)
                sampling_rate = self.sampling_rate

            effective_size_len = sampling_rate * self.max_audio_len

            if speech_array.shape[-1] > effective_size_len:
                speech_array = speech_array[:, :effective_size_len]

            speech_array = speech_array.squeeze().numpy()
            input_tensor = self.processor(speech_array, sampling_rate=sampling_rate).input_values
            input_tensor = np.squeeze(input_tensor)

            input_features.append({"input_values": input_tensor})

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        return batch


label2id = {
    "female": 0,
    "male": 1
}

id2label = {
    0: "female",
    1: "male"
}

num_labels = 2

feature_extractor = AutoFeatureExtractor.from_pretrained("alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech")
model = AutoModelForAudioClassification.from_pretrained(
    pretrained_model_name_or_path="alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech",
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

data_collator = DataColletorTrain(
    feature_extractor,
    sampling_rate=16000,
    padding=True,
    label2id=label2id
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=16,
    collate_fn=data_collator,
    shuffle=False,
    num_workers=10
)

preds = predict(test_dataloader=test_dataloader, model=model)

Training and evaluation data

The Librispeech-clean-100 dataset was used to train the model, with 70% of the data used for training, 10% for validation, and 20% for testing.

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 3e-05
  • train_batch_size: 4
  • eval_batch_size: 4
  • seed: 42
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 16
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_ratio: 0.1
  • num_epochs: 1
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Validation Loss F1
0.002 1.0 1248 0.0061 0.9993

Framework versions

  • Transformers 4.28.0
  • Pytorch 2.0.0+cu118
  • Tokenizers 0.13.3