From 88bfd511d473702df7964f38242b4cab24863347 Mon Sep 17 00:00:00 2001 From: Cherrytest Date: Mon, 29 Sep 2025 16:15:24 +0000 Subject: [PATCH] support float 16 on inference --- birefnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/birefnet.py b/birefnet.py index bc63529..7ad5f3a 100644 --- a/birefnet.py +++ b/birefnet.py @@ -993,7 +993,7 @@ class BasicLayer(nn.Module): mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype) for blk in self.blocks: blk.H, blk.W = H, W