Spaces:
Sleeping
Sleeping
File size: 2,985 Bytes
0efd337 ec853a0 07b00c0 0e4ab50 2ad20f5 ce16e77 07b00c0 ce16e77 07b00c0 ec853a0 0e4ab50 07b00c0 ec853a0 0e4ab50 ec853a0 0e4ab50 07b00c0 2ad20f5 fb600ee f374df6 fb600ee 2ad20f5 0e4ab50 2ad20f5 0e4ab50 2ad20f5 07b00c0 0efd337 2ad20f5 0e4ab50 07b00c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from langchain.memory import ConversationBufferMemory
# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load the tokenizer and model for DistilGPT-2
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
model.to(device)
# Load summarization model (e.g., T5-small)
summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
def summarize_history(history):
input_ids = summarizer_tokenizer.encode(
"summarize: " + history,
return_tensors="pt"
).to(device)
summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()
# Define the chatbot function with memory
def chat_with_distilgpt2(input_text):
# Retrieve conversation history
conversation_history = memory.load_memory_variables({})['history']
# Summarize if history exceeds certain length
if len(conversation_history.split()) > 200:
conversation_history = summarize_history(conversation_history)
# Combine the (possibly summarized) history with the current user input
full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
# Tokenize the input and convert to tensor
input_ids = tokenizer.encode(full_input, return_tensors="pt").to(device)
# Generate the response using the model with adjusted parameters
outputs = model.generate(
input_ids,
max_length=input_ids.shape[1] + 100, # Limit total length
max_new_tokens=100,
num_return_sequences=1,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
temperature=0.7,
top_k=20,
top_p=0.8,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode the model output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Update the memory with the user input and model response
memory.save_context({"input": input_text}, {"output": response})
return response
# Set up the Gradio interface
interface = gr.Interface(
fn=chat_with_distilgpt2,
inputs=gr.Textbox(label="Chat with DistilGPT-2"),
outputs=gr.Textbox(label="DistilGPT-2's Response"),
title="DistilGPT-2 Chatbot with Memory",
description="This is a simple chatbot powered by the DistilGPT-2 model with conversational memory, using LangChain.",
)
# Launch the Gradio app
interface.launch()
|