2024/01/08
前提・課題
- mmsegmentation のv1系を使って、セマンティックセグメンテーションモデルを訓練したい
- 既存の解説記事にはmmsegmentation v0系を使ったものは多いが、v1系(mmcv>=2.0.0)のものが見つからなかった
- この記事ではカスタムデータでの学習方法は扱わない。追って記事にする予定
- この記事では以下の環境を使用している
- Ubuntu 22.04 LTS
- Python-3.10.4
- CUDA Toolkit 12.3
- mmsegmentation 1.2.2
方法
公式リポジトリのdemo/MMSegmentation_Tutorial.ipynb にチュートリアルがあるので、それを参考にコードを書く
mmsegmentation/demo/MMSegmentation_Tutorial.ipynb at main · open-mmlab/mmsegmentation
OpenMMLabSemanticSegmentationToolboxandBenchmark.-open-mmlab/mmsegmentation
インストール
mmsegmentationと依存関係のインストールを行うsetup.shを作成
# !/bin/bash
# torch
python3 -m pip install --upgrade pip wheel setuptools \
torch==2.1.1 \
torchvision==0.16.1 \
--index-url https://download.pytorch.org/whl/cu121
# mmsegmentation
python3 -m pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"
python3 -m pip install "mmsegmentation>=1.0.0" ftfy regex実行
chmod +x setup.sh
./setup.shデータの準備
データをダウンロードして展開
# データをダウンロード
wget http://dags.stanford.edu/data/iccv09Data.tar.gz -O stanford_background.tar.gz
# 展開。iccv09Dataが生成される
tar -zxf stanford_background.tar.gzラベル画像作成とデータ分割を行うmake_data_iccv09.py を作成する
import os.path as osp
import mmengine
import numpy as np
from PIL import Image
data_root = "iccv09Data"
img_dir = "images"
ann_dir = "labels"
palette = [
[128, 128, 128],
[129, 127, 38],
[120, 69, 125],
[53, 125, 34],
[0, 11, 123],
[118, 20, 12],
[122, 81, 25],
[241, 134, 51],
]
# convert dataset annotation to semantic segmentation map
for file in mmengine.scandir(osp.join(data_root, ann_dir), suffix=".regions.txt"):
seg_map = np.loadtxt(osp.join(data_root, ann_dir, file)).astype(np.uint8)
seg_img = Image.fromarray(seg_map).convert("P")
seg_img.putpalette(np.array(palette, dtype=np.uint8))
seg_img.save(osp.join(data_root, ann_dir, file.replace(".regions.txt", ".png")))
# split train/val set randomly
split_dir = "splits"
mmengine.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [
osp.splitext(filename)[0]
for filename in mmengine.scandir(osp.join(data_root, ann_dir), suffix=".png")
]
with open(osp.join(data_root, split_dir, "train.txt"), "w") as f:
# select first 4/5 as train set
train_length = int(len(filename_list) * 4 / 5)
f.writelines(line + "\n" for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, "val.txt"), "w") as f:
# select last 1/5 as train set
f.writelines(line + "\n" for line in filename_list[train_length:])
- 元データではラベルがテキストで表現されており、それをインデックスカラーの画像として保存する
- インデックスカラーとは、画像のピクセルに色を入れるのではなく、色定義テーブルの参照番号を入れるというような方法らしい。 https://ja.wikipedia.org/wiki/%E3%82%A4%E3%83%B3%E3%83%87%E3%83%83%E3%82%AF%E3%82%B9%E3%82%AB%E3%83%A9%E3%83%BC
- インデックスカラーの画像でも、VSCodeやビューアーで画像を表示すると普通に色がついているように見える
- インデックスカラーの色設定に使う色はRGBで定義しているはず https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.putpalette
実行。iccv09Data/splitsとiccv09Data/labelsに画像が生成される
python make_data_iccv09.py学習
モデルの学習を行うtrain_iccv09.py を作成する
import os
from mmseg.registry import DATASETS
from mmseg.datasets import BaseSegDataset
from mmengine import Config
from mmengine.runner import Runner
data_root = "iccv09Data"
img_dir = "images"
ann_dir = "labels"
classes = (
"sky",
"tree",
"road",
"grass",
"water",
"bldg",
"mntn",
"fg obj",
)
palette = [
[128, 128, 128],
[129, 127, 38],
[120, 69, 125],
[53, 125, 34],
[0, 11, 123],
[118, 20, 12],
[122, 81, 25],
[241, 134, 51],
]
@DATASETS.register_module()
class StanfordBackgroundDataset(BaseSegDataset):
METAINFO = dict(classes=classes, palette=palette)
def __init__(self, **kwargs):
super().__init__(img_suffix=".jpg", seg_map_suffix=".png", **kwargs)
def make_config(dir_save: str, path_config: str, path_model: str):
cfg = Config.fromfile(path_config)
print(f"Config:\n{cfg.pretty_text}")
# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type="BN", requires_grad=True)
cfg.crop_size = (256, 256)
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8
# Modify dataset type and path
cfg.dataset_type = "StanfordBackgroundDataset"
cfg.data_root = data_root
cfg.train_dataloader.batch_size = 8
cfg.train_pipeline = [
dict(type="LoadImageFromFile"),
dict(type="LoadAnnotations"),
dict(
type="RandomResize",
scale=(320, 240),
ratio_range=(0.5, 2.0),
keep_ratio=True,
),
dict(type="RandomCrop", crop_size=cfg.crop_size, cat_max_ratio=0.75),
dict(type="RandomFlip", prob=0.5),
dict(type="PackSegInputs"),
]
cfg.test_pipeline = [
dict(type="LoadImageFromFile"),
dict(type="Resize", scale=(320, 240), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type="LoadAnnotations"),
dict(type="PackSegInputs"),
]
cfg.train_dataloader.dataset.type = cfg.dataset_type
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.train_dataloader.dataset.data_prefix = dict(
img_path=img_dir, seg_map_path=ann_dir
)
cfg.train_dataloader.dataset.pipeline = cfg.train_pipeline
cfg.train_dataloader.dataset.ann_file = "splits/train.txt"
cfg.val_dataloader.dataset.type = cfg.dataset_type
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_prefix = dict(
img_path=img_dir, seg_map_path=ann_dir
)
cfg.val_dataloader.dataset.pipeline = cfg.test_pipeline
cfg.val_dataloader.dataset.ann_file = "splits/val.txt"
cfg.test_dataloader = cfg.val_dataloader
# Load the pretrained weights
cfg.load_from = path_model
# Set up working dir to save files and logs.
cfg.work_dir = dir_save
cfg.train_cfg.max_iters = 200
cfg.train_cfg.val_interval = 200
cfg.default_hooks.logger.interval = 10
cfg.default_hooks.checkpoint.interval = 200
# Set seed to facilitate reproducing the result
cfg["randomness"] = dict(seed=0)
# Let's have a look at the final config used for training
print(f"Config:\n{cfg.pretty_text}")
return cfg
def main(dir_save, path_config, path_model):
cfg = make_config(dir_save, path_config, path_model)
runner = Runner.from_cfg(cfg)
# start training
runner.train()
if __name__ == "__main__":
dir_model = "checkpoints"
config = "pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py"
model = "pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth"
dir_save = "./save/tutorial"
path_config = os.path.join(dir_model, config)
path_model = os.path.join(dir_model, model)
main(dir_save, path_config, path_model)
事前学習済みモデルと設定ファイルをダウンロードし、学習を実行
# modelとconfig ダウンロード。checkpointsディレクトリが作成されファイルが保存される
mim download mmsegmentation --config pspnet_r50-d8_4xb2-40k_cityscapes-512x1024 --dest checkpoints
# 学習実行
python train_iccv09.pysave/tutorialディレクトリが作成され、モデルファイルやログが保存される- 筆者の環境(GPUあり)では1分程度かかった
推論
推論を行うpredict_iccv09.pyを作成する
import os
import matplotlib.pyplot as plt
import mmcv
from mmseg.apis import init_model, inference_model, show_result_pyplot
from train_iccv09 import make_config
def main(dir_save, path_config, path_model):
cfg = make_config(dir_save, path_config, path_model)
model = init_model(cfg, path_model, "cuda:0")
path_img = "iccv09Data/images/6000124.jpg"
path_save = "out.jpg"
img = mmcv.imread(path_img)
result = inference_model(model, img)
plt.figure(figsize=(8, 6))
vis_result = show_result_pyplot(model, mmcv.bgr2rgb(img), result, show=False)
plt.imshow(vis_result)
plt.savefig(path_save)
if __name__ == "__main__":
dir_model = "checkpoints"
config = "pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py"
dir_save = "./save/tutorial"
path_config = os.path.join(dir_model, config)
path_model = os.path.join(dir_save, "iter_200.pth")
main(dir_save, path_config, path_model)
- 推論結果を描画する mmseg.apisの
show_result_pyplot()は- channelがRGB形式のnp.ndarrayを返す
- 入力画像はstrかnp.ndarrayで、strの場合は
mmcv.imread(img, channel_order='rgb')でRGB形式で読み込むことになっているので、np.ndarrayを渡すときもchannelはRGBで渡したほうが良さそう。https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/apis/inference.py#L164
実行
python predict_iccv09.py- 推論結果がout.jpgとして保存される
参考記事
mmsegmentation/demo/MMSegmentation_Tutorial.ipynb
https://github.com/open-mmlab/mmsegmentation/blob/main/demo/MMSegmentation_Tutorial.ipynb
mmsegmentation v1系でのカスタムデータセット学習チュートリアル
openmimを使ったMMDetectionのconfigファイルとcheckpointsファイルのダウンロード方法 2022
https://qiita.com/nyakiri_0726/items/6581e16924400ccc24bfmim download コマンドでmodelとconfigをダウンロードできる

コメント