Skip to content

model

ViTHead

Bases: HeadForImageClassification

Source code in glasses/models/vision/image/classification/heads/vit/model.py
class ViTHead(HeadForImageClassification):
    POLICIES = ["token", "mean"]

    def __init__(
        self, emb_size: int = 768, num_classes: int = 1000, policy: str = "token"
    ):
        """
        ViT Classification Head
        Args:
            emb_size (int, optional):  Embedding dimensions Defaults to 768.
            num_classes (int, optional): [description]. Defaults to 1000.
            policy (str, optional): Pooling policy, can be token or mean. Defaults to 'token'.
        """
        if policy not in self.POLICIES:
            raise ValueError(f"Only policies {','.join(self.POLICIES)} are supported")

        super().__init__()
        self.pool = (
            Reduce("b n e -> b e", reduction="mean")
            if policy == "mean"
            else Lambda(lambda x: x[:, 0])
        )
        self.fc = nn.Linear(emb_size, num_classes)

    def forward(self, features: List[Tensor]) -> Tensor:
        x = self.pool(features[-1])
        x = self.fc(x)
        return x

__init__(emb_size=768, num_classes=1000, policy='token')

ViT Classification Head

Parameters:

Name Type Description Default
emb_size int

Embedding dimensions Defaults to 768.

768
num_classes int

[description]. Defaults to 1000.

1000
policy str

Pooling policy, can be token or mean. Defaults to 'token'.

'token'
Source code in glasses/models/vision/image/classification/heads/vit/model.py
def __init__(
    self, emb_size: int = 768, num_classes: int = 1000, policy: str = "token"
):
    """
    ViT Classification Head
    Args:
        emb_size (int, optional):  Embedding dimensions Defaults to 768.
        num_classes (int, optional): [description]. Defaults to 1000.
        policy (str, optional): Pooling policy, can be token or mean. Defaults to 'token'.
    """
    if policy not in self.POLICIES:
        raise ValueError(f"Only policies {','.join(self.POLICIES)} are supported")

    super().__init__()
    self.pool = (
        Reduce("b n e -> b e", reduction="mean")
        if policy == "mean"
        else Lambda(lambda x: x[:, 0])
    )
    self.fc = nn.Linear(emb_size, num_classes)