The base AutoModel
class.
Usage:
auto_model = AutoModel()
model = auto_model.from_name("my_name")
model = auto_model.from_pretrained("my_name")
model = auto_model.from_pretrained("my_name", my_config)
Source code in glasses/models/auto/base.py
| class AutoModel:
"""The base `AutoModel` class.
Usage:
```python
auto_model = AutoModel()
model = auto_model.from_name("my_name")
model = auto_model.from_pretrained("my_name")
model = auto_model.from_pretrained("my_name", my_config)
```
"""
names_to_configs: Dict[str, Callable[[], Config]]
"""Holds the map from name to config type"""
@classmethod
def get_config_from_name(cls, name: str) -> Config:
return cls.names_to_configs[name]()
@classmethod
def from_name(cls, name: str):
if name not in cls.names_to_configs:
suggestions = difflib.get_close_matches(name, cls.names_to_configs.keys())
msg = f'Model "{name}" does not exists.'
if len(suggestions) > 0:
msg += f' Did you mean "{suggestions[0]}?"'
raise KeyError(msg)
config = cls.names_to_configs[name]()
return config.build()
@classmethod
def from_pretrained(
cls, name: str, config: Optional[Config] = None, storage: Storage = None
) -> nn.Module:
storage = LocalStorage() if storage is None else storage
state_dict, loaded_config = storage.get(name)
config = loaded_config if config is None else config
model = config.build()
try:
model.load_state_dict(state_dict)
except RuntimeError as e:
logger.warning(str(e))
logger.info(f"Loaded pretrained weights for {name}.")
return model, config
|
names_to_configs: Dict[str, Callable[[], Config]]
class-attribute
Holds the map from name to config type