from transformers import AutoModel from torch import nn import pytorch_lightning as pl class MediumBert(pl.LightningModule): def __init__(self): super().__init__() self.bert_model = AutoModel.from_pretrained('asafaya/bert-medium-arabic') self.fc = nn.Linear(512,18) def forward(self,input_ids,attention_mask): out = self.bert_model(input_ids = input_ids, attention_mask =attention_mask)#inputs["input_ids"],inputs["token_type_ids"],inputs["attention_mask"]) pooler = out[1] out = self.fc(pooler) return out