gpt-moe-mcts / example_usage.py
RobbiePasquale's picture
Initial commit of GPT-MoE-MCTS model
e443128
raw
history blame
759 Bytes
from transformers import GPT2Tokenizer
from modeling_gpt_moe_mcts import GPTMoEMCTSModel
from configuration_gpt_moe_mcts import GPTMoEMCTSConfig
# Initialize configuration
config = GPTMoEMCTSConfig()
# Initialize model
model = GPTMoEMCTSModel(config)
# Initialize tokenizer (using GPT2Tokenizer as a base)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Prepare input
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")
# Forward pass
outputs = model(**inputs)
# Get the predicted next token
next_token_logits = outputs.logits[0, -1, :]
next_token = next_token_logits.argmax()
# Decode the predicted token
predicted_text = tokenizer.decode(next_token)
print(f"Input: {text}")
print(f"Predicted next token: {predicted_text}")