Source code for vis4d.data.datasets.coco

"""COCO dataset."""

from __future__ import annotations

import contextlib
import io
import os
from collections.abc import Sequence

import numpy as np
import pycocotools.mask as maskUtils
from pycocotools.coco import COCO as COCOAPI

from vis4d.common import ArgsType, DictStrAny
from vis4d.data.const import CommonKeys as K
from vis4d.data.typing import DictData

from .base import Dataset
from .util import CacheMappingMixin, im_decode

# COCO detection
coco_det_map = {
    "person": 0,
    "bicycle": 1,
    "car": 2,
    "motorcycle": 3,
    "airplane": 4,
    "bus": 5,
    "train": 6,
    "truck": 7,
    "boat": 8,
    "traffic light": 9,
    "fire hydrant": 10,
    "stop sign": 11,
    "parking meter": 12,
    "bench": 13,
    "bird": 14,
    "cat": 15,
    "dog": 16,
    "horse": 17,
    "sheep": 18,
    "cow": 19,
    "elephant": 20,
    "bear": 21,
    "zebra": 22,
    "giraffe": 23,
    "backpack": 24,
    "umbrella": 25,
    "handbag": 26,
    "tie": 27,
    "suitcase": 28,
    "frisbee": 29,
    "skis": 30,
    "snowboard": 31,
    "sports ball": 32,
    "kite": 33,
    "baseball bat": 34,
    "baseball glove": 35,
    "skateboard": 36,
    "surfboard": 37,
    "tennis racket": 38,
    "bottle": 39,
    "wine glass": 40,
    "cup": 41,
    "fork": 42,
    "knife": 43,
    "spoon": 44,
    "bowl": 45,
    "banana": 46,
    "apple": 47,
    "sandwich": 48,
    "orange": 49,
    "broccoli": 50,
    "carrot": 51,
    "hot dog": 52,
    "pizza": 53,
    "donut": 54,
    "cake": 55,
    "chair": 56,
    "couch": 57,
    "potted plant": 58,
    "bed": 59,
    "dining table": 60,
    "toilet": 61,
    "tv": 62,
    "laptop": 63,
    "mouse": 64,
    "remote": 65,
    "keyboard": 66,
    "cell phone": 67,
    "microwave": 68,
    "oven": 69,
    "toaster": 70,
    "sink": 71,
    "refrigerator": 72,
    "book": 73,
    "clock": 74,
    "vase": 75,
    "scissors": 76,
    "teddy bear": 77,
    "hair drier": 78,
    "toothbrush": 79,
}

# COCO segmentation categories
coco_seg_map = {
    "background": 0,
    "airplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "dining table": 11,
    "dog": 12,
    "horse": 13,
    "motorcycle": 14,
    "person": 15,
    "potted plant": 16,
    "sheep": 17,
    "couch": 18,
    "train": 19,
    "tv": 20,
}


[docs] class COCO(CacheMappingMixin, Dataset): """COCO dataset class.""" DESCRIPTION = """COCO is a large-scale object detection, segmentation, and captioning dataset.""" HOMEPAGE = "http://cocodataset.org" PAPER = "http://arxiv.org/abs/1405.0312" LICENSE = "BY-NC-SA 2.0" KEYS = [ K.images, K.input_hw, K.original_images, K.original_hw, K.sample_names, K.boxes2d, K.boxes2d_classes, K.instance_masks, K.seg_masks, ] def __init__( self, data_root: str, keys_to_load: Sequence[str] = ( K.images, K.boxes2d, K.boxes2d_classes, K.instance_masks, ), split: str = "train2017", remove_empty: bool = False, use_pascal_voc_cats: bool = False, cache_as_binary: bool = False, cached_file_path: str | None = None, **kwargs: ArgsType, ) -> None: """Initialize the COCO dataset. Args: data_root (str): Path to the root directory of the dataset. keys_to_load (tuple[str, ...]): Keys to load from the dataset. split (split): Which split to load. Default: "train2017". remove_empty (bool): Whether to remove images with no annotations. use_pascal_voc_cats (bool): Whether to use Pascal VOC categories. cache_as_binary (bool): Whether to cache the dataset as binary. Default: False. cached_file_path (str | None): Path to a cached file. If cached file exist then it will load it instead of generating the data mapping. Default: None. """ super().__init__(**kwargs) self.data_root = data_root self.keys_to_load = keys_to_load self.split = split self.remove_empty = remove_empty self.use_pascal_voc_cats = use_pascal_voc_cats # handling keys to load self.validate_keys(keys_to_load) self.load_annotations = ( K.boxes2d in keys_to_load or K.boxes2d_classes in keys_to_load or K.instance_masks in keys_to_load or K.seg_masks in keys_to_load ) self.data, _ = self._load_mapping( self._generate_data_mapping, self._filter_data, cache_as_binary=cache_as_binary, cached_file_path=cached_file_path, )
[docs] def __repr__(self) -> str: """Concise representation of the dataset.""" return ( f"COCODataset(root={self.data_root}, split={self.split}, " f"use_pascal_voc_cats={self.use_pascal_voc_cats})" )
def _filter_data(self, data: list[DictStrAny]) -> list[DictStrAny]: """Remove empty samples.""" if self.remove_empty: samples = [] for sample in data: if len(sample["anns"]) > 0: samples.append(sample) return samples return data def _generate_data_mapping(self) -> list[DictStrAny]: """Generate coco dataset mapping.""" annotation_file = os.path.join( self.data_root, "annotations", "instances_" + self.split + ".json" ) with contextlib.redirect_stdout(io.StringIO()): coco_api = COCOAPI(annotation_file) cat_ids = sorted(coco_api.getCatIds()) cats_map = {c["id"]: c["name"] for c in coco_api.loadCats(cat_ids)} if self.use_pascal_voc_cats: voc_cats = set(coco_seg_map.keys()) img_ids = sorted(coco_api.imgs.keys()) imgs = coco_api.loadImgs(img_ids) samples = [] for img_id, img in zip(img_ids, imgs): anns = coco_api.imgToAnns[img_id] if self.use_pascal_voc_cats: anns = [ ann for ann in anns if cats_map[ann["category_id"]] in voc_cats ] for ann in anns: cat_name = cats_map[ann["category_id"]] if self.use_pascal_voc_cats: ann["category_id"] = coco_seg_map[cat_name] else: ann["category_id"] = coco_det_map[cat_name] samples.append({"img_id": img_id, "img": img, "anns": anns}) return samples
[docs] def __len__(self) -> int: """Return length of dataset.""" return len(self.data)
[docs] def __getitem__(self, idx: int) -> DictData: """Transform coco sample to vis4d input format. Returns: DataDict[DataKeys, Union[torch.Tensor, Dict[Any]]] """ data = self.data[idx] img_h, img_w = data["img"]["height"], data["img"]["width"] dict_data: DictData = {} if K.images in self.keys_to_load: img_path = os.path.join( self.data_root, self.split, data["img"]["file_name"] ) im_bytes = self.data_backend.get(img_path) img = im_decode(im_bytes, mode=self.image_channel_mode) img_ = np.ascontiguousarray(img, dtype=np.float32)[None] assert (img_h, img_w) == img_.shape[ 1:3 ], "Image's shape doesn't match annotation." dict_data[K.sample_names] = data["img"]["id"] dict_data[K.images] = img_ dict_data[K.input_hw] = [img_h, img_w] if K.original_images in self.keys_to_load: dict_data[K.original_images] = img_ dict_data[K.original_hw] = [img_h, img_w] if self.load_annotations: boxes = [] classes = [] masks = [] for ann in data["anns"]: if K.boxes2d in self.keys_to_load: x1, y1, width, height = ann["bbox"] x2, y2 = x1 + width, y1 + height boxes.append((x1, y1, x2, y2)) if ( K.boxes2d in self.keys_to_load or K.boxes2d_classes in self.keys_to_load or K.seg_masks in self.keys_to_load ): classes.append(ann["category_id"]) if ( K.seg_masks in self.keys_to_load or K.instance_masks in self.keys_to_load ): mask_ann = ann.get("segmentation", None) if mask_ann is not None: if isinstance(mask_ann, list): rles = maskUtils.frPyObjects( mask_ann, img_h, img_w ) rle = maskUtils.merge(rles) elif isinstance(mask_ann["counts"], list): # uncompressed RLE rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) else: # RLE rle = mask_ann masks.append(maskUtils.decode(rle)) else: # pragma: no cover masks.append(np.empty((img_h, img_w), dtype=np.uint8)) box_tensor = ( np.empty((0, 4), dtype=np.float32) if not boxes else np.array(boxes, dtype=np.float32) ) mask_tensor = ( np.empty((0, img_h, img_w), dtype=np.uint8) if not masks else np.ascontiguousarray(masks, dtype=np.uint8) ) if K.boxes2d in self.keys_to_load: dict_data[K.boxes2d] = box_tensor if K.boxes2d_classes in self.keys_to_load: dict_data[K.boxes2d_classes] = np.array( classes, dtype=np.int64 ) if K.instance_masks in self.keys_to_load: dict_data[K.instance_masks] = mask_tensor if K.seg_masks in self.keys_to_load: seg_masks = ( mask_tensor * np.array(classes)[:, None, None] ).max(axis=0) seg_masks = seg_masks.astype(np.int64) seg_masks[mask_tensor.sum(0) > 1] = 255 # discard overlapped dict_data[K.seg_masks] = seg_masks[None] return dict_data