Nemil commited on
Commit
6270cb3
β€’
1 Parent(s): 6a961b1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +40 -46
  2. requirements.txt +6 -6
app.py CHANGED
@@ -138,58 +138,52 @@ class Social_Media_Captioner:
138
 
139
 
140
  def _load_model(self):
141
- try:
142
- self.bnb_config = BitsAndBytesConfig(
143
- load_in_4bit = True,
144
- bnb_4bit_use_double_quant = True,
145
- bnb_4bit_quant_type= "nf4",
146
- bnb_4bit_compute_dtype=torch.bfloat16,
147
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  self.model = AutoModelForCausalLM.from_pretrained(
149
- self.MODEL_NAME,
150
- device_map = "auto",
151
- trust_remote_code = True,
152
- quantization_config = self.bnb_config
 
153
  )
 
154
 
155
  # Defining the tokenizers
156
- self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
157
  self.tokenizer.pad_token = self.tokenizer.eos_token
158
 
159
- if self.use_finetuned:
160
- # LORA Config Model
161
- self.lora_config = LoraConfig(
162
- r=16,
163
- lora_alpha=32,
164
- target_modules=["query_key_value"],
165
- lora_dropout=0.05,
166
- bias="none",
167
- task_type="CAUSAL_LM"
168
- )
169
- self.model = get_peft_model(self.model, self.lora_config)
170
-
171
- # Fitting the adapters
172
- self.peft_config = PeftConfig.from_pretrained(self.peft_model_name)
173
- self.model = AutoModelForCausalLM.from_pretrained(
174
- self.peft_config.base_model_name_or_path,
175
- return_dict = True,
176
- quantization_config = self.bnb_config,
177
- device_map= "auto",
178
- trust_remote_code = True
179
- )
180
- self.model = PeftModel.from_pretrained(self.model, self.peft_model_name)
181
-
182
- # Defining the tokenizers
183
- self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path)
184
- self.tokenizer.pad_token = self.tokenizer.eos_token
185
-
186
- self.model_loaded = True
187
- print("Model Loaded successfully")
188
-
189
- except Exception as e:
190
- print(e)
191
- self.model_loaded = False
192
-
193
 
194
  def inference(self, input_text: str, use_cached=True, cache_generation=True) -> str | None:
195
  if not self.model_loaded:
 
138
 
139
 
140
  def _load_model(self):
141
+ self.bnb_config = BitsAndBytesConfig(
142
+ load_in_4bit = True,
143
+ bnb_4bit_use_double_quant = True,
144
+ bnb_4bit_quant_type= "nf4",
145
+ bnb_4bit_compute_dtype=torch.bfloat16,
146
+ )
147
+ self.model = AutoModelForCausalLM.from_pretrained(
148
+ self.MODEL_NAME,
149
+ device_map = "auto",
150
+ trust_remote_code = True,
151
+ quantization_config = self.bnb_config
152
+ )
153
+
154
+ # Defining the tokenizers
155
+ self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
156
+ self.tokenizer.pad_token = self.tokenizer.eos_token
157
+
158
+ if self.use_finetuned:
159
+ # LORA Config Model
160
+ self.lora_config = LoraConfig(
161
+ r=16,
162
+ lora_alpha=32,
163
+ target_modules=["query_key_value"],
164
+ lora_dropout=0.05,
165
+ bias="none",
166
+ task_type="CAUSAL_LM"
167
+ )
168
+ self.model = get_peft_model(self.model, self.lora_config)
169
+
170
+ # Fitting the adapters
171
+ self.peft_config = PeftConfig.from_pretrained(self.peft_model_name)
172
  self.model = AutoModelForCausalLM.from_pretrained(
173
+ self.peft_config.base_model_name_or_path,
174
+ return_dict = True,
175
+ quantization_config = self.bnb_config,
176
+ device_map= "auto",
177
+ trust_remote_code = True
178
  )
179
+ self.model = PeftModel.from_pretrained(self.model, self.peft_model_name)
180
 
181
  # Defining the tokenizers
182
+ self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path)
183
  self.tokenizer.pad_token = self.tokenizer.eos_token
184
 
185
+ self.model_loaded = True
186
+ print("Model Loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def inference(self, input_text: str, use_cached=True, cache_generation=True) -> str | None:
189
  if not self.model_loaded:
requirements.txt CHANGED
@@ -1,11 +1,11 @@
1
  evaluate
2
  jiwer
3
- huggingface_hub==0.20.0
4
- gradio
5
- bitsandbytes
6
- transformers @ git+https://github.com/huggingface/transformers.git
7
- peft @ git+https://github.com/huggingface/peft.git
8
- accelerate @ git+https://github.com/huggingface/accelerate.git
9
  einops
10
  safetensors
11
  torch
 
1
  evaluate
2
  jiwer
3
+ huggingface_hub==0.16.4
4
+ gradio==3.36.0
5
+ bitsandbytes==0.41.0
6
+ transformers==4.31.0
7
+ peft==0.4.0
8
+ accelerate==0.21.0
9
  einops
10
  safetensors
11
  torch