mirror of
https://www.modelscope.cn/AI-ModelScope/RMBG-2.0.git
synced 2026-04-02 11:02:56 +08:00
Update README.md
This commit is contained in:
@ -146,15 +146,12 @@ transformers
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
|
||||
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
|
||||
torch.set_float32_matmul_precision(['high', 'highest'][0])
|
||||
model.to('cuda')
|
||||
model.eval()
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).eval().to(device)
|
||||
|
||||
# Data settings
|
||||
image_size = (1024, 1024)
|
||||
@ -165,7 +162,7 @@ transform_image = transforms.Compose([
|
||||
])
|
||||
|
||||
image = Image.open(input_image_path)
|
||||
input_images = transform_image(image).unsqueeze(0).to('cuda')
|
||||
input_images = transform_image(image).unsqueeze(0).to(device)
|
||||
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
|
||||
@ -6,6 +6,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
class Config(PretrainedConfig):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# PATH settings
|
||||
self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
|
||||
|
||||
|
||||
Reference in New Issue
Block a user