Pytorch3d: Why the network is not updating?

Created on 11 Apr 2020  路  7Comments  路  Source: facebookresearch/pytorch3d

My network is not updating using GraphConv. I adopt the code from MeshRCNN, is this a bug?

Following are my code snippets, trying to make it as simple as possible. For training, I have:

net = MeshRefinementStage([16, 32, 32], 0, 'zero')
dataset = CustomDataLoader(2).load_data()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-1, betas=(0.9, 0.999))
criterionMesh = MeshLoss().to(device)
criterionMSE = torch.nn.MSELoss(reduction='sum')

for epoch in range(3):
    for i, data in enumerate(dataset):
        verts = data['verts'].to(device)
        faces = data['faces'].to(device)
        meshes = Meshes(verts=verts.to(device), faces=faces.to(device))

        meshes_out = net(meshes, [0,1,0])
        subd = SubdivideMeshes()
        meshes.scale_verts_(1.3)
        meshes = subd(meshes)
        a = list(net.parameters())[0].clone()
        optimizer.zero_grad()
        #loss = criterionMesh(meshes_out, meshes)
        loss = criterionMSE(meshes_out.verts_packed(), meshes.verts_packed())
        loss.retain_grad()
        loss.backward()
        optimizer.step()
        b = list(net.parameters())[0].clone()
        print(torch.equal(a.data, b.data),'!!!!!')
    print('finish*')

and the network looks like this:

class MeshRefinementStage(nn.Module):
    """
    Args:
        feat_dims (tensor): Output feature dimensions for each stage, [16, 32, 64].
        vert_feat_dim (int): Dimension of vert_feats we will receive from the previous stage; can be 0
        gconv_init (int): Specifies weight initialization for graph-conv layers
    """
    def __init__(self, feat_dims, vert_feat_dim, gconv_init='normal'):
        super(MeshRefinementStage, self).__init__()
        self.gconvs = nn.ModuleList()
        for i in range(len(feat_dims)):
            if i == 0:
                #input_dim = vert_feat_dim + 3
                input_dim = 3
            else:
                input_dim = feat_dims[i-1]
            output_dim = feat_dims[i]
            gconv = GraphConv(input_dim, output_dim, init=gconv_init, directed=False).to(device)
            self.gconvs.append(gconv)

        self.verts_offset = nn.Linear(feat_dims[-1], 3).to(device)
        nn.init.zeros_(self.verts_offset.weight)
        nn.init.constant_(self.verts_offset.bias, 0)

    def forward(self, meshes, vert_feats=None):
        """
        Args:
            meshes (Meshes): Initial meshes which will get refined
            vert_feats (tensor): Features from the previous refinement stage
        Returns:
            meshes_out (Meshes): Refined meshes in this stage.
        """
        vert_pos_packed = meshes.verts_packed()
        first_layer_feats = [vert_pos_packed]
        # if vert_feats is not None:
        #     first_layer_feats.append(vert_feats)
        vert_feats = torch.cat(first_layer_feats, dim=1)

        # run graph conv layers
        for gconv in self.gconvs:
            vert_feats = F.relu(gconv(vert_feats, meshes.edges_packed()))
            #vert_feats = torch.cat([vert_feats_nopos, vert_pos_packed], dim=1)

        # predict new meshes
        # # apply scale
        # verts_scale = torch.abs(self.verts_scale(vert_feats))
        # meshes_out = meshes.scale_verts(verts_scale)
        # apply offside
        verts_offsets = torch.tanh(self.verts_offset(vert_feats))
        meshes_out = meshes.offset_verts(verts_offsets)
        subd = SubdivideMeshes()
        meshes_out = subd(meshes_out)

        return meshes_out

The output gives me

dataset [Dataset] was created
True !!!!!
True !!!!!
True !!!!!
True !!!!!
True !!!!!
finish*
True !!!!!
True !!!!!
True !!!!!
True !!!!!
True !!!!!
finish*
True !!!!!
True !!!!!
True !!!!!
True !!!!!
True !!!!!
finish*

which means it's not learning at all. However, if I use only one GraphConv layer in the net, like:

class MeshRefinementStage(nn.Module):
    """
    Args:
        feat_dims (tensor): Output feature dimensions for each stage, [16, 32, 64].
        vert_feat_dim (int): Dimension of vert_feats we will receive from the previous stage; can be 0
        gconv_init (int): Specifies weight initialization for graph-conv layers
    """
    def __init__(self, feat_dims, vert_feat_dim, gconv_init='normal'):
        super(MeshRefinementStage, self).__init__()
        self.gconvs = nn.ModuleList()
        for i in range(1):
            if i == 0:
                #input_dim = vert_feat_dim + 3
                input_dim = 3
            else:
                input_dim = feat_dims[i-1]
            output_dim = feat_dims[i]
            gconv = GraphConv(input_dim, output_dim, init=gconv_init, directed=False).to(device)
            self.gconvs.append(gconv)

        self.verts_offset = nn.Linear(feat_dims[0], 3).to(device)
        nn.init.zeros_(self.verts_offset.weight)
        nn.init.constant_(self.verts_offset.bias, 0)

    def forward(self, meshes, vert_feats=None):
        """
        Args:
            meshes (Meshes): Initial meshes which will get refined
            vert_feats (tensor): Features from the previous refinement stage
        Returns:
            meshes_out (Meshes): Refined meshes in this stage.
        """
        vert_pos_packed = meshes.verts_packed()
        first_layer_feats = [vert_pos_packed]
        # if vert_feats is not None:
        #     first_layer_feats.append(vert_feats)
        vert_feats = torch.cat(first_layer_feats, dim=1)

        # run graph conv layers
        for gconv in self.gconvs:
            vert_feats = F.relu(gconv(vert_feats, meshes.edges_packed()))
            #vert_feats = torch.cat([vert_feats_nopos, vert_pos_packed], dim=1)

        # predict new meshes
        # # apply scale
        # verts_scale = torch.abs(self.verts_scale(vert_feats))
        # meshes_out = meshes.scale_verts(verts_scale)
        # apply offside
        verts_offsets = torch.tanh(self.verts_offset(vert_feats))
        meshes_out = meshes.offset_verts(verts_offsets)
        subd = SubdivideMeshes()
        meshes_out = subd(meshes_out)

        return meshes_out

the output gives me:

dataset [Dataset] was created
True !!!!!
False !!!!!
False !!!!!
False !!!!!
False !!!!!
finish*
False !!!!!
False !!!!!
False !!!!!
False !!!!!
False !!!!!
finish*
False !!!!!
False !!!!!
False !!!!!
False !!!!!
False !!!!!
finish*

What happened???

how to

Most helpful comment

Hi @czkg! Your issue has nothing to do with PyTorch3D and is related to optimization and requires a better understanding of deep learning.

Instead of feeding you the answer, I will take a pedagogical approach and will try to show you how you can debug this and how you can get to the answer yourself. I think that's more valuable than me telling you what's wrong with your code.

First tip: You should try to isolate your problem as much as possible. The code snippet that you provided here is far from cleaned up, has tons of commented code out and most importantly is not reproducible by any other human being because we don't have access to you data. So instead, I'd replace your dataset with simple meshes. e.g. ico_spheres. This is what I did to work with your code; I replaced your dataset with ico_spheres. I also removed all the SubdivideMeshes operations because they're irrelevant.

Second tip: You are checking to see if the first layer of your MeshRefinementStage gets updated. This is defined here in your code:

a = list(net.parameters())[0].clone()

How about the later stages of your net? If you check for the last stage for example (list(net.parameters())[-1].clone()), you will see that it actually does get updated! Well, that's a good clue! If you were to decrease the number of your layers to 1 by defining your net as net = MeshRefinementStage([16], 0, 'zero'), then you'd see that the first layer actually does get updated. We're getting somewhere!

Third tip: Merely checking for updated weights and whether they change is a good sanity check. However, the explanation of why something does not change lies in the math. Deep learning is not black magic, it's simple math, in fact it's the chain rule! To better understand this behavior you could merely check for the intermediate feature activations in MeshRefinementStage. Even more advanced, you could declare retrain_grad() for all these tensors and also check their gradients. If you do, you will see something really interesting. You will see that the optimizer tries to push your weights to 0.0. This usually occurs in deep learning when the initialization is suboptimal (and thus optimization finds a local optima) and your problem is not well defined. In this case. To verify that in practice you can play with some initialization schemes, e.g. model = MeshRefinementStage([16, 32, 32], gconv_init="normal").to(device) which initializes the GraphConvs with a normal distribution instead of 0s. You will see a different model behaviour.

Ok, this was a long response but I wanted to emphasize that there is an explanation for everything. Surely you can open an issue with all the deep learning libraries to get help with debugging your code, but you should be able to figure these out on your own first. For future reference, reporting potential bugs should be accompanied by thorough investigation and clear evidence that something is wrong. In your case, you will quickly see that there is some optimization issue with your implementation which leads to dead activations. I hope this helped.

All 7 comments

Hi @czkg! Your issue has nothing to do with PyTorch3D and is related to optimization and requires a better understanding of deep learning.

Instead of feeding you the answer, I will take a pedagogical approach and will try to show you how you can debug this and how you can get to the answer yourself. I think that's more valuable than me telling you what's wrong with your code.

First tip: You should try to isolate your problem as much as possible. The code snippet that you provided here is far from cleaned up, has tons of commented code out and most importantly is not reproducible by any other human being because we don't have access to you data. So instead, I'd replace your dataset with simple meshes. e.g. ico_spheres. This is what I did to work with your code; I replaced your dataset with ico_spheres. I also removed all the SubdivideMeshes operations because they're irrelevant.

Second tip: You are checking to see if the first layer of your MeshRefinementStage gets updated. This is defined here in your code:

a = list(net.parameters())[0].clone()

How about the later stages of your net? If you check for the last stage for example (list(net.parameters())[-1].clone()), you will see that it actually does get updated! Well, that's a good clue! If you were to decrease the number of your layers to 1 by defining your net as net = MeshRefinementStage([16], 0, 'zero'), then you'd see that the first layer actually does get updated. We're getting somewhere!

Third tip: Merely checking for updated weights and whether they change is a good sanity check. However, the explanation of why something does not change lies in the math. Deep learning is not black magic, it's simple math, in fact it's the chain rule! To better understand this behavior you could merely check for the intermediate feature activations in MeshRefinementStage. Even more advanced, you could declare retrain_grad() for all these tensors and also check their gradients. If you do, you will see something really interesting. You will see that the optimizer tries to push your weights to 0.0. This usually occurs in deep learning when the initialization is suboptimal (and thus optimization finds a local optima) and your problem is not well defined. In this case. To verify that in practice you can play with some initialization schemes, e.g. model = MeshRefinementStage([16, 32, 32], gconv_init="normal").to(device) which initializes the GraphConvs with a normal distribution instead of 0s. You will see a different model behaviour.

Ok, this was a long response but I wanted to emphasize that there is an explanation for everything. Surely you can open an issue with all the deep learning libraries to get help with debugging your code, but you should be able to figure these out on your own first. For future reference, reporting potential bugs should be accompanied by thorough investigation and clear evidence that something is wrong. In your case, you will quickly see that there is some optimization issue with your implementation which leads to dead activations. I hope this helped.

Closing this task.

Hi @gkioxari , thanks so much for your patience. As you suggested, I checked my network and noticed no matter how many gconv layers I am using, only the last layer in the module is updated, do you have any suggestion?

The answer is suggesting more things, such as switching the initialization scheme of the GraphConv layers from "zero" to "normal". Again, I don't have access to your data, but when I reproduced your code by changing your data with ico_spheres changing the initialization to normal worked fine.

I think the problem lies with your problem formulation. You're trying to go from (x,y,z) of an initial, arbitrary, 3D data sample to deformed (x,y,z) that approximates the same indexed vertex of the target mesh sphere via operations that are index-agnostic (as in they share the same weights regardless of vertex index in the mesh). This problem is ill-defined, as is, and will have optimization issues which is exactly what you're seeing and which is expected.

@gkioxari Yes, indeed. It works after I switched to 'normal'. Seems like the net is hard to learning anything given zero initialization.

I wrote above an explanation of why your problem formulation is ill-defined and is expected to have optimization issues. I think you should re-consider how you're setting your problem up! Good luck!

@gkioxari Thank you for your answer!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

aluo-x picture aluo-x  路  3Comments

NotAnyMike picture NotAnyMike  路  3Comments

cihanongun picture cihanongun  路  3Comments

MarkTension picture MarkTension  路  3Comments

TSKongLingwei picture TSKongLingwei  路  3Comments