thinh-researcher commited on
Commit
e89a14c
β€’
1 Parent(s): 9d17be0
{resnet_config β†’ resnet_model}/__init__.py RENAMED
File without changes
{resnet_config β†’ resnet_model}/configuration_resnet.py RENAMED
File without changes
{resnet_config β†’ resnet_model}/custom-resnet/config.json RENAMED
File without changes
{resnet_config β†’ resnet_model}/modeling_resnet.py RENAMED
@@ -1,6 +1,6 @@
1
  from typing import Dict
2
 
3
- import torch
4
  from timm.models.resnet import BasicBlock, Bottleneck, ResNet
5
  from torch import Tensor, nn
6
  from transformers import PreTrainedModel
@@ -11,6 +11,11 @@ BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
11
 
12
 
13
  class ResnetModel(PreTrainedModel):
 
 
 
 
 
14
  config_class = ResnetConfig
15
 
16
  def __init__(self, config: ResnetConfig):
@@ -28,27 +33,30 @@ class ResnetModel(PreTrainedModel):
28
  avg_down=config.avg_down,
29
  )
30
 
31
- def forward(self, tensor):
32
  return self.model.forward_features(tensor)
33
 
34
 
35
  class ResnetModelForImageClassification(PreTrainedModel):
 
 
 
 
 
36
  config_class = ResnetConfig
37
 
38
  def __init__(self, config: ResnetConfig):
39
  super().__init__(config)
40
- block_layer = BLOCK_MAPPING[config.block_type]
41
- self.model = ResNet(
42
- block_layer,
43
- config.layers,
44
- num_classes=config.num_classes,
45
- in_chans=config.input_channels,
46
- cardinality=config.cardinality,
47
- base_width=config.base_width,
48
- stem_width=config.stem_width,
49
- stem_type=config.stem_type,
50
- avg_down=config.avg_down,
51
- )
52
 
53
  def forward(self, tensor: Tensor, labels=None) -> Dict[str, Tensor]:
54
  logits = self.model(tensor)
@@ -56,3 +64,12 @@ class ResnetModelForImageClassification(PreTrainedModel):
56
  loss = nn.cross_entropy(logits, labels)
57
  return {"loss": loss, "logits": logits}
58
  return {"logits": logits}
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict
2
 
3
+ import timm
4
  from timm.models.resnet import BasicBlock, Bottleneck, ResNet
5
  from torch import Tensor, nn
6
  from transformers import PreTrainedModel
 
11
 
12
 
13
  class ResnetModel(PreTrainedModel):
14
+ """
15
+ The line that sets the config_class is not mandatory,
16
+ unless you want to register your model with the auto classes
17
+ """
18
+
19
  config_class = ResnetConfig
20
 
21
  def __init__(self, config: ResnetConfig):
 
33
  avg_down=config.avg_down,
34
  )
35
 
36
+ def forward(self, tensor: Tensor) -> Tensor:
37
  return self.model.forward_features(tensor)
38
 
39
 
40
  class ResnetModelForImageClassification(PreTrainedModel):
41
+ """
42
+ The line that sets the config_class is not mandatory,
43
+ unless you want to register your model with the auto classes
44
+ """
45
+
46
  config_class = ResnetConfig
47
 
48
  def __init__(self, config: ResnetConfig):
49
  super().__init__(config)
50
+ self.model = ResnetModel(config)
51
+
52
+ """
53
+ You can have your model return anything you want,
54
+ but returning a dictionary like we did for ResnetModelForImageClassification,
55
+ with the loss included when labels are passed,
56
+ will make your model directly usable inside the Trainer class.
57
+ Using another output format is fine as long as you are planning on
58
+ using your own training loop or another library for training.
59
+ """
 
 
60
 
61
  def forward(self, tensor: Tensor, labels=None) -> Dict[str, Tensor]:
62
  logits = self.model(tensor)
 
64
  loss = nn.cross_entropy(logits, labels)
65
  return {"loss": loss, "logits": logits}
66
  return {"logits": logits}
67
+
68
+
69
+ if __name__ == "__main__":
70
+ resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
71
+ resnet50d = ResnetModelForImageClassification(resnet50d_config)
72
+
73
+ # Load pretrained weights from timm
74
+ pretrained_model: nn.Module = timm.create_model("resnet50d", pretrained=True)
75
+ resnet50d.model.load_state_dict(pretrained_model.state_dict())
run.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+
3
+ from resnet_model.configuration_resnet import ResnetConfig
4
+ from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
5
+
6
+ ResnetConfig.register_for_auto_class()
7
+ ResnetModel.register_for_auto_class("AutoModel")
8
+ ResnetModelForImageClassification.register_for_auto_class(
9
+ "AutoModelForImageClassification"
10
+ )
11
+
12
+ resnet50d_config = ResnetConfig(
13
+ block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
14
+ )
15
+ resnet50d = ResnetModelForImageClassification(resnet50d_config)
16
+ pretrained_model = timm.create_model("resnet50d", pretrained=True)
17
+ resnet50d.model.load_state_dict(pretrained_model.state_dict())
18
+
19
+ resnet50d.push_to_hub("custom-resnet50d")