glasses.models.base package¶
Submodules¶
glasses.models.base.protocols module¶
- class glasses.models.base.protocols.Freezable[source]¶
Bases:
object
A protocol that allows to freeze and unfreeze weights of the class that uses it
- Example
>>> model = ResNet.resnet18() >>> Freezable.set_requires_grad(model.encoder) >>> class MyModel(nn.Sequential, Freezable): >>> def __init__(self): >>> super().__init__(nn.Conv2d(3, 32, kernel_size=3), nn.BatchNorm2d(32), nn.ReLU()) >>> model = MyModel() >>> model.freeze() >>> model.unfreeze() >>> # freeze only one specific layer >>> model.freeze(model[0])
- class glasses.models.base.protocols.Interpretable[source]¶
Bases:
object
Protocol that allows the clas that subclass it to interpret an input using and instance of Interpretability
- interpret(x: torch.Tensor, using: glasses.interpretability.Interpretability.Interpretability, *args, **kwargs)[source]¶
Module contents¶
- class glasses.models.base.Encoder[source]¶
Bases:
torch.nn.modules.module.Module
Base encoder class, it allows to access the inner features.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- property features: List[torch.Tensor]¶
- property features_widths: List[int]¶
- property stages: List[torch.nn.modules.module.Module]¶
- training: bool¶
- class glasses.models.base.VisionModule[source]¶
Bases:
torch.nn.modules.module.Module
,glasses.models.base.protocols.Freezable
,glasses.models.base.protocols.Interpretable
Base vision module, all models should subclass it.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- summary(input_shape: Tuple[int] = (1, 3, 224, 224), device: Optional[torch.device] = None)[source]¶
Useful method to run torchinfo directly from the model
- Parameters
input_shape (tuple, optional) – [description]. Defaults to (3, 224, 224).
- Returns
[description]
- Return type
[type]
- training: bool¶