File size: 2,985 Bytes
0efd337
ec853a0
07b00c0
0e4ab50
2ad20f5
ce16e77
 
 
07b00c0
 
 
ce16e77
07b00c0
ec853a0
 
 
 
 
 
 
 
 
 
 
 
 
0e4ab50
 
 
 
07b00c0
ec853a0
0e4ab50
 
ec853a0
 
 
 
 
0e4ab50
 
07b00c0
 
2ad20f5
fb600ee
 
 
 
 
 
 
 
 
f374df6
 
fb600ee
 
 
 
2ad20f5
0e4ab50
2ad20f5
0e4ab50
 
 
 
2ad20f5
 
 
 
07b00c0
 
 
 
 
0efd337
 
2ad20f5
 
0e4ab50
07b00c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from langchain.memory import ConversationBufferMemory

# Move model to device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load the tokenizer and model for DistilGPT-2
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
model.to(device)

# Load summarization model (e.g., T5-small)
summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)

def summarize_history(history):
    input_ids = summarizer_tokenizer.encode(
        "summarize: " + history,
        return_tensors="pt"
    ).to(device)
    summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
    summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# Set up conversational memory using LangChain's ConversationBufferMemory
memory = ConversationBufferMemory()

# Define the chatbot function with memory
def chat_with_distilgpt2(input_text):
    # Retrieve conversation history
    conversation_history = memory.load_memory_variables({})['history']
    
    # Summarize if history exceeds certain length
    if len(conversation_history.split()) > 200:
        conversation_history = summarize_history(conversation_history)
    
    # Combine the (possibly summarized) history with the current user input
    full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
    
    # Tokenize the input and convert to tensor
    input_ids = tokenizer.encode(full_input, return_tensors="pt").to(device)
    
    # Generate the response using the model with adjusted parameters
    outputs = model.generate(
        input_ids,
        max_length=input_ids.shape[1] + 100,  # Limit total length
        max_new_tokens=100,
        num_return_sequences=1,
        no_repeat_ngram_size=3,
        repetition_penalty=1.2,
        temperature=0.7,
        top_k=20,
        top_p=0.8,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    
    # Decode the model output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Update the memory with the user input and model response
    memory.save_context({"input": input_text}, {"output": response})
    
    return response

# Set up the Gradio interface
interface = gr.Interface(
    fn=chat_with_distilgpt2,
    inputs=gr.Textbox(label="Chat with DistilGPT-2"),
    outputs=gr.Textbox(label="DistilGPT-2's Response"),
    title="DistilGPT-2 Chatbot with Memory",
    description="This is a simple chatbot powered by the DistilGPT-2 model with conversational memory, using LangChain.",
)

# Launch the Gradio app
interface.launch()