mmsegmentationのv1系でモデルを訓練する方法

プログラミング

2024/01/08

前提・課題

  • mmsegmentation のv1系を使って、セマンティックセグメンテーションモデルを訓練したい
  • 既存の解説記事にはmmsegmentation v0系を使ったものは多いが、v1系(mmcv>=2.0.0)のものが見つからなかった
  • この記事ではカスタムデータでの学習方法は扱わない。追って記事にする予定
  • この記事では以下の環境を使用している
    • Ubuntu 22.04 LTS
    • Python-3.10.4
    • CUDA Toolkit 12.3
    • mmsegmentation 1.2.2

方法

公式リポジトリのdemo/MMSegmentation_Tutorial.ipynb にチュートリアルがあるので、それを参考にコードを書く

mmsegmentation/demo/MMSegmentation_Tutorial.ipynb at main · open-mmlab/mmsegmentation
OpenMMLabSemanticSegmentationToolboxandBenchmark.-open-mmlab/mmsegmentation

インストール

mmsegmentationと依存関係のインストールを行うsetup.shを作成

# !/bin/bash

# torch
python3 -m pip install --upgrade pip wheel setuptools \
    torch==2.1.1 \
    torchvision==0.16.1 \
    --index-url https://download.pytorch.org/whl/cu121

# mmsegmentation
python3 -m pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"
python3 -m pip install "mmsegmentation>=1.0.0" ftfy regex

実行

chmod +x setup.sh
./setup.sh

データの準備

データをダウンロードして展開

# データをダウンロード
wget http://dags.stanford.edu/data/iccv09Data.tar.gz -O stanford_background.tar.gz
# 展開。iccv09Dataが生成される
tar -zxf stanford_background.tar.gz

ラベル画像作成とデータ分割を行うmake_data_iccv09.py を作成する

import os.path as osp

import mmengine
import numpy as np
from PIL import Image

data_root = "iccv09Data"
img_dir = "images"
ann_dir = "labels"

palette = [
    [128, 128, 128],
    [129, 127, 38],
    [120, 69, 125],
    [53, 125, 34],
    [0, 11, 123],
    [118, 20, 12],
    [122, 81, 25],
    [241, 134, 51],
]

# convert dataset annotation to semantic segmentation map
for file in mmengine.scandir(osp.join(data_root, ann_dir), suffix=".regions.txt"):
    seg_map = np.loadtxt(osp.join(data_root, ann_dir, file)).astype(np.uint8)
    seg_img = Image.fromarray(seg_map).convert("P")
    seg_img.putpalette(np.array(palette, dtype=np.uint8))
    seg_img.save(osp.join(data_root, ann_dir, file.replace(".regions.txt", ".png")))


# split train/val set randomly
split_dir = "splits"
mmengine.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [
    osp.splitext(filename)[0]
    for filename in mmengine.scandir(osp.join(data_root, ann_dir), suffix=".png")
]
with open(osp.join(data_root, split_dir, "train.txt"), "w") as f:
    # select first 4/5 as train set
    train_length = int(len(filename_list) * 4 / 5)
    f.writelines(line + "\n" for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, "val.txt"), "w") as f:
    # select last 1/5 as train set
    f.writelines(line + "\n" for line in filename_list[train_length:])

実行。iccv09Data/splitsiccv09Data/labelsに画像が生成される

python make_data_iccv09.py

学習

モデルの学習を行うtrain_iccv09.py を作成する

import os

from mmseg.registry import DATASETS
from mmseg.datasets import BaseSegDataset
from mmengine import Config
from mmengine.runner import Runner

data_root = "iccv09Data"
img_dir = "images"
ann_dir = "labels"
classes = (
    "sky",
    "tree",
    "road",
    "grass",
    "water",
    "bldg",
    "mntn",
    "fg obj",
)
palette = [
    [128, 128, 128],
    [129, 127, 38],
    [120, 69, 125],
    [53, 125, 34],
    [0, 11, 123],
    [118, 20, 12],
    [122, 81, 25],
    [241, 134, 51],
]


@DATASETS.register_module()
class StanfordBackgroundDataset(BaseSegDataset):
    METAINFO = dict(classes=classes, palette=palette)

    def __init__(self, **kwargs):
        super().__init__(img_suffix=".jpg", seg_map_suffix=".png", **kwargs)


def make_config(dir_save: str, path_config: str, path_model: str):
    cfg = Config.fromfile(path_config)
    print(f"Config:\n{cfg.pretty_text}")

    # Since we use only one GPU, BN is used instead of SyncBN
    cfg.norm_cfg = dict(type="BN", requires_grad=True)
    cfg.crop_size = (256, 256)
    cfg.model.data_preprocessor.size = cfg.crop_size
    cfg.model.backbone.norm_cfg = cfg.norm_cfg
    cfg.model.decode_head.norm_cfg = cfg.norm_cfg
    cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
    # modify num classes of the model in decode/auxiliary head
    cfg.model.decode_head.num_classes = 8
    cfg.model.auxiliary_head.num_classes = 8

    # Modify dataset type and path
    cfg.dataset_type = "StanfordBackgroundDataset"
    cfg.data_root = data_root

    cfg.train_dataloader.batch_size = 8

    cfg.train_pipeline = [
        dict(type="LoadImageFromFile"),
        dict(type="LoadAnnotations"),
        dict(
            type="RandomResize",
            scale=(320, 240),
            ratio_range=(0.5, 2.0),
            keep_ratio=True,
        ),
        dict(type="RandomCrop", crop_size=cfg.crop_size, cat_max_ratio=0.75),
        dict(type="RandomFlip", prob=0.5),
        dict(type="PackSegInputs"),
    ]

    cfg.test_pipeline = [
        dict(type="LoadImageFromFile"),
        dict(type="Resize", scale=(320, 240), keep_ratio=True),
        # add loading annotation after ``Resize`` because ground truth
        # does not need to do resize data transform
        dict(type="LoadAnnotations"),
        dict(type="PackSegInputs"),
    ]

    cfg.train_dataloader.dataset.type = cfg.dataset_type
    cfg.train_dataloader.dataset.data_root = cfg.data_root
    cfg.train_dataloader.dataset.data_prefix = dict(
        img_path=img_dir, seg_map_path=ann_dir
    )
    cfg.train_dataloader.dataset.pipeline = cfg.train_pipeline
    cfg.train_dataloader.dataset.ann_file = "splits/train.txt"

    cfg.val_dataloader.dataset.type = cfg.dataset_type
    cfg.val_dataloader.dataset.data_root = cfg.data_root
    cfg.val_dataloader.dataset.data_prefix = dict(
        img_path=img_dir, seg_map_path=ann_dir
    )
    cfg.val_dataloader.dataset.pipeline = cfg.test_pipeline
    cfg.val_dataloader.dataset.ann_file = "splits/val.txt"

    cfg.test_dataloader = cfg.val_dataloader

    # Load the pretrained weights
    cfg.load_from = path_model

    # Set up working dir to save files and logs.
    cfg.work_dir = dir_save

    cfg.train_cfg.max_iters = 200
    cfg.train_cfg.val_interval = 200
    cfg.default_hooks.logger.interval = 10
    cfg.default_hooks.checkpoint.interval = 200

    # Set seed to facilitate reproducing the result
    cfg["randomness"] = dict(seed=0)

    # Let's have a look at the final config used for training
    print(f"Config:\n{cfg.pretty_text}")

    return cfg


def main(dir_save, path_config, path_model):
    cfg = make_config(dir_save, path_config, path_model)
    runner = Runner.from_cfg(cfg)
    # start training
    runner.train()


if __name__ == "__main__":
    dir_model = "checkpoints"
    config = "pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py"
    model = "pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth"
    dir_save = "./save/tutorial"

    path_config = os.path.join(dir_model, config)
    path_model = os.path.join(dir_model, model)

    main(dir_save, path_config, path_model)

事前学習済みモデルと設定ファイルをダウンロードし、学習を実行

# modelとconfig ダウンロード。checkpointsディレクトリが作成されファイルが保存される
mim download mmsegmentation --config pspnet_r50-d8_4xb2-40k_cityscapes-512x1024 --dest checkpoints

# 学習実行
python train_iccv09.py

  • save/tutorial ディレクトリが作成され、モデルファイルやログが保存される
  • 筆者の環境(GPUあり)では1分程度かかった

推論

推論を行うpredict_iccv09.pyを作成する

import os

import matplotlib.pyplot as plt
import mmcv
from mmseg.apis import init_model, inference_model, show_result_pyplot

from train_iccv09 import make_config


def main(dir_save, path_config, path_model):
    cfg = make_config(dir_save, path_config, path_model)
    model = init_model(cfg, path_model, "cuda:0")

    path_img = "iccv09Data/images/6000124.jpg"
    path_save = "out.jpg"

    img = mmcv.imread(path_img)
    result = inference_model(model, img)

    plt.figure(figsize=(8, 6))
    vis_result = show_result_pyplot(model, mmcv.bgr2rgb(img), result, show=False)
    plt.imshow(vis_result)
    plt.savefig(path_save)


if __name__ == "__main__":
    dir_model = "checkpoints"
    config = "pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py"
    dir_save = "./save/tutorial"

    path_config = os.path.join(dir_model, config)
    path_model = os.path.join(dir_save, "iter_200.pth")

    main(dir_save, path_config, path_model)
  • 推論結果を描画する mmseg.apisのshow_result_pyplot()

実行

python predict_iccv09.py
  • 推論結果がout.jpgとして保存される

参考記事

mmsegmentation/demo/MMSegmentation_Tutorial.ipynb
https://github.com/open-mmlab/mmsegmentation/blob/main/demo/MMSegmentation_Tutorial.ipynb
mmsegmentation v1系でのカスタムデータセット学習チュートリアル

openmimを使ったMMDetectionのconfigファイルとcheckpointsファイルのダウンロード方法 2022
https://qiita.com/nyakiri_0726/items/6581e16924400ccc24bf
mim download コマンドでmodelとconfigをダウンロードできる

コメント

タイトルとURLをコピーしました