TharvinPrakash commited on
Commit
5020ea6
1 Parent(s): 8fda1c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -1,24 +1,24 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
  # Load Hugging Face tokenizer and model for re-punctuation
6
  @st.cache_resource
7
  def load_re_punctuate_model():
8
  tokenizer = AutoTokenizer.from_pretrained("SJ-Ray/Re-Punctuate")
9
- model = AutoModelForSeq2SeqLM.from_pretrained("SJ-Ray/Re-Punctuate",from_tf=True)
10
  return tokenizer, model
11
 
12
- # Load Hugging Face tokenizer and model for headline generation
13
  @st.cache_resource
14
- def load_headline_model():
15
- tokenizer = AutoTokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
16
- model = AutoModelForSeq2SeqLM.from_pretrained("Michau/t5-base-en-generate-headline")
17
  return tokenizer, model
18
 
19
  # Function to re-punctuate text
20
  def re_punctuate_text(tokenizer, model, text):
21
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
22
  outputs = model.generate(inputs["input_ids"], max_length=512, num_beams=4, early_stopping=True)
23
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
24
 
@@ -27,7 +27,7 @@ def generate_headline_text(tokenizer, model, text, max_length=50):
27
  inputs = tokenizer(f"headline: {text}", return_tensors="pt", truncation=True, padding=True)
28
  with torch.no_grad():
29
  outputs = model.generate(
30
- **inputs,
31
  max_length=max_length,
32
  num_beams=5,
33
  no_repeat_ngram_size=2,
@@ -45,6 +45,9 @@ selected_model = st.selectbox("Choose a model to use:", model_options)
45
  # User input text
46
  input_text = st.text_area("Enter text:", placeholder="Type your input here...")
47
 
 
 
 
48
  # Button to process text based on the selected model
49
  if st.button("Process Text") and input_text:
50
  with st.spinner("Processing..."):
@@ -52,7 +55,7 @@ if st.button("Process Text") and input_text:
52
  tokenizer, model = load_re_punctuate_model()
53
  result = re_punctuate_text(tokenizer, model, input_text)
54
  else: # Generate Headline
55
- tokenizer, model = load_headline_model()
56
  result = generate_headline_text(tokenizer, model, input_text)
57
 
58
  # Display result
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
  # Load Hugging Face tokenizer and model for re-punctuation
6
  @st.cache_resource
7
  def load_re_punctuate_model():
8
  tokenizer = AutoTokenizer.from_pretrained("SJ-Ray/Re-Punctuate")
9
+ model = TFAutoModelForSeq2SeqLM.from_pretrained("SJ-Ray/Re-Punctuate")
10
  return tokenizer, model
11
 
12
+ # Load Hugging Face tokenizer and model for headline generation (local path)
13
  @st.cache_resource
14
+ def load_headline_model(model_path):
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
17
  return tokenizer, model
18
 
19
  # Function to re-punctuate text
20
  def re_punctuate_text(tokenizer, model, text):
21
+ inputs = tokenizer(text, return_tensors="tf", max_length=512, truncation=True)
22
  outputs = model.generate(inputs["input_ids"], max_length=512, num_beams=4, early_stopping=True)
23
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
24
 
 
27
  inputs = tokenizer(f"headline: {text}", return_tensors="pt", truncation=True, padding=True)
28
  with torch.no_grad():
29
  outputs = model.generate(
30
+ **inputs,
31
  max_length=max_length,
32
  num_beams=5,
33
  no_repeat_ngram_size=2,
 
45
  # User input text
46
  input_text = st.text_area("Enter text:", placeholder="Type your input here...")
47
 
48
+ # Default local model path for headline generation
49
+ local_model_path = r"C:\Users\Tharvin prakash\.cache\huggingface\hub\models--Michau--t5-base-en-generate-headline\snapshots\f526532f788c45b6b6288286e5ef929fa768ef6a"
50
+
51
  # Button to process text based on the selected model
52
  if st.button("Process Text") and input_text:
53
  with st.spinner("Processing..."):
 
55
  tokenizer, model = load_re_punctuate_model()
56
  result = re_punctuate_text(tokenizer, model, input_text)
57
  else: # Generate Headline
58
+ tokenizer, model = load_headline_model(local_model_path)
59
  result = generate_headline_text(tokenizer, model, input_text)
60
 
61
  # Display result