kdevoe commited on
Commit
e1ba8ed
1 Parent(s): 5167829

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -11,18 +11,18 @@ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
11
  model = GPT2LMHeadModel.from_pretrained("distilgpt2")
12
  model.to(device)
13
 
14
- # Load summarization model (e.g., T5-small)
15
- summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
16
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
17
 
18
- def summarize_history(history):
19
- input_ids = summarizer_tokenizer.encode(
20
- "summarize: " + history,
21
- return_tensors="pt"
22
- ).to(device)
23
- summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
24
- summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
- return summary
26
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
@@ -32,9 +32,9 @@ def chat_with_distilgpt2(input_text):
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
- # Summarize if history exceeds certain length
36
- if len(conversation_history.split()) > 200:
37
- conversation_history = summarize_history(conversation_history)
38
 
39
  # Combine the (possibly summarized) history with the current user input
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
@@ -50,9 +50,9 @@ def chat_with_distilgpt2(input_text):
50
  num_return_sequences=1,
51
  no_repeat_ngram_size=3,
52
  repetition_penalty=1.2,
53
- temperature=0.9,
54
- top_k=20,
55
- top_p=0.8,
56
  early_stopping=True,
57
  pad_token_id=tokenizer.eos_token_id,
58
  eos_token_id=tokenizer.eos_token_id
 
11
  model = GPT2LMHeadModel.from_pretrained("distilgpt2")
12
  model.to(device)
13
 
14
+ # # Load summarization model (e.g., T5-small)
15
+ # summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
16
+ # summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
17
 
18
+ # def summarize_history(history):
19
+ # input_ids = summarizer_tokenizer.encode(
20
+ # "summarize: " + history,
21
+ # return_tensors="pt"
22
+ # ).to(device)
23
+ # summary_ids = summarizer_model.generate(input_ids, max_length=50, min_length=25, length_penalty=5., num_beams=2)
24
+ # summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
25
+ # return summary
26
 
27
  # Set up conversational memory using LangChain's ConversationBufferMemory
28
  memory = ConversationBufferMemory()
 
32
  # Retrieve conversation history
33
  conversation_history = memory.load_memory_variables({})['history']
34
 
35
+ # # Summarize if history exceeds certain length
36
+ # if len(conversation_history.split()) > 200:
37
+ # conversation_history = summarize_history(conversation_history)
38
 
39
  # Combine the (possibly summarized) history with the current user input
40
  full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:"
 
50
  num_return_sequences=1,
51
  no_repeat_ngram_size=3,
52
  repetition_penalty=1.2,
53
+ # temperature=0.9,
54
+ # top_k=20,
55
+ # top_p=0.8,
56
  early_stopping=True,
57
  pad_token_id=tokenizer.eos_token_id,
58
  eos_token_id=tokenizer.eos_token_id