glasses.models.base package

Submodules

glasses.models.base.protocols module

class glasses.models.base.protocols.Freezable[source]

Bases: object

A protocol that allows to freeze and unfreeze weights of the class that uses it

Example

>>> model = ResNet.resnet18()
>>> Freezable.set_requires_grad(model.encoder)
>>> class MyModel(nn.Sequential, Freezable):
>>>    def __init__(self):
>>>       super().__init__(nn.Conv2d(3, 32, kernel_size=3), nn.BatchNorm2d(32), nn.ReLU())
>>> model = MyModel()
>>> model.freeze()
>>> model.unfreeze()
>>> # freeze only one specific layer
>>> model.freeze(model[0])
freeze(who: Optional[torch.nn.modules.module.Module] = None)[source]
static set_requires_grad(module, to: bool = False)[source]
unfreeze(who: Optional[torch.nn.modules.module.Module] = None)[source]
class glasses.models.base.protocols.Interpretable[source]

Bases: object

Protocol that allows the clas that subclass it to interpret an input using and instance of Interpretability

interpret(x: torch.Tensor, using: glasses.interpretability.Interpretability.Interpretability, *args, **kwargs)[source]

Module contents

class glasses.models.base.Encoder[source]

Bases: torch.nn.modules.module.Module

Base encoder class, it allows to access the inner features.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

property features: List[torch.Tensor]
property features_widths: List[int]
property stages: List[torch.nn.modules.module.Module]
training: bool
class glasses.models.base.VisionModule[source]

Bases: torch.nn.modules.module.Module, glasses.models.base.protocols.Freezable, glasses.models.base.protocols.Interpretable

Base vision module, all models should subclass it.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

summary(input_shape: Tuple[int] = (1, 3, 224, 224), device: Optional[torch.device] = None)[source]

Useful method to run torchinfo directly from the model

Parameters

input_shape (tuple, optional) – [description]. Defaults to (3, 224, 224).

Returns

[description]

Return type

[type]

training: bool