From cf6aec811327248541bb3f3a8d65c16f95cd1a35 Mon Sep 17 00:00:00 2001 From: Cherrytest Date: Mon, 22 Sep 2025 07:12:51 +0000 Subject: [PATCH] Update README.md --- README.md | 9 +++------ birefnet.py | 1 + 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 46483b3..c19d43b 100644 --- a/README.md +++ b/README.md @@ -146,15 +146,12 @@ transformers ```python from PIL import Image -import matplotlib.pyplot as plt import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation -model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) -torch.set_float32_matmul_precision(['high', 'highest'][0]) -model.to('cuda') -model.eval() +device = 'cuda' if torch.cuda.is_available() else 'cpu' +model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).eval().to(device) # Data settings image_size = (1024, 1024) @@ -165,7 +162,7 @@ transform_image = transforms.Compose([ ]) image = Image.open(input_image_path) -input_images = transform_image(image).unsqueeze(0).to('cuda') +input_images = transform_image(image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): diff --git a/birefnet.py b/birefnet.py index d68d7d3..bc63529 100644 --- a/birefnet.py +++ b/birefnet.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig class Config(PretrainedConfig): def __init__(self) -> None: + super().__init__() # PATH settings self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx