Source code for glasses.utils.Storage

from functools import partial, reduce
from collections import OrderedDict

"""Local copy with minor changes of `PytorchModuleStorage <https://github.com/FrancescoSaverioZuppichini/PytorchModuleStorage>`_

.. warning::
    The `PytorchModuleStorage` are not great and I wish to update the library and make it a dependency of glasses.
    This is why typing is missing here.
"""


[docs]class MutipleKeysDict(OrderedDict): """ Allow to get values from multiple keys. Example: ```python d = MutipleKeysDict({ 'a' : 1, 'b' : 2, 'c' : 3}) d[['a', 'b']] # out [1,2] ``` """ def __getitem__(self, keys): if type(keys) is list: item = [dict.__getitem__(self, key) for key in keys] else: item = super().__getitem__(keys) # # if is a list and contains only one el, return it # if type(item) is list and len(item) == 1: item = item[0] return item
[docs]class ModuleStorage: def __init__(self, where2layers, debug=False): self.where2layers = where2layers self.where = list(self.names)[0] self.state = self._state self.unsubcribe = [] self.debug = debug @property def _state(self): return MutipleKeysDict( { k: MutipleKeysDict() if type(self.where2layers) == dict else [] for k in self.names } ) @property def names(self): names = [] if type(self.where2layers) == dict: names = self.where2layers.keys() elif type(self.where2layers) is list: names = self.where2layers return names @property def layers(self): """ Flat all the layers in the same array """ layers = [] if type(self.where2layers) == dict: layers = reduce(lambda a, b: a + b, self.where2layers.values()) elif type(self.where2layers) is list: layers = self.where2layers return layers
[docs] def register_hooks(self, how="forward"): """ Loop in all the layers and register a hook. There is ONLY one hook per layer to improve performance. """ for layer in self.layers: # create a hash of a layer as an identifier, this is unique # name = f"{type(layer).__name__.lower()}-{hash(layer)}" if how == "forward": self.unsubcribe.append( layer.register_forward_hook(partial(self.hook, name=layer)) ) elif how == "backward": self.unsubcribe.append( layer.register_backward_hook(partial(self.hook, name=layer)) ) else: raise ValueError("type must be 'forward' or 'backward'") if self.debug: print(f"[INFO] {how} hook registered to {layer}")
[docs] def hook(self, m, i, o, name): if self.debug: print(f"{m} called") if type(self.where2layers) == dict: # store only the outputs from the correct layers defined in self.where2layers if m in self.where2layers[self.where]: self.state[self.where][name] = o if type(self.where2layers) is list: self.state[name] = o
[docs] def clear(self): if self.debug: print("[INFO] clear") [un.remove() for un in self.unsubcribe]
def __call__(self, where=None, *args, **kwargs): if where is not None: if where not in self.keys(): raise KeyError(f"we cannot find any layers with key {where}") self.where = where def __repr__(self): items = lambda x: x.items() if type(x) == MutipleKeysDict else enumerate(x) return str( {k: [{i: e.shape for i, e in items(v)}] for k, v in self.state.items()} ) def __getitem__(self, key): item = self.state[key] return item
[docs] def keys(self): return self.state.keys()
[docs]class ForwardModuleStorage(ModuleStorage): def __init__(self, module, *args, **kwargs): super().__init__(*args, **kwargs) self.module = module self.register_hooks(how="forward") def __call__(self, x, *args, **kwargs): super().__call__(*args, **kwargs) if type(x) != list: x = [x] [self.module(_x) for _x in x]
[docs]class BackwardModuleStorage(ModuleStorage): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.register_hooks(how="backward") def __call__(self, x, *args, **kwargs): super().__call__(*args, **kwargs) if type(x) != list: x = [x] [_x.backward() for _x in x]