remosleandre commited on
Commit
b46b06b
1 Parent(s): 8936391

[FIX] weight_update

Browse files
Files changed (2) hide show
  1. model.py +1 -1
  2. model_hugging_face.py +6 -1
model.py CHANGED
@@ -44,7 +44,7 @@ class Architecture(nn.Module):
44
 
45
  def load_model():
46
  model = Architecture()
47
- model.load_state_dict(torch.load('model_weights.pth'))
48
  return model
49
 
50
  def inference_model(model, input):
 
44
 
45
  def load_model():
46
  model = Architecture()
47
+ model.load_state_dict(torch.load('./model_weights.pth'))
48
  return model
49
 
50
  def inference_model(model, input):
model_hugging_face.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import PreTrainedModel, PretrainedConfig
2
  import torch.nn as nn
3
  import torch
@@ -59,7 +60,11 @@ class Architecture(PreTrainedModel):
59
 
60
  # Loading the model from saved weights
61
  def load_model():
 
 
62
  config = ArchitectureConfig()
63
  model = Architecture(config)
64
- model.load_state_dict(torch.load('model_weights.pth'))
65
  return model
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
  from transformers import PreTrainedModel, PretrainedConfig
3
  import torch.nn as nn
4
  import torch
 
60
 
61
  # Loading the model from saved weights
62
  def load_model():
63
+ AutoConfig.register("architecture", ArchitectureConfig)
64
+ AutoModel.register(ArchitectureConfig, Architecture)
65
  config = ArchitectureConfig()
66
  model = Architecture(config)
67
+ model.load_state_dict(torch.load('./model_weights.pth'))
68
  return model
69
+
70
+ load_model()