|
import argparse |
|
from typing import Dict |
|
|
|
import nemo.collections.asr as nemo_asr |
|
import torch |
|
from omegaconf import open_dict |
|
|
|
|
|
def evaluate_model( |
|
model_path: str = None, |
|
test_manifest: str = None, |
|
batch_size: int = 1, |
|
) -> Dict: |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
with open_dict(model.cfg): |
|
model.cfg.validation_ds.manifest_filepath = test_manifest |
|
model.cfg.validation_ds.batch_size = batch_size |
|
|
|
|
|
model.setup_test_data(model.cfg.validation_ds) |
|
|
|
wer_nums = [] |
|
wer_denoms = [] |
|
|
|
|
|
for test_batch in model.test_dataloader(): |
|
|
|
test_batch = [x for x in test_batch] |
|
targets = test_batch[2].to(device) |
|
targets_lengths = test_batch[3].to(device) |
|
|
|
log_probs, encoded_len, greedy_predictions = model(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)) |
|
|
|
model._wer.update(greedy_predictions, targets, targets_lengths) |
|
_, wer_num, wer_denom = model._wer.compute() |
|
model._wer.reset() |
|
wer_nums.append(wer_num.detach().cpu().numpy()) |
|
wer_denoms.append(wer_denom.detach().cpu().numpy()) |
|
|
|
del test_batch, log_probs, targets, targets_lengths, encoded_len, greedy_predictions |
|
|
|
|
|
wer_score = sum(wer_nums) / sum(wer_denoms) |
|
print({"WER_score": wer_score}) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.") |
|
parser.add_argument("--test_manifest", default=None, help="Path for train manifest JSON file.") |
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.") |
|
args = parser.parse_args() |
|
|
|
evaluate_model( |
|
model_path=args.model_path, |
|
test_manifest=args.test_manifest, |
|
batch_size=args.batch_size, |
|
) |