Home
Softono
CascadePSP

CascadePSP

Open source MIT Python
884
Stars
97
Forks
3
Issues
14
Watchers
5 months
Last Commit

About CascadePSP

CascadePSP is a deep learning framework designed for class-agnostic, very high-resolution semantic segmentation refinement. Published in CVPR 2020, the system employs a global and local refinement strategy to enhance coarse segmentation masks into precise, high-definition outputs. The implementation is built on PyTorch and includes both training and testing functionalities. A key feature is its ability to refine masks without requiring class-specific training, making it adaptable to various datasets including the provided UHD BIG dataset and Relabeled PASCAL VOC 2012. The architecture consists of a global step for coarse context and a local step for detailed boundary adjustments, supported by a specialized refinement module. The software offers a streamlined pip package named segmentation-refinement, allowing users to process images with minimal code. It supports execution on both CUDA and CPU devices, with adjustable parameters for memory usage and speed. Pretrained models are available for immediate inferen

Platforms

Web Self-hosted

Languages

Python

CascadePSP: Toward Class-Agnostic and Very High-Resolution Segmentation via Global and Local Refinement

Ho Kei Cheng*, Jihoon Chung*, Yu-Wing Tai, Chi-Keung Tang

[arXiv] [PDF]

[Supplementary Information (Comparisons with DenseCRF included!)]

[Supplementary image results]

gif

Introduction

CascadePSP is a deep learning model for high-resolution segmentation refinement. This repository contains our PyTorch implementation with both training and testing functionalities. We also provide the annotated UHD dataset BIG and the pretrained model.

Here are some refinement results on high-resolution images. teaser

Quick Start

Tested on PyTorch 1.0 -- though higher versions would likely work for inference as well.

Check out this folder. We have built a pip package that can refine an input image with two lines of code.

Install with

pip install segmentation-refinement

Code demo:

import cv2
import time
import matplotlib.pyplot as plt
import segmentation_refinement as refine
image = cv2.imread('test/aeroplane.jpg')
mask = cv2.imread('test/aeroplane.png', cv2.IMREAD_GRAYSCALE)

# model_path can also be specified here
# This step takes some time to load the model
refiner = refine.Refiner(device='cuda:0') # device can also be 'cpu'

# Fast - Global step only.
# Smaller L -> Less memory usage; faster in fast mode.
output = refiner.refine(image, mask, fast=False, L=900) 

# this line to save output
cv2.imwrite('output.png', output)

plt.imshow(output)
plt.show()

Network Overview

Global Step & Local Step

Global Step Local Step
Global Step Local Step

Refinement Module

Refinement Module

Table of Contents

Running:

Downloads:

More Results

Refining the masks of Human 3.6M

Image Original Mask Refined Mask
Image OriginalMask RefinedMask
Image OriginalMask RefinedMask
Image OriginalMask RefinedMask

The first row is the failure case (see neck).

Credit

PSPNet implementation: https://github.com/Lextal/pspnet-pytorch

SyncBN implementation: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch

If you find our work useful in your research, please cite the following:

@inproceedings{cheng2020cascadepsp,
  title={{CascadePSP}: Toward Class-Agnostic and Very High-Resolution Segmentation via Global and Local Refinement},
  author={Cheng, Ho Kei and Chung, Jihoon and Tai, Yu-Wing and Tang, Chi-Keung},
  booktitle={CVPR},
  year={2020}
}