astra / hint_fine_tuning.py
suryadev1's picture
fine
5c72fe4
import argparse
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from src.dataset import TokenizerDataset
from src.bert import BERT
from src.pretrainer import BERTFineTuneTrainer1
from src.vocab import Vocab
import pandas as pd
# class CustomBERTModel(nn.Module):
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
# super(CustomBERTModel, self).__init__()
# hidden_size = 768
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
# if isinstance(checkpoint, dict):
# self.bert.load_state_dict(checkpoint)
# elif isinstance(checkpoint, BERT):
# self.bert = checkpoint
# else:
# raise TypeError(f"Expected state_dict or BERT instance, got {type(checkpoint)} instead.")
# self.fc = nn.Linear(hidden_size, output_dim)
# def forward(self, sequence, segment_info):
# sequence = sequence.to(next(self.parameters()).device)
# segment_info = segment_info.to(sequence.device)
# if sequence.size(0) == 0 or sequence.size(1) == 0:
# raise ValueError("Input sequence tensor has 0 elements. Check data preprocessing.")
# x = self.bert(sequence, segment_info)
# print(f"BERT output shape: {x.shape}")
# if x.size(0) == 0 or x.size(1) == 0:
# raise ValueError("BERT output tensor has 0 elements. Check input dimensions.")
# cls_embeddings = x[:, 0]
# logits = self.fc(cls_embeddings)
# return logits
# class CustomBERTModel(nn.Module):
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
# super(CustomBERTModel, self).__init__()
# hidden_size = 764 # Ensure this is 768
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
# # Load the pre-trained model's state_dict
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
# if isinstance(checkpoint, dict):
# self.bert.load_state_dict(checkpoint)
# else:
# raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
# # Fully connected layer with input size 768
# self.fc = nn.Linear(hidden_size, output_dim)
# def forward(self, sequence, segment_info):
# sequence = sequence.to(next(self.parameters()).device)
# segment_info = segment_info.to(sequence.device)
# x = self.bert(sequence, segment_info)
# print(f"BERT output shape: {x.shape}") # Should output (batch_size, seq_len, 768)
# cls_embeddings = x[:, 0] # Extract CLS token embeddings
# print(f"CLS Embeddings shape: {cls_embeddings.shape}") # Should output (batch_size, 768)
# logits = self.fc(cls_embeddings) # Should now pass a tensor of size (batch_size, 768) to `fc`
# return logits
# for test
class CustomBERTModel(nn.Module):
def __init__(self, vocab_size, output_dim, pre_trained_model_path):
super(CustomBERTModel, self).__init__()
self.hidden = 764 # Ensure this is defined correctly
self.bert = BERT(vocab_size=vocab_size, hidden=self.hidden, n_layers=12, attn_heads=12, dropout=0.1)
# Load the pre-trained model's state_dict
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict):
self.bert.load_state_dict(checkpoint)
else:
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
self.fc = nn.Linear(self.hidden, output_dim)
def forward(self, sequence, segment_info):
x = self.bert(sequence, segment_info)
cls_embeddings = x[:, 0] # Extract CLS token embeddings
logits = self.fc(cls_embeddings) # Pass to fully connected layer
return logits
def preprocess_labels(label_csv_path):
try:
labels_df = pd.read_csv(label_csv_path)
labels = labels_df['last_hint_class'].values.astype(int)
return torch.tensor(labels, dtype=torch.long)
except Exception as e:
print(f"Error reading dataset file: {e}")
return None
def preprocess_data(data_path, vocab, max_length=128):
try:
with open(data_path, 'r') as f:
sequences = f.readlines()
except Exception as e:
print(f"Error reading data file: {e}")
return None, None
if len(sequences) == 0:
raise ValueError(f"No sequences found in data file {data_path}. Check the file content.")
tokenized_sequences = []
for sequence in sequences:
sequence = sequence.strip()
if sequence:
encoded = vocab.to_seq(sequence, seq_len=max_length)
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
segment_label = [0] * max_length
tokenized_sequences.append({
'input_ids': torch.tensor(encoded),
'segment_label': torch.tensor(segment_label)
})
if not tokenized_sequences:
raise ValueError("Tokenization resulted in an empty list. Check the sequences and tokenization logic.")
tokenized_sequences = [t for t in tokenized_sequences if len(t['input_ids']) == max_length]
if not tokenized_sequences:
raise ValueError("All tokenized sequences are of unexpected length. This suggests an issue with the tokenization logic.")
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
print(f"Input IDs shape: {input_ids.shape}")
print(f"Segment labels shape: {segment_labels.shape}")
return input_ids, segment_labels
def collate_fn(batch):
inputs = []
labels = []
segment_labels = []
for item in batch:
if item is None:
continue
if isinstance(item, dict):
inputs.append(item['input_ids'].unsqueeze(0))
labels.append(item['label'].unsqueeze(0))
segment_labels.append(item['segment_label'].unsqueeze(0))
if len(inputs) == 0 or len(segment_labels) == 0:
print("Empty batch encountered. Returning None to skip this batch.")
return None
try:
inputs = torch.cat(inputs, dim=0)
labels = torch.cat(labels, dim=0)
segment_labels = torch.cat(segment_labels, dim=0)
except Exception as e:
print(f"Error concatenating tensors: {e}")
return None
return {
'input': inputs,
'label': labels,
'segment_label': segment_labels
}
def custom_collate_fn(batch):
processed_batch = collate_fn(batch)
if processed_batch is None or len(processed_batch['input']) == 0:
# Return a valid batch with at least one element instead of an empty one
return {
'input': torch.zeros((1, 128), dtype=torch.long),
'label': torch.zeros((1,), dtype=torch.long),
'segment_label': torch.zeros((1, 128), dtype=torch.long)
}
return processed_batch
def train_without_progress_status(trainer, epoch, shuffle):
for epoch_idx in range(epoch):
print(f"EP_train:{epoch_idx}:")
for batch in trainer.train_data:
if batch is None:
continue
# Check if batch is a string (indicating an issue)
if isinstance(batch, str):
print(f"Error: Received a string instead of a dictionary in batch: {batch}")
raise ValueError(f"Unexpected string in batch: {batch}")
# Validate the batch structure before passing to iteration
if isinstance(batch, dict):
# Verify that all expected keys are present and that the values are tensors
if all(key in batch for key in ['input_ids', 'segment_label', 'labels']):
if all(isinstance(batch[key], torch.Tensor) for key in batch):
try:
print(f"Batch Structure: {batch}") # Debugging batch before iteration
trainer.iteration(epoch_idx, batch)
except Exception as e:
print(f"Error during batch processing: {e}")
sys.stdout.flush()
raise e # Propagate the exception for better debugging
else:
print(f"Error: Expected all values in batch to be tensors, but got: {batch}")
raise ValueError("Batch contains non-tensor values.")
else:
print(f"Error: Batch missing expected keys. Batch keys: {batch.keys()}")
raise ValueError("Batch does not contain expected keys.")
else:
print(f"Error: Expected batch to be a dictionary but got {type(batch)} instead.")
raise ValueError(f"Invalid batch structure: {batch}")
# def main(opt):
# # device = torch.device("cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# vocab = Vocab(opt.vocab_file)
# vocab.load_vocab()
# input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
# labels = preprocess_labels(opt.dataset)
# if input_ids is None or segment_labels is None or labels is None:
# print("Error in preprocessing data. Exiting.")
# return
# dataset = TensorDataset(input_ids, segment_labels, torch.tensor(labels, dtype=torch.long))
# val_size = len(dataset) - int(0.8 * len(dataset))
# val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
# train_dataloader = DataLoader(
# train_dataset,
# batch_size=32,
# shuffle=True,
# collate_fn=custom_collate_fn
# )
# val_dataloader = DataLoader(
# val_dataset,
# batch_size=32,
# shuffle=False,
# collate_fn=custom_collate_fn
# )
# custom_model = CustomBERTModel(
# vocab_size=len(vocab.vocab),
# output_dim=2,
# pre_trained_model_path=opt.pre_trained_model_path
# ).to(device)
# trainer = BERTFineTuneTrainer1(
# bert=custom_model.bert,
# vocab_size=len(vocab.vocab),
# train_dataloader=train_dataloader,
# test_dataloader=val_dataloader,
# lr=5e-5,
# num_labels=2,
# with_cuda=torch.cuda.is_available(),
# log_freq=10,
# workspace_name=opt.output_dir,
# log_folder_path=opt.log_folder_path
# )
# trainer.train(epoch=20)
# # os.makedirs(opt.output_dir, exist_ok=True)
# # output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model.pth')
# # torch.save(custom_model.state_dict(), output_model_file)
# # print(f'Model saved to {output_model_file}')
# os.makedirs(opt.output_dir, exist_ok=True)
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
# torch.save(custom_model, output_model_file)
# print(f'Model saved to {output_model_file}')
def main(opt):
# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available()) # Should return True if GPU is available
print(torch.cuda.device_count())
# Load vocabulary
vocab = Vocab(opt.vocab_file)
vocab.load_vocab()
# Preprocess data and labels
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
labels = preprocess_labels(opt.dataset)
if input_ids is None or segment_labels is None or labels is None:
print("Error in preprocessing data. Exiting.")
return
# Transfer tensors to the correct device (GPU/CPU)
input_ids = input_ids.to(device)
segment_labels = segment_labels.to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
# Create TensorDataset and split into train and validation sets
dataset = TensorDataset(input_ids, segment_labels, labels)
val_size = len(dataset) - int(0.8 * len(dataset))
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
# Create DataLoaders for training and validation
train_dataloader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
collate_fn=custom_collate_fn
)
val_dataloader = DataLoader(
val_dataset,
batch_size=32,
shuffle=False,
collate_fn=custom_collate_fn
)
# Initialize custom BERT model and move it to the device
custom_model = CustomBERTModel(
vocab_size=len(vocab.vocab),
output_dim=2,
pre_trained_model_path=opt.pre_trained_model_path
).to(device)
# Initialize the fine-tuning trainer
trainer = BERTFineTuneTrainer1(
bert=custom_model.bert,
vocab_size=len(vocab.vocab),
train_dataloader=train_dataloader,
test_dataloader=val_dataloader,
lr=5e-5,
num_labels=2,
with_cuda=torch.cuda.is_available(),
log_freq=10,
workspace_name=opt.output_dir,
log_folder_path=opt.log_folder_path
)
# Train the model
trainer.train(epoch=20)
# Save the model to the specified output directory
# os.makedirs(opt.output_dir, exist_ok=True)
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
# torch.save(custom_model.state_dict(), output_model_file)
# print(f'Model saved to {output_model_file}')
os.makedirs(opt.output_dir, exist_ok=True)
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
torch.save(custom_model, output_model_file)
print(f'Model saved to {output_model_file}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct_logs', help='Path to the folder for saving logs.')
opt = parser.parse_args()
main(opt)