base
HeadForImageClassification
Bases: nn.Module
Base class for classification heads
Define a custom classification head
class LinearHead(HeadForImageClassification):
def __init__(self, num_classes: int, in_channels: int):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.flat = nn.Flatten()
self.fc = nn.Linear(in_channels, num_classes)
def forward(self, features: List[Tensor]) -> Tensor:
x = features[-1]
x = self.pool(x)
x = self.flat(x)
x = self.fc(x)
return x
Source code in glasses/models/vision/image/classification/heads/base.py
forward(features)
The forward method for classification head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
features |
List[Tensor]
|
A list of features. |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tensor
|
The logits |