glasses.models.classification.deit package¶
Module contents¶
- class glasses.models.classification.deit.DeiT(*args, head: torch.nn.modules.module.Module = <class 'glasses.models.classification.deit.DeiTClassificationHead'>, tokens: torch.nn.modules.module.Module = <class 'glasses.models.classification.deit.DeiTTokens'>, **kwargs)[source]¶
- Bases: - glasses.models.classification.vit.ViT- Implementation of DeiT proposed in Training data-efficient image transformers & distillation through attention - An attention based distillation is proposed where a new token is added to the model, the dist token.   - DeiT.deit_tiny_patch16_224() DeiT.deit_small_patch16_224() DeiT.deit_base_patch16_224() DeiT.deit_base_patch16_384() - Parameters
- head (nn.Module, optional) – [description]. Defaults to DeiTClassificationHead. 
- tokens (nn.Module, optional) – [description]. Defaults to DeiTTokens. 
 
 - Initializes internal Module state, shared by both nn.Module and ScriptModule. 
- class glasses.models.classification.deit.DeiTClassificationHead(emb_size: int = 768, n_classes: int = 1000)[source]¶
- Bases: - torch.nn.modules.module.Module- DeiT classification head, it relies on two heads using the cls and the`dist` token respectively. At test time, the prediction is made by avering the results from the two, while during training both predictions are returned. - Parameters
- emb_size (int, optional) – Embedding dimensions Defaults to 768. 
- n_classes (int, optional) – [description]. Defaults to 1000. 
 
 - forward(x: torch.Tensor) torch.Tensor[source]¶
- Defines the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 - training: bool¶
 
- class glasses.models.classification.deit.DeiTTokens(emb_size: int)[source]¶
- Bases: - glasses.models.classification.vit.ViTTokens- Tokens for DeiT, it contains the cls token present in ViT plus a special token, dist, used for distillation. - Parameters
- emb_size (int) – Embedding dimensions 
 - training: bool¶