Skip to content

base

HeadForImageClassification

Bases: nn.Module

Base class for classification heads

Define a custom classification head

class LinearHead(HeadForImageClassification):
    def __init__(self, num_classes: int, in_channels: int):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()
        self.fc = nn.Linear(in_channels, num_classes)

    def forward(self, features: List[Tensor]) -> Tensor:
        x = features[-1]
        x = self.pool(x)
        x = self.flat(x)
        x = self.fc(x)
        return x
Source code in glasses/models/vision/image/classification/heads/base.py
class HeadForImageClassification(nn.Module):
    """Base class for classification heads

    Define a custom classification head

    ```python

    class LinearHead(HeadForImageClassification):
        def __init__(self, num_classes: int, in_channels: int):
            super().__init__()
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
            self.flat = nn.Flatten()
            self.fc = nn.Linear(in_channels, num_classes)

        def forward(self, features: List[Tensor]) -> Tensor:
            x = features[-1]
            x = self.pool(x)
            x = self.flat(x)
            x = self.fc(x)
            return x
    ```

    """

    def forward(self, features: List[Tensor]) -> Tensor:
        """The forward method for classification head.

        Args:
            features (List[Tensor]): A list of features.

        Returns:
            Tensor: The logits
        """
        raise NotImplementedError

forward(features)

The forward method for classification head.

Parameters:

Name Type Description Default
features List[Tensor]

A list of features.

required

Returns:

Name Type Description
Tensor Tensor

The logits

Source code in glasses/models/vision/image/classification/heads/base.py
def forward(self, features: List[Tensor]) -> Tensor:
    """The forward method for classification head.

    Args:
        features (List[Tensor]): A list of features.

    Returns:
        Tensor: The logits
    """
    raise NotImplementedError