Hi guys,
I am reproducing the DeepLabV3+ these days,
and I write a Module like this,
self.entry_flow = nn.Sequential()
# entry_flow的第一个卷积层
self.entry_flow.add_module("conv1", nn.Conv2d(3, 32, 3, 2, 0, bias=False))
self.entry_flow.add_module("bn_relu1", BNReLU(32))
self.entry_flow.add_module("conv2", nn.Conv2d(32, 64, 3, bias=False))
self.entry_flow.add_module("bn_relu2", BNReLU(64))
# 添加三个Block模块
self.entry_flow.add_module("block1", XceptionBlock(64, 64, [1, 1, 2]))
self.entry_flow.add_module("block2", XceptionBlock(128, 128, [1, 1, 2]))
self.entry_flow.add_module("block2", XceptionBlock(256, 256, [1, 1, 2]))
and I am wondering if I can get the the intermediate output of inner module "block1" and "block2"?
Any answer or suggestion will be appreciated!
Watch out: you added a module with the name "block2" twice. This is most certainly not want you want. For my answer I replaced the second "block2"with "block3".
No, you can't. That's the whole point of an nn.Sequential: perform all operations successively and only return the final result. If you do depend on the intermediate results you should use an nn.Module and implement a custom forward() method. In your case that might look like this:
from torch import nn
class BNRelu(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
raise RuntimeError
class XceptionBlock(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
raise RuntimeError
class MyModule(nn.Module):
def __init__(self, return_intermediate=True):
super().__init__()
self.block0 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 0, bias=False),
BNRelu(32),
nn.Conv2d(32, 64, 3, bias=False),
BNRelu(64),
)
self.block1 = XceptionBlock(64, 64, [1, 1, 2])
self.block2 = XceptionBlock(128, 128, [1, 1, 2])
self.block3 = XceptionBlock(256, 256, [1, 1, 2])
self.return_intermediate = return_intermediate
def forward(self, x):
intermediate_results = {}
x = intermediate_results["block0"] = self.block0(x)
x = intermediate_results["block1"] = self.block1(x)
x = intermediate_results["block2"] = self.block2(x)
result = self.block3(x)
if self.return_intermediate:
return result, intermediate_results
else:
return result
If you construct this with model = MyModule(return_intermediate=True) you get a dict with all intermediate results additionally to the final result:
output, intermediate_results = model(input)
If you set return_intermediate=False it behaves like a regular nn.Sequential.
If you do this regularly, it might be advantageous to define an IntermediateSequential that handles this for you:
class IntermediateSequential(nn.Sequential):
def __init__(self, *args, return_intermediate=True):
super().__init__(*args)
self.return_intermediate = return_intermediate
def forward(self, input):
if not self.return_intermediate:
return super().forward(input)
intermediate_outputs = {}
output = input
for name, module in self.named_children():
output = intermediate_outputs[name] = module(output)
return output, intermediate_outputs
With this you can use your old style:
self.entry_flow = IntermediateSequential()
# entry_flow的第一个卷积层
self.entry_flow.add_module("conv1", nn.Conv2d(3, 32, 3, 2, 0, bias=False))
self.entry_flow.add_module("bn_relu1", BNReLU(32))
self.entry_flow.add_module("conv2", nn.Conv2d(32, 64, 3, bias=False))
self.entry_flow.add_module("bn_relu2", BNReLU(64))
# 添加三个Block模块
self.entry_flow.add_module("block1", XceptionBlock(64, 64, [1, 1, 2]))
self.entry_flow.add_module("block2", XceptionBlock(128, 128, [1, 1, 2]))
self.entry_flow.add_module("block3", XceptionBlock(256, 256, [1, 1, 2]))
and get the results from block1 and block2 with
output_block3, intermediate_outputs = self.entry_flow(input)
output_block1 = intermediate_outputs["block1"]
output_block2 = intermediate_outputs["block2"]
@pmeier , hi, Philip!
Thanks sincerely for your kind reply and detailed suggestions.
And I have a little question why there adds a raise RuntimeError in the constructors of the modules?
I've simply added these classes to make the example "complete", since you have not given us an implementation for BNRelu and XceptionBlock. They are raiseing the RuntimeErrors just to remind you that these classes need to be implemented. Just use whatever you used before.
@pmeier , OK, your codes inspired me a lot,
and I read some more materials yesterday, and I found there is another implementation style of hook, (and I didn't see this style in your suggestions).
So I want to know how do you think about using register_forward_hook to realizing getting intermediate ouputs of the intermediate modules?
I don't think register_forward_hook() will help you here. The documentation states:
The hook will be called every time after
forward()has computed an output.
Thus, if you use an nn.Sequential the output you get in the hook is only the final result and you have no option the retrieve the intermediate ones. The only option to use a forward hook is to store the intermediate results during calculation and retrieve them within the hook. I would advise against this, since:
forward() method that stores the intermediate results and thus this approach has no advantage over my solution andnn.Module can get you in a lot of other trouble (think state_dict()).@pmeier , I think there is a way to realize my idea, let me describe it below:
nn.Moduleand name it with some specific names, so that we can retrieve them, no matter how deep these inner modules inside the top network;model.named_children() and use the specific names to find them, then get their output,walk() in Python. Just to be sure: you want to have access to every intermediate result not just the "top level"? If that is the case named_children() will not help you as it only yields the immediate children of the nn.Module.
If I understand your approach correctly, your approach has two major flaws:
With your solution you are computing the outputs multiple times. This is a twofold problem:
@pmeier , sorry for my mistake, in fact, my idea is to use model.named_modules() to get every inner intermediate modules so that we can access every intermediate result.
@songyuc I think the best way without code modification is to use forward hooks as @pmeier mentioned. Another way would be to modify the forward of the module to return the elements that you want, but that requires more code.
Let us know if you have further questions.