thinh-researcher commited on
Commit
db7d3b8
1 Parent(s): 695e3cb
Files changed (3) hide show
  1. resnet_model/configuration_resnet.py +1 -1
  2. run.py +8 -0
  3. test.py +6 -0
resnet_model/configuration_resnet.py CHANGED
@@ -20,7 +20,7 @@ class ResnetConfig(PretrainedConfig):
20
  Defining a model_type for your configuration (here model_type="resnet") is not mandatory,
21
  unless you want to register your model with the auto classes (see last section)."""
22
 
23
- model_type = "resnet"
24
 
25
  def __init__(
26
  self,
 
20
  Defining a model_type for your configuration (here model_type="resnet") is not mandatory,
21
  unless you want to register your model with the auto classes (see last section)."""
22
 
23
+ model_type = "rgbdsod-resnet"
24
 
25
  def __init__(
26
  self,
run.py CHANGED
@@ -2,11 +2,19 @@ import timm
2
 
3
  from resnet_model.configuration_resnet import ResnetConfig
4
  from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
 
5
 
6
  ResnetConfig.register_for_auto_class()
7
  ResnetModel.register_for_auto_class("AutoModel")
8
  ResnetModelForImageClassification.register_for_auto_class("AutoModel")
9
 
 
 
 
 
 
 
 
10
  resnet50d_config = ResnetConfig(
11
  block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
12
  )
 
2
 
3
  from resnet_model.configuration_resnet import ResnetConfig
4
  from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
5
+ from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
6
 
7
  ResnetConfig.register_for_auto_class()
8
  ResnetModel.register_for_auto_class("AutoModel")
9
  ResnetModelForImageClassification.register_for_auto_class("AutoModel")
10
 
11
+
12
+ # AutoConfig.register("rgbdsod-resnet", ResnetConfig)
13
+ # AutoModel.register(ResnetConfig, ResnetModel)
14
+ # AutoModelForImageClassification.register(
15
+ # ResnetConfig, ResnetModelForImageClassification
16
+ # )
17
+
18
  resnet50d_config = ResnetConfig(
19
  block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
20
  )
test.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import AutoModel
2
+
3
+ model = AutoModel.from_pretrained(
4
+ "RGBD-SOD/custom-resnet50d", trust_remote_code=True
5
+ )
6
+ print(model)