|
--- |
|
license: mit |
|
--- |
|
## Usage |
|
|
|
```python |
|
import torch |
|
from informer_models import InformerConfig, InformerForSequenceClassification |
|
|
|
model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification") |
|
|
|
model.to(device) |
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
for i, batch in enumerate(test_dataloader): |
|
print(f"processing batch {i}") |
|
batch = {k: v.to(device) for k, v in batch.items() if k != "objid"} |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
y_true.extend(batch['labels'].cpu().numpy()) |
|
y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy()) |
|
print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}") |
|
``` |