support float 16 on inference

This commit is contained in:
Cherrytest
2025-09-29 16:15:24 +00:00
parent cf6aec8113
commit 88bfd511d4

View File

@ -993,7 +993,7 @@ class BasicLayer(nn.Module):
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
for blk in self.blocks:
blk.H, blk.W = H, W