from transformers import GPT2Tokenizer from modeling_gpt_moe_mcts import GPTMoEMCTSModel from configuration_gpt_moe_mcts import GPTMoEMCTSConfig # Initialize configuration config = GPTMoEMCTSConfig() # Initialize model model = GPTMoEMCTSModel(config) # Initialize tokenizer (using GPT2Tokenizer as a base) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Prepare input text = "Hello, how are you?" inputs = tokenizer(text, return_tensors="pt") # Forward pass outputs = model(**inputs) # Get the predicted next token next_token_logits = outputs.logits[0, -1, :] next_token = next_token_logits.argmax() # Decode the predicted token predicted_text = tokenizer.decode(next_token) print(f"Input: {text}") print(f"Predicted next token: {predicted_text}")