kdevoe commited on
Commit
a1f6cc4
1 Parent(s): 684c258

Loading model weights from saved file manually to prevent issue when using load_pretrained

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
 
4
  from langchain.memory import ConversationBufferMemory
5
 
6
  # Move model to device (GPU if available)
@@ -9,17 +10,16 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
9
  # Load the tokenizer (use pre-trained tokenizer for GPT-2 family)
10
  tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
11
 
12
- # Load the fine-tuned model from the local safetensors file
13
- model_path = "./model.safetensors" # Path to your local model file
14
- model = GPT2LMHeadModel.from_pretrained(
15
- pretrained_model_name_or_path=None, # None because it's not from a model name
16
- config="distilgpt2", # Specify the config for distilgpt2
17
- local_files_only=True, # Only look for local files
18
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
19
- )
20
 
21
- # Load the safetensors weights
22
- model.load_state_dict(torch.load(model_path, map_location=device))
 
 
 
 
 
23
 
24
  # Move model to the device (GPU or CPU)
25
  model.to(device)
@@ -73,3 +73,4 @@ interface.launch()
73
 
74
 
75
 
 
 
1
  import gradio as gr
2
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForSeq2SeqLM, AutoTokenizer, GPT2Config
3
  import torch
4
+ from safetensors.torch import load_file as safetensors_load_file # Import safetensors loading function
5
  from langchain.memory import ConversationBufferMemory
6
 
7
  # Move model to device (GPU if available)
 
10
  # Load the tokenizer (use pre-trained tokenizer for GPT-2 family)
11
  tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
12
 
13
+ # Load the configuration for the model (DistilGPT2 is a smaller GPT-2)
14
+ config = GPT2Config.from_pretrained("distilgpt2")
 
 
 
 
 
 
15
 
16
+ # Initialize the model using the configuration
17
+ model = GPT2LMHeadModel(config)
18
+
19
+ # Load the weights from the safetensors file
20
+ model_path = "./model.safetensors" # Path to your local model file
21
+ state_dict = safetensors_load_file(model_path) # Use safetensors loader
22
+ model.load_state_dict(state_dict) # Load the state dict into the model
23
 
24
  # Move model to the device (GPU or CPU)
25
  model.to(device)
 
73
 
74
 
75
 
76
+