from timm.models.layers import config
from glasses.models.AutoTransform import AutoTransform
import logging
from argparse import ArgumentParser
from dataclasses import dataclass
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Dict
from glasses.utils.weights.storage import LocalStorage, HuggingFaceStorage
import pretrainedmodels
import timm
import torch
from torch import Tensor, nn
from torchvision.models import (
densenet121,
densenet161,
densenet169,
densenet201,
resnet18,
resnet50,
resnet101,
resnet152,
resnext50_32x4d,
resnext101_32x8d,
vgg11,
vgg13,
vgg16,
vgg19,
wide_resnet50_2,
wide_resnet101_2,
mobilenetv2,
)
from tqdm.autonotebook import tqdm
from glasses.models.AutoModel import AutoModel
from glasses.models import *
from glasses.utils.ModuleTransfer import ModuleTransfer
from glasses.models.classification.vit import ViTTokens
from glasses.models.classification.deit import DeiTTokens
[docs]def vit_clone(key: str):
src = timm.create_model(key, pretrained="True")
dst = AutoModel.from_name(key)
cfg = AutoTransform.from_name(key)
dst = clone_model(
src,
dst,
torch.randn((1, 3, cfg.input_size, cfg.input_size)),
dest_skip=[ViTTokens],
)
dst.embedding.positions.data.copy_(src.pos_embed.data.squeeze(0))
dst.embedding.tokens.cls.data.copy_(src.cls_token.data)
return dst
[docs]def deit_clone(key: str):
k_split = key.split("_")
hub_key = "_".join(k_split[:2]) + "_distilled_" + "_".join(k_split[2:])
src = torch.hub.load("facebookresearch/deit:main", hub_key, pretrained=True)
dst = AutoModel.from_name(key)
cfg = AutoTransform.from_name(f"vit_{'_'.join(key.split('_')[1:])}")
dst = clone_model(
src,
dst,
torch.randn((1, 3, cfg.input_size, cfg.input_size)),
dest_skip=[DeiTTokens],
)
dst.embedding.positions.data.copy_(src.pos_embed.data.squeeze(0))
dst.embedding.tokens.cls.data.copy_(src.cls_token.data)
dst.embedding.tokens.dist.data.copy_(src.dist_token.data)
return dst
zoo_source = {
"resnet18": partial(resnet18, pretrained=True),
"resnet26": partial(timm.create_model, "resnet26", pretrained=True),
"resnet26d": partial(timm.create_model, "resnet26d", pretrained=True),
"resnet34": partial(timm.create_model, "resnet34", pretrained=True),
"resnet34d": partial(timm.create_model, "resnet34d", pretrained=True),
"resnet50": partial(resnet50, pretrained=True),
"resnet50d": partial(timm.create_model, "resnet50d", pretrained=True),
"resnet101": partial(resnet101, pretrained=True),
"resnet152": partial(resnet152, pretrained=True),
"se_resnet50": partial(timm.create_model, "seresnet50", pretrained=True),
"resnext50_32x4d": partial(resnext50_32x4d, pretrained=True),
"resnext101_32x8d": partial(resnext101_32x8d, pretrained=True),
"wide_resnet50_2": partial(wide_resnet50_2, pretrained=True),
"wide_resnet101_2": partial(wide_resnet101_2, pretrained=True),
"eca_resnet26t": partial(timm.create_model, "ecaresnet26t", pretrained=True),
"eca_resnet50t": partial(timm.create_model, "ecaresnet50t", pretrained=True),
"eca_resnet50d": partial(timm.create_model, "ecaresnet50d", pretrained=True),
"eca_resnet101d": partial(timm.create_model, "ecaresnet101d", pretrained=True),
"regnetx_002": None,
"regnetx_004": None,
"regnetx_006": None,
"regnetx_008": None,
"regnetx_016": None,
"regnetx_032": None,
"regnetx_040": None,
"regnetx_064": None,
"regnety_002": None,
"regnety_004": None,
"regnety_006": None,
"regnety_008": None,
"regnety_016": None,
"regnety_032": None,
"regnety_040": None,
"regnety_064": None,
"densenet121": partial(densenet121, pretrained=True),
"densenet169": partial(densenet169, pretrained=True),
"densenet201": partial(densenet201, pretrained=True),
"densenet161": partial(densenet161, pretrained=True),
"vgg11": partial(vgg11, pretrained=True),
"vgg13": partial(vgg13, pretrained=True),
"vgg16": partial(vgg16, pretrained=True),
"vgg19": partial(vgg19, pretrained=True),
"vgg11_bn": pretrainedmodels.__dict__["vgg11_bn"],
"vgg13_bn": pretrainedmodels.__dict__["vgg13_bn"],
"vgg16_bn": pretrainedmodels.__dict__["vgg16_bn"],
"vgg19_bn": pretrainedmodels.__dict__["vgg19_bn"],
"efficientnet_b0": partial(timm.create_model, "efficientnet_b0", pretrained=True),
"efficientnet_b1": partial(timm.create_model, "efficientnet_b1", pretrained=True),
"efficientnet_b2": partial(timm.create_model, "efficientnet_b2", pretrained=True),
"efficientnet_b3": partial(timm.create_model, "efficientnet_b3", pretrained=True),
"efficientnet_lite0": partial(
timm.create_model, "efficientnet_lite0", pretrained=True
),
# # "mobilenet_v2": partial(mobilenetv2, pretrained=True),
"vit_base_patch16_224": (vit_clone, True),
"vit_base_patch16_384": (vit_clone, True),
"vit_base_patch32_384": (vit_clone, True),
"vit_huge_patch16_224": (vit_clone, True),
"vit_huge_patch32_384": (vit_clone, True),
"vit_large_patch16_224": (vit_clone, True),
"vit_large_patch16_384": (vit_clone, True),
"vit_large_patch32_384": (vit_clone, True),
"deit_tiny_patch16_224": (deit_clone, True),
"deit_small_patch16_224": (deit_clone, True),
"deit_base_patch16_224": (deit_clone, True),
"deit_base_patch16_384": (deit_clone, True),
}
[docs]def clone_model(
src: nn.Module, dst: nn.Module, x: Tensor = torch.rand((1, 3, 224, 224)), **kwargs
) -> nn.Module:
src = src.eval()
dst = dst.eval()
a = src(x)
b = dst(x)
ModuleTransfer(src, dst, **kwargs)(x)
return dst
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--storage", type=str, choices=["local", "hf"], default="hf")
parser.add_argument("-o", type=Path)
args = parser.parse_args()
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
logging.info(f"Using {args.storage} storage 💾")
# store the pretrained names
with open("pretrained_models.txt", "w") as f:
f.write(",".join(list(zoo_source.keys())))
if args.o is not None:
save_dir = args.o
save_dir.mkdir(exist_ok=True)
storages = {"local": LocalStorage, "hf": HuggingFaceStorage}
storage = storages[args.storage]()
if args.storage == "local":
logging.info(f"Store root={storage.root}")
override = False
bar = tqdm(zoo_source.items())
uploading_bar = tqdm()
for key, src_def in bar:
bar.set_description(key)
if src_def is None:
# it means I was lazy and I meant to use timm
src_def = partial(timm.create_model, key, pretrained=True)
if key not in storage or override:
if type(src_def) is tuple:
# I have a custom clone func -> not the most elegant way, but it works!
clone_func, flag = src_def
cloned = clone_func(key)
else:
src, dst = src_def(), AutoModel.from_name(key)
cloned = clone_model(src, dst)
storage.put(key, cloned)