|
import argparse |
|
import os |
|
import sys |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, random_split, TensorDataset |
|
from src.dataset import TokenizerDataset |
|
from src.bert import BERT |
|
from src.pretrainer import BERTFineTuneTrainer1 |
|
from src.vocab import Vocab |
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomBERTModel(nn.Module): |
|
def __init__(self, vocab_size, output_dim, pre_trained_model_path): |
|
super(CustomBERTModel, self).__init__() |
|
self.hidden = 764 |
|
self.bert = BERT(vocab_size=vocab_size, hidden=self.hidden, n_layers=12, attn_heads=12, dropout=0.1) |
|
|
|
|
|
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu')) |
|
if isinstance(checkpoint, dict): |
|
self.bert.load_state_dict(checkpoint) |
|
else: |
|
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.") |
|
|
|
self.fc = nn.Linear(self.hidden, output_dim) |
|
|
|
def forward(self, sequence, segment_info): |
|
x = self.bert(sequence, segment_info) |
|
cls_embeddings = x[:, 0] |
|
logits = self.fc(cls_embeddings) |
|
return logits |
|
|
|
def preprocess_labels(label_csv_path): |
|
try: |
|
labels_df = pd.read_csv(label_csv_path) |
|
labels = labels_df['last_hint_class'].values.astype(int) |
|
return torch.tensor(labels, dtype=torch.long) |
|
except Exception as e: |
|
print(f"Error reading dataset file: {e}") |
|
return None |
|
|
|
|
|
def preprocess_data(data_path, vocab, max_length=128): |
|
try: |
|
with open(data_path, 'r') as f: |
|
sequences = f.readlines() |
|
except Exception as e: |
|
print(f"Error reading data file: {e}") |
|
return None, None |
|
|
|
if len(sequences) == 0: |
|
raise ValueError(f"No sequences found in data file {data_path}. Check the file content.") |
|
|
|
tokenized_sequences = [] |
|
|
|
for sequence in sequences: |
|
sequence = sequence.strip() |
|
if sequence: |
|
encoded = vocab.to_seq(sequence, seq_len=max_length) |
|
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded)) |
|
segment_label = [0] * max_length |
|
|
|
tokenized_sequences.append({ |
|
'input_ids': torch.tensor(encoded), |
|
'segment_label': torch.tensor(segment_label) |
|
}) |
|
|
|
if not tokenized_sequences: |
|
raise ValueError("Tokenization resulted in an empty list. Check the sequences and tokenization logic.") |
|
|
|
tokenized_sequences = [t for t in tokenized_sequences if len(t['input_ids']) == max_length] |
|
|
|
if not tokenized_sequences: |
|
raise ValueError("All tokenized sequences are of unexpected length. This suggests an issue with the tokenization logic.") |
|
|
|
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0) |
|
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0) |
|
|
|
print(f"Input IDs shape: {input_ids.shape}") |
|
print(f"Segment labels shape: {segment_labels.shape}") |
|
|
|
return input_ids, segment_labels |
|
|
|
|
|
def collate_fn(batch): |
|
inputs = [] |
|
labels = [] |
|
segment_labels = [] |
|
|
|
for item in batch: |
|
if item is None: |
|
continue |
|
|
|
if isinstance(item, dict): |
|
inputs.append(item['input_ids'].unsqueeze(0)) |
|
labels.append(item['label'].unsqueeze(0)) |
|
segment_labels.append(item['segment_label'].unsqueeze(0)) |
|
|
|
if len(inputs) == 0 or len(segment_labels) == 0: |
|
print("Empty batch encountered. Returning None to skip this batch.") |
|
return None |
|
|
|
try: |
|
inputs = torch.cat(inputs, dim=0) |
|
labels = torch.cat(labels, dim=0) |
|
segment_labels = torch.cat(segment_labels, dim=0) |
|
except Exception as e: |
|
print(f"Error concatenating tensors: {e}") |
|
return None |
|
|
|
return { |
|
'input': inputs, |
|
'label': labels, |
|
'segment_label': segment_labels |
|
} |
|
|
|
def custom_collate_fn(batch): |
|
processed_batch = collate_fn(batch) |
|
|
|
if processed_batch is None or len(processed_batch['input']) == 0: |
|
|
|
return { |
|
'input': torch.zeros((1, 128), dtype=torch.long), |
|
'label': torch.zeros((1,), dtype=torch.long), |
|
'segment_label': torch.zeros((1, 128), dtype=torch.long) |
|
} |
|
|
|
return processed_batch |
|
|
|
|
|
def train_without_progress_status(trainer, epoch, shuffle): |
|
for epoch_idx in range(epoch): |
|
print(f"EP_train:{epoch_idx}:") |
|
for batch in trainer.train_data: |
|
if batch is None: |
|
continue |
|
|
|
|
|
if isinstance(batch, str): |
|
print(f"Error: Received a string instead of a dictionary in batch: {batch}") |
|
raise ValueError(f"Unexpected string in batch: {batch}") |
|
|
|
|
|
if isinstance(batch, dict): |
|
|
|
if all(key in batch for key in ['input_ids', 'segment_label', 'labels']): |
|
if all(isinstance(batch[key], torch.Tensor) for key in batch): |
|
try: |
|
print(f"Batch Structure: {batch}") |
|
trainer.iteration(epoch_idx, batch) |
|
except Exception as e: |
|
print(f"Error during batch processing: {e}") |
|
sys.stdout.flush() |
|
raise e |
|
else: |
|
print(f"Error: Expected all values in batch to be tensors, but got: {batch}") |
|
raise ValueError("Batch contains non-tensor values.") |
|
else: |
|
print(f"Error: Batch missing expected keys. Batch keys: {batch.keys()}") |
|
raise ValueError("Batch does not contain expected keys.") |
|
else: |
|
print(f"Error: Expected batch to be a dictionary but got {type(batch)} instead.") |
|
raise ValueError(f"Invalid batch structure: {batch}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(opt): |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
print(torch.cuda.is_available()) |
|
print(torch.cuda.device_count()) |
|
|
|
|
|
vocab = Vocab(opt.vocab_file) |
|
vocab.load_vocab() |
|
|
|
|
|
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128) |
|
labels = preprocess_labels(opt.dataset) |
|
|
|
if input_ids is None or segment_labels is None or labels is None: |
|
print("Error in preprocessing data. Exiting.") |
|
return |
|
|
|
|
|
input_ids = input_ids.to(device) |
|
segment_labels = segment_labels.to(device) |
|
labels = torch.tensor(labels, dtype=torch.long).to(device) |
|
|
|
|
|
dataset = TensorDataset(input_ids, segment_labels, labels) |
|
val_size = len(dataset) - int(0.8 * len(dataset)) |
|
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size]) |
|
|
|
|
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=32, |
|
shuffle=True, |
|
collate_fn=custom_collate_fn |
|
) |
|
val_dataloader = DataLoader( |
|
val_dataset, |
|
batch_size=32, |
|
shuffle=False, |
|
collate_fn=custom_collate_fn |
|
) |
|
|
|
|
|
custom_model = CustomBERTModel( |
|
vocab_size=len(vocab.vocab), |
|
output_dim=2, |
|
pre_trained_model_path=opt.pre_trained_model_path |
|
).to(device) |
|
|
|
|
|
trainer = BERTFineTuneTrainer1( |
|
bert=custom_model.bert, |
|
vocab_size=len(vocab.vocab), |
|
train_dataloader=train_dataloader, |
|
test_dataloader=val_dataloader, |
|
lr=5e-5, |
|
num_labels=2, |
|
with_cuda=torch.cuda.is_available(), |
|
log_freq=10, |
|
workspace_name=opt.output_dir, |
|
log_folder_path=opt.log_folder_path |
|
) |
|
|
|
|
|
trainer.train(epoch=20) |
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(opt.output_dir, exist_ok=True) |
|
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth') |
|
torch.save(custom_model, output_model_file) |
|
print(f'Model saved to {output_model_file}') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Fine-tune BERT model.') |
|
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.') |
|
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.') |
|
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.') |
|
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.') |
|
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.') |
|
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct_logs', help='Path to the folder for saving logs.') |
|
|
|
|
|
opt = parser.parse_args() |
|
main(opt) |