Hooks are useful to debug and tinker with your PyTorch models. But how do you debug the hooks themselves?
Suppose you have a tensor with some hooks, for instance:
t = torch.tensor(1.0, requires_grad=True) def hook_1(grad): return grad * 2 def hook_2(grad): return grad + 1 t.register_hook(hook_1) t.register_hook(hook_2)
What if you are adding this kind of hook at different places in your code? The order matters, so you might want to print all the hooks for your tensor to check in which order they are called.
I keep forgetting the code snippet to do that (the only place I found it is this excellent video about hooks), so here it is:
It returns an ordered dictionary with all the hooks:
OrderedDict([(0, <function hook_1 at 0x7f7157b481f0>), (1, <function hook_2 at 0x7f70f8fb8e50>)])