Skip to content

Expand use_linear to all models and UpsampleInSpace module#52

Open
TsChala wants to merge 1 commit intoORNL:mainfrom
TsChala:Conv3DToLinear
Open

Expand use_linear to all models and UpsampleInSpace module#52
TsChala wants to merge 1 commit intoORNL:mainfrom
TsChala:Conv3DToLinear

Conversation

@TsChala
Copy link
Copy Markdown
Collaborator

@TsChala TsChala commented Apr 22, 2026

Following #48 , I expanded the option to use a linear layer instead of conv3D. The following are the proposed changes:

  1. Expand use_linear to AViT, SViT and TurbT models as an option, default is False.
  2. In the TurbT model the UpsampleInSpace module also has conv3D's, these are replaced with linear layers as well if use_linear=True
  3. expand_projections is now compatbile with use_linear=True, Fix smooth layer and expand_projections regressions from #48 #50 already adressed the hMLP_output, I added the necessary parts for UpsampleInSpace as well.
  4. I found a small misalignment between the linear and conv3D version in the hMLP_output's out_head bias. The dimension of the bias for this term is out_chans if use_linear=False, while it is out_chans * kD * kH * kW if use_linear=True. I suggest to add the bias after rearranging so it's the same size as in the conv3D case. This keeps the parameter count consistent between linear and conv3D options.

@TsChala TsChala requested a review from pzhanggit April 22, 2026 15:20
Copy link
Copy Markdown
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @TsChala. I have two comments: (1) could we be explcit about the assumptions made for the linear changes to be functionally equivalent? For conv2d/conv3d, we’re assuming the patches are non-overlapping. Could we raise an error when that’s not the case but use_linear is on?
(2) The current nn.linear implementations of some functions like UpsampleConv3d do not reproduce the same functionality.

Comment on lines +151 to +156
self.out_proj.append(nn.Linear(channels, channels * kD * kH * kW, bias=False))
self.out_proj.append(nn.InstanceNorm3d(channels, affine=True))
self.out_proj.append(nn.GELU())
# Final head
kD, kH, kW = self.ks[0]
self.out_head = nn.Linear(channels, channels * kD * kH * kW)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nn.Linear operations here are not equivalent to UpsampleConv3d, as UpsampleConv3d consists of nn.Upsample and nn.Conv3d.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Upon further investigation I think it doesn't make sense to include the use_linear option for the UpsampleConv3D part. We do the upsampling to change the size, then do the Conv3D block such that it doesn't change the input size. We hardcoded stride=1 and padding="same" exactly for this reason. Therefore, it is not really possible to have the non-overlapping scenario here (when kernel size == stride). I'll remove the use_linear option from UpsampleConv3d parts.

kD, kH, kW = self.ks[-(ilayer+1)]
# Apply linear, norm, activation
x = rearrange(x, 'tb c d h w -> (tb d h w) c')
x = self.out_proj[layer_idx](x) # Linear layer
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above. We will need to fix it to make it consistent.

Comment thread matey/models/spatial_modules.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants