thinh-researcher
commited on
Commit
β’
e89a14c
1
Parent(s):
9d17be0
Update
Browse files
{resnet_config β resnet_model}/__init__.py
RENAMED
File without changes
|
{resnet_config β resnet_model}/configuration_resnet.py
RENAMED
File without changes
|
{resnet_config β resnet_model}/custom-resnet/config.json
RENAMED
File without changes
|
{resnet_config β resnet_model}/modeling_resnet.py
RENAMED
@@ -1,6 +1,6 @@
|
|
1 |
from typing import Dict
|
2 |
|
3 |
-
import
|
4 |
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
5 |
from torch import Tensor, nn
|
6 |
from transformers import PreTrainedModel
|
@@ -11,6 +11,11 @@ BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
|
|
11 |
|
12 |
|
13 |
class ResnetModel(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
14 |
config_class = ResnetConfig
|
15 |
|
16 |
def __init__(self, config: ResnetConfig):
|
@@ -28,27 +33,30 @@ class ResnetModel(PreTrainedModel):
|
|
28 |
avg_down=config.avg_down,
|
29 |
)
|
30 |
|
31 |
-
def forward(self, tensor):
|
32 |
return self.model.forward_features(tensor)
|
33 |
|
34 |
|
35 |
class ResnetModelForImageClassification(PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
|
36 |
config_class = ResnetConfig
|
37 |
|
38 |
def __init__(self, config: ResnetConfig):
|
39 |
super().__init__(config)
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
avg_down=config.avg_down,
|
51 |
-
)
|
52 |
|
53 |
def forward(self, tensor: Tensor, labels=None) -> Dict[str, Tensor]:
|
54 |
logits = self.model(tensor)
|
@@ -56,3 +64,12 @@ class ResnetModelForImageClassification(PreTrainedModel):
|
|
56 |
loss = nn.cross_entropy(logits, labels)
|
57 |
return {"loss": loss, "logits": logits}
|
58 |
return {"logits": logits}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Dict
|
2 |
|
3 |
+
import timm
|
4 |
from timm.models.resnet import BasicBlock, Bottleneck, ResNet
|
5 |
from torch import Tensor, nn
|
6 |
from transformers import PreTrainedModel
|
|
|
11 |
|
12 |
|
13 |
class ResnetModel(PreTrainedModel):
|
14 |
+
"""
|
15 |
+
The line that sets the config_class is not mandatory,
|
16 |
+
unless you want to register your model with the auto classes
|
17 |
+
"""
|
18 |
+
|
19 |
config_class = ResnetConfig
|
20 |
|
21 |
def __init__(self, config: ResnetConfig):
|
|
|
33 |
avg_down=config.avg_down,
|
34 |
)
|
35 |
|
36 |
+
def forward(self, tensor: Tensor) -> Tensor:
|
37 |
return self.model.forward_features(tensor)
|
38 |
|
39 |
|
40 |
class ResnetModelForImageClassification(PreTrainedModel):
|
41 |
+
"""
|
42 |
+
The line that sets the config_class is not mandatory,
|
43 |
+
unless you want to register your model with the auto classes
|
44 |
+
"""
|
45 |
+
|
46 |
config_class = ResnetConfig
|
47 |
|
48 |
def __init__(self, config: ResnetConfig):
|
49 |
super().__init__(config)
|
50 |
+
self.model = ResnetModel(config)
|
51 |
+
|
52 |
+
"""
|
53 |
+
You can have your model return anything you want,
|
54 |
+
but returning a dictionary like we did for ResnetModelForImageClassification,
|
55 |
+
with the loss included when labels are passed,
|
56 |
+
will make your model directly usable inside the Trainer class.
|
57 |
+
Using another output format is fine as long as you are planning on
|
58 |
+
using your own training loop or another library for training.
|
59 |
+
"""
|
|
|
|
|
60 |
|
61 |
def forward(self, tensor: Tensor, labels=None) -> Dict[str, Tensor]:
|
62 |
logits = self.model(tensor)
|
|
|
64 |
loss = nn.cross_entropy(logits, labels)
|
65 |
return {"loss": loss, "logits": logits}
|
66 |
return {"logits": logits}
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
resnet50d_config = ResnetConfig.from_pretrained("custom-resnet")
|
71 |
+
resnet50d = ResnetModelForImageClassification(resnet50d_config)
|
72 |
+
|
73 |
+
# Load pretrained weights from timm
|
74 |
+
pretrained_model: nn.Module = timm.create_model("resnet50d", pretrained=True)
|
75 |
+
resnet50d.model.load_state_dict(pretrained_model.state_dict())
|
run.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
|
3 |
+
from resnet_model.configuration_resnet import ResnetConfig
|
4 |
+
from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
|
5 |
+
|
6 |
+
ResnetConfig.register_for_auto_class()
|
7 |
+
ResnetModel.register_for_auto_class("AutoModel")
|
8 |
+
ResnetModelForImageClassification.register_for_auto_class(
|
9 |
+
"AutoModelForImageClassification"
|
10 |
+
)
|
11 |
+
|
12 |
+
resnet50d_config = ResnetConfig(
|
13 |
+
block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
|
14 |
+
)
|
15 |
+
resnet50d = ResnetModelForImageClassification(resnet50d_config)
|
16 |
+
pretrained_model = timm.create_model("resnet50d", pretrained=True)
|
17 |
+
resnet50d.model.load_state_dict(pretrained_model.state_dict())
|
18 |
+
|
19 |
+
resnet50d.push_to_hub("custom-resnet50d")
|