glasses.models.classification.deit package¶
Module contents¶
- class glasses.models.classification.deit.DeiT(*args, head: torch.nn.modules.module.Module = <class 'glasses.models.classification.deit.DeiTClassificationHead'>, tokens: torch.nn.modules.module.Module = <class 'glasses.models.classification.deit.DeiTTokens'>, **kwargs)[source]¶
Bases:
glasses.models.classification.vit.ViT
Implementation of DeiT proposed in Training data-efficient image transformers & distillation through attention
An attention based distillation is proposed where a new token is added to the model, the dist token.
DeiT.deit_tiny_patch16_224() DeiT.deit_small_patch16_224() DeiT.deit_base_patch16_224() DeiT.deit_base_patch16_384()
- Parameters
head (nn.Module, optional) – [description]. Defaults to DeiTClassificationHead.
tokens (nn.Module, optional) – [description]. Defaults to DeiTTokens.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- class glasses.models.classification.deit.DeiTClassificationHead(emb_size: int = 768, n_classes: int = 1000)[source]¶
Bases:
torch.nn.modules.module.Module
DeiT classification head, it relies on two heads using the cls and the`dist` token respectively. At test time, the prediction is made by avering the results from the two, while during training both predictions are returned.
- Parameters
emb_size (int, optional) – Embedding dimensions Defaults to 768.
n_classes (int, optional) – [description]. Defaults to 1000.
- forward(x: torch.Tensor) torch.Tensor [source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class glasses.models.classification.deit.DeiTTokens(emb_size: int)[source]¶
Bases:
glasses.models.classification.vit.ViTTokens
Tokens for DeiT, it contains the cls token present in ViT plus a special token, dist, used for distillation.
- Parameters
emb_size (int) – Embedding dimensions
- training: bool¶