from transformers import PreTrainedModel | |
from .config import RealESRGANConfig | |
from .rrdbnet import RRDBNet | |
class RealESRGANModel(PreTrainedModel): | |
config_class = RealESRGANConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = RRDBNet( | |
num_in_ch=config.num_in_ch, | |
num_out_ch=config.num_out_ch, | |
num_feat=config.num_feat, | |
num_block=config.num_block, | |
num_grow_ch=config.num_grow_ch, | |
scale=config.scale, | |
) | |
def forward(self, tensor): | |
return self.model.forward(tensor) | |