Streanlit-GEMMA-2B / model_functions.py
aryachakraborty's picture
Upload 4 files
31d6ed6 verified
raw
history blame
756 Bytes
from transformers import AutoModelForCausalLM,AutoTokenizer
import streamlit as st
@st.cache_resource(show_spinner='Loading the Gemma model. Be patient🙏')
def LOAD_GEMMA():
model_id = "aryachakraborty/GEMMA-2B-NL-SQL"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
low_cpu_mem_usage = True
).cpu()
return tokenizer,model
def LOAD_MISTRAL():
model_id=''
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
low_cpu_usage=True,
).cpu()