|
import torch |
|
import torch.nn as nn |
|
from src.bert import BERT |
|
|
|
class CustomBERTModel(nn.Module): |
|
def __init__(self, vocab_size, output_dim, pre_trained_model_path): |
|
super(CustomBERTModel, self).__init__() |
|
hidden_size = 768 |
|
self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=4, attn_heads=8, dropout=0.1) |
|
|
|
|
|
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu')) |
|
if isinstance(checkpoint, dict): |
|
self.bert.load_state_dict(checkpoint) |
|
else: |
|
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.") |
|
|
|
|
|
self.fc = nn.Linear(hidden_size, output_dim) |
|
|
|
def forward(self, sequence, segment_info): |
|
sequence = sequence.to(next(self.parameters()).device) |
|
segment_info = segment_info.to(sequence.device) |
|
|
|
x = self.bert(sequence, segment_info) |
|
print(f"BERT output shape: {x.shape}") |
|
|
|
cls_embeddings = x[:, 0] |
|
print(f"CLS Embeddings shape: {cls_embeddings.shape}") |
|
|
|
logits = self.fc(cls_embeddings) |
|
|
|
return logits |
|
|