做任务领q币网站,申请账号注册,wordpress媒体库图片分类,新浪sae wordpress略缩图设置摘要#xff1a; 记录MindSpore AI框架使用FCN全卷积网络理解图像进行图像语议分割的过程、步骤和方法。包括环境准备、下载数据集、数据集加载和预处理、构建网络、训练准备、模型训练、模型评估、模型推理等。
一、概念
1.语义分割
图像语义分割
semantic segmentation …摘要 记录MindSpore AI框架使用FCN全卷积网络理解图像进行图像语议分割的过程、步骤和方法。包括环境准备、下载数据集、数据集加载和预处理、构建网络、训练准备、模型训练、模型评估、模型推理等。
一、概念
1.语义分割
图像语义分割
semantic segmentation 图像处理 机器视觉 图像理解 AI领域重要分支 应用 人脸识别 物体检测 医学影像 卫星图像分析 自动驾驶感知 目的 图像每个像素点分类 输出与输入大小相同的图像 输出图像的每个像素对应了输入图像每个像素的类别 图像领域语义 图像的内容 对图片意思的理解
实例 2.FCN全卷积网络
Fully Convolutional Networks
图像语义分割框架 2015年UC Berkeley提出 端到端(end to end)像素级(pixel level)预测全卷积网络 全卷积神经网络主要使用三种技术
1.卷积化Convolutional
VGG-16 FCN的backbone 输入224*224RGB图像 固定大小的输入 丢弃了空间坐标 产生非空间输出 输出1000个预测值
卷积层 输出二维矩阵 生成输入图片映射的heatmap 2.上采样Upsample
卷积过程 卷积操作 池化操作
特征图尺寸变小
上采样操作 得到原图大小的稠密图像预测
双线性插值参数
初始化上采样逆卷积参数
反向传播学习非线性上采样 3.跳跃结构Skip Layer
将深层的全局信息与浅层的局部信息相结合 底层stride 32的预测FCN-32s 2倍上采样
融合相加 pool4层stride 16的预测FCN-16s 2倍上采样
融合相加 pool3层stride 8的预测FCN-8s 特点:
(1)不含全连接层(fc)的全卷积(fully conv)网络可适应任意尺寸输入。
(2)增大数据尺寸的反卷积(deconv)层能够输出精细的结果。
(3)结合不同深度层结果的跳级(skip)结构同时确保鲁棒性和精确性。
二、环境准备
%%capture captured_output
# 实验环境已经预装了mindspore2.2.14如需更换mindspore版本可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
输出
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contactmindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by:
三、数据处理
1.下载数据集
from download import download
url https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar
download(url, ./dataset, kindtar, replaceTrue)
输出
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar (537.2 MB)file_sizes: 100%|█████████████████████████████| 563M/563M [00:0300:00, 160MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset
./dataset
2.数据预处理
PASCAL VOC 2012数据集图像分辨率不一致 标准化处理
3.数据加载
混合PASCAL VOC 2012数据集、SDB数据集
import numpy as np
import cv2
import mindspore.dataset as ds
class SegDataset:def __init__(self,image_mean,image_std,data_file,batch_size32,crop_size512,max_scale2.0,min_scale0.5,ignore_label255,num_classes21,num_readers2,num_parallel_calls4):
self.data_file data_fileself.batch_size batch_sizeself.crop_size crop_sizeself.image_mean np.array(image_mean, dtypenp.float32)self.image_std np.array(image_std, dtypenp.float32)self.max_scale max_scaleself.min_scale min_scaleself.ignore_label ignore_labelself.num_classes num_classesself.num_readers num_readersself.num_parallel_calls num_parallel_callsmax_scale min_scale
def preprocess_dataset(self, image, label):image_out cv2.imdecode(np.frombuffer(image, dtypenp.uint8), cv2.IMREAD_COLOR)label_out cv2.imdecode(np.frombuffer(label, dtypenp.uint8), cv2.IMREAD_GRAYSCALE)sc np.random.uniform(self.min_scale, self.max_scale)new_h, new_w int(sc * image_out.shape[0]), int(sc * image_out.shape[1])image_out cv2.resize(image_out, (new_w, new_h), interpolationcv2.INTER_CUBIC)label_out cv2.resize(label_out, (new_w, new_h), interpolationcv2.INTER_NEAREST)
image_out (image_out - self.image_mean) / self.image_stdout_h, out_w max(new_h, self.crop_size), max(new_w, self.crop_size)pad_h, pad_w out_h - new_h, out_w - new_wif pad_h 0 or pad_w 0:image_out cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value0)label_out cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, valueself.ignore_label)offset_h np.random.randint(0, out_h - self.crop_size 1)offset_w np.random.randint(0, out_w - self.crop_size 1)image_out image_out[offset_h: offset_h self.crop_size, offset_w: offset_w self.crop_size, :]label_out label_out[offset_h: offset_h self.crop_size, offset_w: offset_wself.crop_size]if np.random.uniform(0.0, 1.0) 0.5:image_out image_out[:, ::-1, :]label_out label_out[:, ::-1]image_out image_out.transpose((2, 0, 1))image_out image_out.copy()label_out label_out.copy()label_out label_out.astype(int32)return image_out, label_out
def get_dataset(self):ds.config.set_numa_enable(True)dataset ds.MindDataset(self.data_file, columns_list[data, label],shuffleTrue, num_parallel_workersself.num_readers)transforms_list self.preprocess_datasetdataset dataset.map(operationstransforms_list, input_columns[data, label],output_columns[data, label],num_parallel_workersself.num_parallel_calls)dataset dataset.shuffle(buffer_sizeself.batch_size * 10)dataset dataset.batch(self.batch_size, drop_remainderTrue)return dataset
# 定义创建数据集的参数
IMAGE_MEAN [103.53, 116.28, 123.675]
IMAGE_STD [57.375, 57.120, 58.395]
DATA_FILE dataset/dataset_fcn8s/mindname.mindrecord
# 定义模型训练参数
train_batch_size 4
crop_size 512
min_scale 0.5
max_scale 2.0
ignore_label 255
num_classes 21
# 实例化Dataset
dataset SegDataset(image_meanIMAGE_MEAN,image_stdIMAGE_STD,data_fileDATA_FILE,batch_sizetrain_batch_size,crop_sizecrop_size,max_scalemax_scale,min_scalemin_scale,ignore_labelignore_label,num_classesnum_classes,num_readers2,num_parallel_calls4)
dataset dataset.get_dataset()
4.训练集可视化
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize(16, 8))
# 对训练集中的数据进行展示
for i in range(1, 9):plt.subplot(2, 4, i)show_data next(dataset.create_dict_iterator())show_images show_data[data].asnumpy()show_images np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示plt.imshow(show_images[0].transpose(1, 2, 0))plt.axis(off)plt.subplots_adjust(wspace0.05, hspace0)
plt.show()
输出 四、网络构建
FCN网络流程 输入图像image pool1池化 尺寸变为原始尺寸的1/2 pool2池化 尺寸变为原始尺寸的1/4 pool3池化 尺寸变为原始尺寸的1/8 pool4池化 尺寸变为原始尺寸的1/16 pool5池化 尺寸变为原始尺寸的1/32 conv6-7卷积 输出尺寸原图的1/32 FCN-32s 反卷积扩大到原始尺寸 FCN-16s 融合 conv7反卷积尺寸扩大两倍至原图的1/16 pool4特征图 反卷积扩大到原始尺寸 FCN-8s 融合 conv7反卷积尺寸扩大4倍 pool4特征图反卷积扩大2倍 pool3特征图 反卷积扩大到原始尺寸 构建FCN-8s网络代码
import mindspore.nn as nn
class FCN8s(nn.Cell):def __init__(self, n_class):super().__init__()self.n_class n_classself.conv1 nn.SequentialCell(nn.Conv2d(in_channels3, out_channels64,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(in_channels64, out_channels64,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(64),nn.ReLU())self.pool1 nn.MaxPool2d(kernel_size2, stride2)self.conv2 nn.SequentialCell(nn.Conv2d(in_channels64, out_channels128,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(in_channels128, out_channels128,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(128),nn.ReLU())self.pool2 nn.MaxPool2d(kernel_size2, stride2)self.conv3 nn.SequentialCell(nn.Conv2d(in_channels128, out_channels256,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels256, out_channels256,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels256, out_channels256,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(256),nn.ReLU())self.pool3 nn.MaxPool2d(kernel_size2, stride2)self.conv4 nn.SequentialCell(nn.Conv2d(in_channels256, out_channels512,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels512, out_channels512,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels512, out_channels512,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(512),nn.ReLU())self.pool4 nn.MaxPool2d(kernel_size2, stride2)self.conv5 nn.SequentialCell(nn.Conv2d(in_channels512, out_channels512,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels512, out_channels512,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels512, out_channels512,kernel_size3, weight_initxavier_uniform),nn.BatchNorm2d(512),nn.ReLU())self.pool5 nn.MaxPool2d(kernel_size2, stride2)self.conv6 nn.SequentialCell(nn.Conv2d(in_channels512, out_channels4096,kernel_size7, weight_initxavier_uniform),nn.BatchNorm2d(4096),nn.ReLU(),)self.conv7 nn.SequentialCell(nn.Conv2d(in_channels4096, out_channels4096,kernel_size1, weight_initxavier_uniform),nn.BatchNorm2d(4096),nn.ReLU(),)self.score_fr nn.Conv2d(in_channels4096, out_channelsself.n_class,kernel_size1, weight_initxavier_uniform)self.upscore2 nn.Conv2dTranspose(in_channelsself.n_class, out_channelsself.n_class,kernel_size4, stride2, weight_initxavier_uniform)self.score_pool4 nn.Conv2d(in_channels512, out_channelsself.n_class,kernel_size1, weight_initxavier_uniform)self.upscore_pool4 nn.Conv2dTranspose(in_channelsself.n_class, out_channelsself.n_class,kernel_size4, stride2, weight_initxavier_uniform)self.score_pool3 nn.Conv2d(in_channels256, out_channelsself.n_class,kernel_size1, weight_initxavier_uniform)self.upscore8 nn.Conv2dTranspose(in_channelsself.n_class, out_channelsself.n_class,kernel_size16, stride8, weight_initxavier_uniform)
def construct(self, x):x1 self.conv1(x)p1 self.pool1(x1)x2 self.conv2(p1)p2 self.pool2(x2)x3 self.conv3(p2)p3 self.pool3(x3)x4 self.conv4(p3)p4 self.pool4(x4)x5 self.conv5(p4)p5 self.pool5(x5)x6 self.conv6(p5)x7 self.conv7(x6)sf self.score_fr(x7)u2 self.upscore2(sf)s4 self.score_pool4(p4)f4 s4 u2u4 self.upscore_pool4(f4)s3 self.score_pool3(p3)f3 s3 u4out self.upscore8(f3)return out
五、训练准备
1.导入VGG-16部分预训练权重
from download import download
from mindspore import load_checkpoint, load_param_into_net
url https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt
download(url, fcn8s_vgg16_pretrain.ckpt, replaceTrue)
def load_vgg16():ckpt_vgg16 fcn8s_vgg16_pretrain.ckptparam_vgg load_checkpoint(ckpt_vgg16)load_param_into_net(net, param_vgg)
输出
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt (513.2 MB)file_sizes: 100%|█████████████████████████████| 538M/538M [00:0300:00, 179MB/s]
Successfully downloaded file to fcn8s_vgg16_pretrain.ckpt
2.损失函数
交叉熵损失函数
mindspore.nn.CrossEntropyLoss()
计算FCN网络输出与mask之间的交叉熵损失
3.自定义评价指标 Metrics
用于评估模型效果
设共有 K1个类 从 到 其中包含一个空类或背景
表示本属于i类但被预测为j类的像素数量
表示真正的数量
则分别被解释为假正和假负 Pixel Accuracy
PA像素精度 标记正确的像素占总像素的比例。 Mean Pixel Accuracy
MPA均像素精度
计算每个类内正确分类像素数的比例
求所有类的平均 Mean Intersection over Union
MloU均交并比 语义分割的标准度量 计算两个集合的交集和并集之比 交集为真实值ground truth) 并集为预测值predicted segmentation 两者之比正真数 (intersection) /真正假负假正并集 在每个类上计算loU 平均 Frequency Weighted Intersection over Union
FWIoU频权交井比
根据每个类出现的频率设置权重 import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train
class PixelAccuracy(train.Metric):def __init__(self, num_class21):super(PixelAccuracy, self).__init__()self.num_class num_class
def _generate_matrix(self, gt_image, pre_image):mask (gt_image 0) (gt_image self.num_class)label self.num_class * gt_image[mask].astype(int) pre_image[mask]count np.bincount(label, minlengthself.num_class**2)confusion_matrix count.reshape(self.num_class, self.num_class)return confusion_matrix
def clear(self):self.confusion_matrix np.zeros((self.num_class,) * 2)
def update(self, *inputs):y_pred inputs[0].asnumpy().argmax(axis1)y inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix self._generate_matrix(y, y_pred)
def eval(self):pixel_accuracy np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()return pixel_accuracy
class PixelAccuracyClass(train.Metric):def __init__(self, num_class21):super(PixelAccuracyClass, self).__init__()self.num_class num_class
def _generate_matrix(self, gt_image, pre_image):mask (gt_image 0) (gt_image self.num_class)label self.num_class * gt_image[mask].astype(int) pre_image[mask]count np.bincount(label, minlengthself.num_class**2)confusion_matrix count.reshape(self.num_class, self.num_class)return confusion_matrix
def update(self, *inputs):y_pred inputs[0].asnumpy().argmax(axis1)y inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix self._generate_matrix(y, y_pred)
def clear(self):self.confusion_matrix np.zeros((self.num_class,) * 2)
def eval(self):mean_pixel_accuracy np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis1)mean_pixel_accuracy np.nanmean(mean_pixel_accuracy)return mean_pixel_accuracy
class MeanIntersectionOverUnion(train.Metric):def __init__(self, num_class21):super(MeanIntersectionOverUnion, self).__init__()self.num_class num_class
def _generate_matrix(self, gt_image, pre_image):mask (gt_image 0) (gt_image self.num_class)label self.num_class * gt_image[mask].astype(int) pre_image[mask]count np.bincount(label, minlengthself.num_class**2)confusion_matrix count.reshape(self.num_class, self.num_class)return confusion_matrix
def update(self, *inputs):y_pred inputs[0].asnumpy().argmax(axis1)y inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix self._generate_matrix(y, y_pred)
def clear(self):self.confusion_matrix np.zeros((self.num_class,) * 2)
def eval(self):mean_iou np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis1) np.sum(self.confusion_matrix, axis0) -np.diag(self.confusion_matrix))mean_iou np.nanmean(mean_iou)return mean_iou
class FrequencyWeightedIntersectionOverUnion(train.Metric):def __init__(self, num_class21):super(FrequencyWeightedIntersectionOverUnion, self).__init__()self.num_class num_class
def _generate_matrix(self, gt_image, pre_image):mask (gt_image 0) (gt_image self.num_class)label self.num_class * gt_image[mask].astype(int) pre_image[mask]count np.bincount(label, minlengthself.num_class**2)confusion_matrix count.reshape(self.num_class, self.num_class)return confusion_matrix
def update(self, *inputs):y_pred inputs[0].asnumpy().argmax(axis1)y inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix self._generate_matrix(y, y_pred)
def clear(self):self.confusion_matrix np.zeros((self.num_class,) * 2)
def eval(self):freq np.sum(self.confusion_matrix, axis1) / np.sum(self.confusion_matrix)iu np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis1) np.sum(self.confusion_matrix, axis0) -np.diag(self.confusion_matrix))
frequency_weighted_iou (freq[freq 0] * iu[freq 0]).sum()return frequency_weighted_iou
六、模型训练
导入VGG-16预训练参数
实例化损失函数、优化器
Model接口编译网络
训练FCN-8s网络
import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model
device_target Ascend
mindspore.set_context(modemindspore.PYNATIVE_MODE, device_targetdevice_target)
train_batch_size 4
num_classes 21
# 初始化模型结构
net FCN8s(n_class21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr 0.0005
base_lr 0.05
train_epochs 1
iters_per_epoch dataset.get_dataset_size()
total_step iters_per_epoch * train_epochs
lr_scheduler mindspore.nn.cosine_decay_lr(min_lr,base_lr,total_step,iters_per_epoch,decay_epoch2)
lr Tensor(lr_scheduler[-1])
# 定义损失函数
loss nn.CrossEntropyLoss(ignore_index255)
# 定义优化器
optimizer nn.Momentum(paramsnet.trainable_params(), learning_ratelr, momentum0.9, weight_decay0.0001)
# 定义loss_scale
scale_factor 4
scale_window 3000
loss_scale_manager ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target Ascend:model Model(net, loss_fnloss, optimizeroptimizer, loss_scale_managerloss_scale_manager, metrics{pixel accuracy: PixelAccuracy(), mean pixel accuracy: PixelAccuracyClass(), mean IoU: MeanIntersectionOverUnion(), frequency weighted IoU: FrequencyWeightedIntersectionOverUnion()})
else:model Model(net, loss_fnloss, optimizeroptimizer, metrics{pixel accuracy: PixelAccuracy(), mean pixel accuracy: PixelAccuracyClass(), mean IoU: MeanIntersectionOverUnion(), frequency weighted IoU: FrequencyWeightedIntersectionOverUnion()})
# 设置ckpt文件保存的参数
time_callback TimeMonitor(data_sizeiters_per_epoch)
loss_callback LossMonitor()
callbacks [time_callback, loss_callback]
save_steps 330
keep_checkpoint_max 5
config_ckpt CheckpointConfig(save_checkpoint_steps10,keep_checkpoint_maxkeep_checkpoint_max)
ckpt_callback ModelCheckpoint(prefixFCN8s,directory./ckpt,configconfig_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbackscallbacks)
输出
epoch: 1 step: 1, loss is 3.0504844
epoch: 1 step: 2, loss is 3.017057
epoch: 1 step: 3, loss is 2.9523003
epoch: 1 step: 4, loss is 2.9488814
epoch: 1 step: 5, loss is 2.666231
epoch: 1 step: 6, loss is 2.7145326
epoch: 1 step: 7, loss is 1.796408
epoch: 1 step: 8, loss is 1.5167583
epoch: 1 step: 9, loss is 1.6862022
epoch: 1 step: 10, loss is 2.4622822
......
epoch: 1 step: 1141, loss is 1.70966
epoch: 1 step: 1142, loss is 1.434751
epoch: 1 step: 1143, loss is 2.406475
Train epoch time: 762889.258 ms, per step time: 667.445 ms
七、模型评估
IMAGE_MEAN [103.53, 116.28, 123.675]
IMAGE_STD [57.375, 57.120, 58.395]
DATA_FILE dataset/dataset_fcn8s/mindname.mindrecord
# 下载已训练好的权重文件
url https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt
download(url, FCN8s.ckpt, replaceTrue)
net FCN8s(n_classnum_classes)
ckpt_file FCN8s.ckpt
param_dict load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
if device_target Ascend:model Model(net, loss_fnloss, optimizeroptimizer, loss_scale_managerloss_scale_manager, metrics{pixel accuracy: PixelAccuracy(), mean pixel accuracy: PixelAccuracyClass(), mean IoU: MeanIntersectionOverUnion(), frequency weighted IoU: FrequencyWeightedIntersectionOverUnion()})
else:model Model(net, loss_fnloss, optimizeroptimizer, metrics{pixel accuracy: PixelAccuracy(), mean pixel accuracy: PixelAccuracyClass(), mean IoU: MeanIntersectionOverUnion(), frequency weighted IoU: FrequencyWeightedIntersectionOverUnion()})
# 实例化Dataset
dataset SegDataset(image_meanIMAGE_MEAN,image_stdIMAGE_STD,data_fileDATA_FILE,batch_sizetrain_batch_size,crop_sizecrop_size,max_scalemax_scale,min_scalemin_scale,ignore_labelignore_label,num_classesnum_classes,num_readers2,num_parallel_calls4)
dataset_eval dataset.get_dataset()
model.eval(dataset_eval)
输出
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt (1.00 GB)file_sizes: 100%|██████████████████████████| 1.08G/1.08G [00:1000:00, 99.7MB/s]
Successfully downloaded file to FCN8s.ckpt
/
{pixel accuracy: 0.9734831394168291,mean pixel accuracy: 0.9423324801371116,mean IoU: 0.8961453779807752,frequency weighted IoU: 0.9488883312345654}
八、模型推理
使用训练的网络对模型推理结果进行展示。
import cv2
import matplotlib.pyplot as plt
net FCN8s(n_classnum_classes)
# 设置超参
ckpt_file FCN8s.ckpt
param_dict load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size 4
img_lst []
mask_lst []
res_lst []
# 推理效果展示(上方为输入图片下方为推理效果图片)
plt.figure(figsize(8, 5))
show_data next(dataset_eval.create_dict_iterator())
show_images show_data[data].asnumpy()
mask_images show_data[label].reshape([4, 512, 512])
show_images np.clip(show_images, 0, 1)
for i in range(eval_batch_size):img_lst.append(show_images[i])mask_lst.append(mask_images[i])
res net(show_data[data]).asnumpy().argmax(axis1)
for i in range(eval_batch_size):plt.subplot(2, 4, i 1)plt.imshow(img_lst[i].transpose(1, 2, 0))plt.axis(off)plt.subplots_adjust(wspace0.05, hspace0.02)plt.subplot(2, 4, i 5)plt.imshow(res[i])plt.axis(off)plt.subplots_adjust(wspace0.05, hspace0.02)
plt.show()
输出 九、总结
FCN 使用全卷积层 通过学习让图片实现端到端分割。 优点 输入接受任意大小的图像 高效避免了由于使用像素块而带来的重复存储和计算卷积的问题。 待改进之处 结果不够精细。比较模糊和平滑边界处细节不敏感。 像素分类没有考虑像素之间的关系如不连续性和相似性 忽略空间规整spatial regularization步骤缺乏空间一致性。