GouthamVarma commited on
Commit
c19ce35
1 Parent(s): f5f0588

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -1,20 +1,32 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  def load_model():
6
- # Fixed the spelling in the model name
7
- model_name = "GouthamVarma/mentalhealth_coversational_chatbot"
8
 
9
- # Load model without authentication since it's public
10
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
 
13
  low_cpu_mem_usage=True,
14
  torch_dtype=torch.float32,
15
  device_map="cpu",
16
  trust_remote_code=True
17
  )
 
18
  return model, tokenizer
19
 
20
  print("Loading model... This might take a few minutes...")
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
  import torch
4
 
5
  def load_model():
6
+ model_name = "GouthamVarma/mentalhealth_coversational_chatbot"
7
+ base_model = "google/gemma-2b-it" # Base model we fine-tuned from
8
 
9
+ # First load the base model's tokenizer and config
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ base_model,
12
+ trust_remote_code=True
13
+ )
14
+
15
+ config = AutoConfig.from_pretrained(
16
+ base_model,
17
+ trust_remote_code=True
18
+ )
19
+
20
+ # Then load your fine-tuned model
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_name,
23
+ config=config,
24
  low_cpu_mem_usage=True,
25
  torch_dtype=torch.float32,
26
  device_map="cpu",
27
  trust_remote_code=True
28
  )
29
+
30
  return model, tokenizer
31
 
32
  print("Loading model... This might take a few minutes...")