glasses.models.classification.fishnet package

Module contents

class glasses.models.classification.fishnet.FishNet(encoder: torch.nn.modules.module.Module = <class 'glasses.models.classification.fishnet.FishNetEncoder'>, head: torch.nn.modules.module.Module = <class 'glasses.models.classification.fishnet.FishNetHead'>, *args, **kwargs)[source]

Bases: glasses.models.classification.base.ClassificationModule

Implementation of ResNet proposed in FishNet: A Versatile Backbone for Image, Region, and Pixel Level Prediction

Honestly, this model it is very weird and it has some mistakes in the paper that nobody ever cared to correct. It is a nice idea, but it could have been described better and definitly implemented better. The author’s code is terrible, I have based mostly of my implemente on this amazing repo Fishnet-PyTorch.

The following image is taken from the paper and shows the architecture detail.

https://github.com/FrancescoSaverioZuppichini/glasses/blob/develop/docs/_static/images/FishNet.png?raw=true
FishNet.fishnet99()
FishNet.fishnet150()

Examples

FishNet.fishnet99(activation = nn.SELU)
# change number of classes (default is 1000 )
FishNet.fishnet99(n_classes=100)
# pass a different block
block = lambda in_ch, out_ch, **kwargs: nn.Sequential(FishNetBottleNeck(in_ch, out_ch), SpatialSE(out_ch))
FishNet.fishnet99(block=block)
Parameters
  • in_channels (int, optional) – Number of channels in the input Image (3 for RGB and 1 for Gray). Defaults to 3.

  • n_classes (int, optional) – Number of classes. Defaults to 1000.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod fishnet150(*args, **kwargs) glasses.models.classification.fishnet.FishNet[source]

Return a fishnet150 model

Returns

[description]

Return type

FishNet

classmethod fishnet99(*args, **kwargs) glasses.models.classification.fishnet.FishNet[source]

Return a fishnet99 model

Returns

[description]

Return type

FishNet

initialize()[source]
training: bool
class glasses.models.classification.fishnet.FishNetBodyBlock(in_features: int, out_features: int, trans_features: int, block: torch.nn.modules.module.Module = functools.partial(<class 'glasses.models.classification.resnet.ResNetBottleneckPreActBlock'>, shortcut=functools.partial(<class 'glasses.nn.blocks.BnActConv'>, kernel_size=1)), depth: int = 1, trans_depth: int = 1, *args, **kwargs)[source]

Bases: torch.nn.modules.module.Module

FishNet body block, called the Up-sampling & Refinement block in the paper.

Parameters
  • in_features (int) – Number of input features

  • out_features (int) – Number of output features

  • trans_features (int) – [description]

  • block (nn.Module, optional) – [description]. Defaults to FishNetBottleNeck.

  • depth (int, optional) – [description]. Defaults to 1.

  • trans_depth (int, optional) – [description]. Defaults to 1.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: torch.Tensor, res: torch.Tensor) torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class glasses.models.classification.fishnet.FishNetBrigde(in_features: int, out_features: int, block: torch.nn.modules.module.Module = functools.partial(<class 'glasses.models.classification.resnet.ResNetBottleneckPreActBlock'>, shortcut=functools.partial(<class 'glasses.nn.blocks.BnActConv'>, kernel_size=1)), depth: int = 1, activation: torch.nn.modules.module.Module = functools.partial(<class 'torch.nn.modules.activation.ReLU'>, inplace=True))[source]

Bases: torch.nn.modules.module.Module

A weird layer that ‘bridges’ the tail and the body of the model.

Parameters
  • in_features (int) – Number of input features

  • out_features (int) – Number of output features

  • block (nn.Module, optional) – [description]. Defaults to FishNetBottleNeck.

  • depth (int, optional) – [description]. Defaults to 1.

forward(x: torch.Tensor) torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class glasses.models.classification.fishnet.FishNetChannelReductionShortcut(in_features: int, out_features: int, *args, **kwargs)[source]

Bases: torch.nn.modules.module.Module

Channel reduction output \(r(x)\) is computed as follows:

\(r(x)=\hat{x}=\left[\hat{x}(1), \hat{x}(2), \ldots, \hat{x}\left(c_{o u t}\right)\right], \quad \hat{x}(n)=\sum_{j=0}^{k} x(k \cdot n+j), n \in\left\{0,1, \ldots, c_{o u t}\right\}\)

Where \(k = \frac{c_{in}}{c_{ou}}\)

Parameters
  • in_features (int) – Number of input features

  • out_features (int) – Number of output features

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: torch.Tensor) torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class glasses.models.classification.fishnet.FishNetEncoder(in_channels: int = 3, start_features: int = 64, tail_depths: List[int] = [1, 1, 1], body_depths: List[int] = [1, 1, 1], body_trans_depths: List[int] = [1, 1, 1], head_depths: List[int] = [1, 1, 1], head_trans_depths: List[int] = [1, 1, 1], bridge_depth: int = 1, block: torch.nn.modules.module.Module = functools.partial(<class 'glasses.models.classification.resnet.ResNetBottleneckPreActBlock'>, shortcut=functools.partial(<class 'glasses.nn.blocks.BnActConv'>, kernel_size=1)), stem: torch.nn.modules.module.Module = functools.partial(<class 'glasses.models.classification.resnet.ResNetStem3x3'>, widths=[32, 32]), activation: torch.nn.modules.module.Module = functools.partial(<class 'torch.nn.modules.activation.ReLU'>, inplace=True), **kwargs)[source]

Bases: torch.nn.modules.module.Module

FishNetEncoder encoder composed by a tail, body and head.

The following image is taken from the paper and shows the architecture detail.

https://github.com/FrancescoSaverioZuppichini/glasses/blob/develop/docs/_static/images/FishNetEncoder.png?raw=true
Parameters
  • in_channels (int, optional) – [description]. Defaults to 3.

  • start_features (int, optional) – [description]. Defaults to 64.

  • tail_depths (List[int], optional) – [description]. Defaults to [1, 1, 1].

  • body_depths (List[int], optional) – [description]. Defaults to [1, 1, 1].

  • body_trans_depths (List[int], optional) – [description]. Defaults to [1, 1, 1].

  • head_depths (List[int], optional) – [description]. Defaults to [1, 1, 1].

  • head_trans_depths (List[int], optional) – [description]. Defaults to [1, 1, 1].

  • bridge_depth (int, optional) – [description]. Defaults to 1.

  • block (nn.Module, optional) – [description]. Defaults to FishNetBottleNeck.

  • activation (nn.Module, optional) – [description]. Defaults to ReLUInPlace.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

static find_widths(start_features: int = 64, depth: int = 3) List[int][source]

This code iteratively computes the correnct number of in and out features for each FishNet layer.

Code copied from Fishnet-PyTorch

Parameters
  • start_features (int, optional) – [description]. Defaults to 64.

  • depth (int, optional) – [description]. Defaults to 3.

Returns

[description]

Return type

List[int]

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class glasses.models.classification.fishnet.FishNetHead(in_features: int, n_classes: int, activation: torch.nn.modules.module.Module = <class 'torch.nn.modules.activation.ReLU'>)[source]

Bases: torch.nn.modules.container.Sequential

FishNet Head composed by 1x1 convs.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class glasses.models.classification.fishnet.FishNetHeadBlock(in_features: int, out_features: int, trans_features: int, block: torch.nn.modules.module.Module = functools.partial(<class 'glasses.models.classification.resnet.ResNetBottleneckPreActBlock'>, shortcut=functools.partial(<class 'glasses.nn.blocks.BnActConv'>, kernel_size=1)), depth: int = 1, trans_depth: int = 1, **kwargs)[source]

Bases: glasses.models.classification.fishnet.FishNetBodyBlock

FishNet head block, called the Down-sampling & Refinement block in the paper.

Parameters
  • in_features (int) – Number of input features

  • out_features (int) – Number of output features

  • trans_features (int) – [description]

  • block (nn.Module, optional) – [description]. Defaults to FishNetBottleNeck.

  • depth (int, optional) – [description]. Defaults to 1.

  • trans_depth (int, optional) – [description]. Defaults to 1.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

training: bool
class glasses.models.classification.fishnet.FishNetTail(in_features: int, out_features: int, depth: int = 1, block: torch.nn.modules.module.Module = functools.partial(<class 'glasses.models.classification.resnet.ResNetBottleneckPreActBlock'>, shortcut=functools.partial(<class 'glasses.nn.blocks.BnActConv'>, kernel_size=1)), *args, **kwargs)[source]

Bases: torch.nn.modules.container.Sequential

FishNet Tail

Parameters
  • in_features (int) – Number of input features

  • out_features (int) – Number of output features

  • depth (int, optional) – [description]. Defaults to 1.

  • block (nn.Module, optional) – [description]. Defaults to FishNetBottleNeck.

Initializes internal Module state, shared by both nn.Module and ScriptModule.