دسته‌بندی نشده

مکان یابی یک شی داخل عکس با پردازش تصویر

مکان-یابی-اشیا-داخل-تصویر-هوش-مصنوعی

در مطالب قبلی با نحوه دسته بندی تصاویر با استفاده از هوش مصنوعی آشنا شدیم. بطور مثال یک مجموعه داده شامل تصاویر حیوانات داریم و میخواهیم با آموزش دادن یک شبکه کاری کنیم که هر تصویر جدید به شبکه داده شد، آن را در کلاس مرتبطش قرار دهد.

اما در این مطلب، علاوه بر کلاس بندی تصاویر، قصد داریم هوش مصنوعی طراحی کنیم که مکان یابی شی در عکس را نیز انجام دهد. مثلا وقتی یک عکس گربه به سیستم میدهیم، نوع حیوان را تشخیص دهد و مکان قرار گیری گربه در تصویر را مشخص کند. (دور گربه کادر بکشد)

البته بطور پیشفرض در نظر میگیریم که فقط یک شی داخل تصویر وجود دارد. (با الگوریتم های پیشرفته تر مانند Yolo اشیا مختلف را با سرعت بیشتری در تصاویر پیدا کنیم)

مکان یابی و ساخت باندینگ‌باکس با روش رگرسیون

مانند گذشته برای پردازش تصاویر از شبکه عصبی کانولوشنال استفاده میکنیم. به یاد دارید که برای دسته بندی تصاویر با CNN بعد از چندلایه کانولوشن لایه آخر را بصورت فلت و به تعداد کلاس های موجود در نظر میگرفتیم، شبکه پس از آموزش دیدن با داده های Train در لایه آخر تشخیص میداد که تصویر داده شد مربوط به کدام دسته بندی میباشد(سگ است یا گربه)

حال برای مکان یابی در لایه آخر علاوه بر کلاس بندی ها، با استفاده از ویژگی های استخراج شده 4عدد دیگر را نیز بدست می آوریم که مشخص کننده مختصات جعبه محاطی یا همان باندینگ باکس ما هستند. (معمولا با نرمالایز کردن، اعداد بدست آمده بین 0 تا 1 هستند)

مکان-یابی-شی-با-هوش-مصنوعی

پس از پیش بینی اولیه، جعبه محاطی بدست آمده را با جعبه محاطی اصلی (موجود در داده های train) مقایسه میکنیم.

حال فقط کافی است که با محاسبه خطا و با بروزرسانی وزن ها در مرحله بک پروپگیشن میزان loss را کاهش دهیم و به وزن های بهینه برای شبکه برسیم؛

مکان-یابی-شی-در-عکس

برای محاسبه خطا اعداد پیش بینی شده توسط شبکه را با 4مختصات اصلی که در داده آموزشی وجود دارد مقایسه میکنیم، سپس از تابع هزینه مجموع مربعات خطا استفاده میکنیم و تلاش میکنیم که مقدار خطا را به نزدیکی صفر برسانیم.

کدنویسی پروژه

برای اجرای این پروژه از زبان پایتون و کتابخانه هوش مصنوعی پایتورچ استفاده شده است.

نکته: کدها و تصاویر استفاده شده در مطلب، از گیت‌هاب استادرضوی برداشته شده است (لینک)

ابتدا کتابخانه های مورد نیاز را وارد میکنیم:

%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import cv2
import time
import json
import numpy as np
import matplotlib.pyplot as plt

from glob import glob
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from utils import to_var
from train import train_model
from data_utils import create_validation_data
from vis_utils import imshow

use_gpu = torch.cuda.is_available()

توابع کمکی:

def get_model(model_name, num_classes, pretrained=True):
    return models.__dict__[model_name](pretrained)


def read_annotations(path):
    """ Read Bounding Boxes from a json file.
    """
    anno_classes = [f.split('_')[0] for f in os.listdir(path)]
    bb_json = {}
    
    for c in anno_classes:
        j = json.load(open(f'{path}/{c}_labels.json', 'r'))
        for l in j:
            if 'annotations' in l and len(l['annotations']) > 0:
                fname = l['filename'].split('/')[-1]
                bb_json[fname] = sorted(
                    l['annotations'], key=lambda x: x['height'] * x['width'])[-1]
    return bb_json


def bbox_to_r1c1r2c2(bbox):
    """ Convert BB from [h, w, x, y] to [r1, c1, r2, c2] format.
    """
    
    # extract h, w, x, y and convert to list
    bb = []
    bb.append(bbox['height'])
    bb.append(bbox['width'])
    bb.append(max(bbox['x'], 0))
    bb.append(max(bbox['y'], 0))
    
    # convert to float
    bb = [float(x) for x in bb]
    
    # convert to [r1, c1, r2, c2] format
    r1 = bb[3]
    c1 = bb[2]
    r2 = r1 + bb[0]
    c2 = c1 + bb[1]
    
    return [r1, c1, r2, c2]


def plot_bbox(img, bbox, w, h, color='red'):
    """ Plot bounding box on the image tensor. 
    """
    img = img.cpu().numpy().transpose((1, 2, 0))  # (H, W, C)
    
    # denormalize
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    
    # scale
    hs, ws = img.shape[:2]
    h_scale = h / hs
    w_scale = w / ws
    
    bb = np.array(bbox, dtype=np.float32)
    bx, by = bb[1], bb[0]
    bw = bb[3] - bb[1]
    bh = bb[2] - bb[0]
    
    bx *= w * w_scale
    by *= h * h_scale
    bw *= w * w_scale
    bh *= h * h_scale
    
    # scale image
    img = cv2.resize(img, (w, h))
    
    # create BB rectangle
    rect = plt.Rectangle((bx, by), bw, bh, color=color, fill=False, lw=3)
    
    # plot
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(img)
    plt.gca().add_patch(rect)
    plt.show()

نوبت به وارد کردن داده های اولیه رسیده است. داده های اولیه این پروژه مربوط به یک مسابقه اجرا شده در سایت kaggle هست. در این دیتاست تعداد زیادی تصاویر از قایق های ماهیگیری هنگام صید ماهی میباشد. ما قصد داریم نوع ماهی را مشخص کنیم و مکان آن را در عکس مشخص کنیم(بطور پیش فرض در نظر میگیریم که فقط یک ماهی داخل تصویر برای شناسایی است)

مکان-یابی-ماهی-در-عکس-مسابقه-کگل

این مسابقه 150هزار دلار جایزه داشته است!

DATA_DIR = "D:/datasets/kaggle/fish"

train_dir = f'{DATA_DIR}/train'
valid_dir = f'{DATA_DIR}/valid'
anno_dir = f'{DATA_DIR}/annotations'

sz = 299  # image size
bs = 32   # batch size
model_name = 'resnet34'
num_classes = 8

داده های مربوط به مختصات جعبه های محاطی در یک فایل json ذخیره شده است با کد زیر آن را فراخوانی میکنیم

bb_json = read_annotations(anno_dir)

کدهای زیر برای آشنایی با ساختار داده ها میباشد

print(os.listdir(DATA_DIR))

 

# all images for each fish class is in a separate directory
print(os.listdir(f'{DATA_DIR}/train'))

 

files = glob(f'{DATA_DIR}/train/ALB/*.*')
files[:5]

حال با اجرای کد زیر یکی از تصاویر موجود در دیتاست را به عنوان نمونه میبینیم:

Image.open(files[1])

برای مشاهده داده های مربوط به مختصات، از کد زیر استفاده میکنیم:

anno_files = os.listdir(anno_dir)
anno_files

ساخت داده های Valid برای اعتبارسنجی مدل:

if not os.path.exists(valid_dir):
    create_validation_data(train_dir, valid_dir, split=0.2, ext='jpg')

با تعریف کلاس و زیر کلاس های زیر و استفاده از آنها در مدل میتوانیم پردازش های مورد نیاز را روی دیتاست انجام دهیم

class FishDataset(Dataset):
    def __init__(self, ds, bboxes, sz=299):
        """ Prepare fish dataset
        
        Inputs:
            root: the directory which contains all required data such as images, labels, etc.
            ds: torchvision ImageFolder dataset.
            bboxes: a dictionary containing the coordinates of the bounding box in each images
            transforms: required transformations on each image
        """
        self.imgs = ds.imgs
        self.classes = ds.classes
        self.bboxes = bboxes
        self.sz = sz
        self.tfms = transforms.Compose([
            transforms.Resize((sz, sz)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def __getitem__(self, index):
        img, lbl = self.imgs[index]
        
        # get bounding box
        img_name = os.path.basename(img)
        if img_name in self.bboxes.keys():
            bbox = self.bboxes[img_name]
        else:
            bbox = {'class': 'rect', 'height': 0., 'width': 0., 'x': 0., 'y': 0.}
            
        # convert [h, w, x, y] to [r1, c1, r2, c2] format        
        bbox = bbox_to_r1c1r2c2(bbox)
        
        # read image and perform transformations
        image = Image.open(img).convert('RGB')
        w, h = image.size
        
        w_scale = sz / w
        h_scale = sz / h
        
        # transformations
        image = self.tfms(image)
        
        # normalize and scale bounding box
        bbox[0] = (bbox[0] / h) * h_scale
        bbox[1] = (bbox[1] / w) * w_scale
        bbox[2] = (bbox[2] / h) * h_scale
        bbox[3] = (bbox[3] / w) * w_scale
        
        # return image tensor, label tensor and bounding box tensor
        return image, lbl, torch.Tensor(bbox), (w, h)
    
    def __len__(self):
        return len(self.imgs)

ساخت دیتاست داده های آموزشی و Valid:

# training data
train_data = datasets.ImageFolder(train_dir)
train_ds = FishDataset(train_data, bb_json, sz=sz)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

# validation data
valid_data = datasets.ImageFolder(valid_dir)
valid_ds = FishDataset(valid_data, bb_json, sz=sz)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=False)

برای مشاهده تعدادی از داده ها قطعه کد زیر را اجرا میکنیم

dataiter = iter(train_dl)
imgs, lbls, bbs, sizes = next(dataiter)
img = torchvision.utils.make_grid(imgs, nrow=8)
plt.figure(figsize=(16, 8))
imshow(img, title='A random batch of training data')

دسته بندی و مکان یابی در تصاویر

Model

class ClassifierLocalizer(nn.Module):
    def __init__(self, model_name, num_classes=8):
        super(ClassifierLocalizer, self).__init__()
        self.num_classes = num_classes
        
        # create cnn model
        model = get_model(model_name, num_classes)
        
        # remove fc layers and add a new fc layer
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes + 4) # classifier + localizer
        self.model = model
    
    def forward(self, x):
        x = self.model(x)                    # extract features from CNN
        scores = x[:, :self.num_classes]     # class scores
        coords = x[:, self.num_classes:]     # bb corners coordinates
        return scores, F.sigmoid(coords)     # sigmoid output is in [0, 1]

تعریف تابع هزینه

با توجه به اینکه دسته بندی و مکان یابی را همزمان انجام میدهیم، نیاز به دو تابع هزینه متفاوت داریم

  • تابع هزینه Cross Entropy برای دسته بندی
  • تابع هزینه مجموع مربعات خطا برای رگرسیون (مختصات جعبه)
class LocalizationLoss(nn.Module):
    def __init__(self, num_classes=8):
        super(LocalizationLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(size_average=False)
        self.mse_loss = nn.MSELoss(size_average=False)
        
    def forward(self, scores, locs, labels, bboxes):
        # Cross Entropy (for classification)
        loss_cls = self.ce_loss(scores, labels)
        
        # Sum of Squared errors (for corner points)
        loss_r1 = self.mse_loss(locs[:, 0], bboxes[:, 0]) / 2.0
        loss_c1 = self.mse_loss(locs[:, 1], bboxes[:, 1]) / 2.0
        loss_r2 = self.mse_loss(locs[:, 2], bboxes[:, 2]) / 2.0
        loss_c2 = self.mse_loss(locs[:, 3], bboxes[:, 3]) / 2.0
        
        return loss_cls, loss_r1 + loss_c1 + loss_r2 + loss_c2
model = ClassifierLocalizer(model_name)
if use_gpu: model = model.cuda()
    
criterion = LocalizationLoss()
if use_gpu: criterion = criterion.cuda()
    
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

و پس از اجرای کد زیر مدل شروع به آموزش دیدن میکند. (تعداد epochها را میتوانید تغییر دهید)

model = train_model(model, train_dl, valid_dl, criterion, optimizer, scheduler, num_epochs=1)

کد زیر نیز برای پیش بینی مکان جعبه محاطی میباشد:

valid_dl = DataLoader(valid_ds, batch_size=1, shuffle=True)
imgs, lbls, bbs, sizes = next(iter(valid_dl))
scores, locs = model(to_var(imgs, volatile=True))

scores = scores.data.cpu().numpy()
locs = locs.data.cpu().numpy()

pred_lbl = np.argmax(scores, axis=1)[0]
pred_bb = locs[0].tolist()

print(pred_lbl, ':', valid_ds.classes[pred_lbl])
w, h = sizes[0].numpy(), sizes[1].numpy()

plot_bbox(imgs[0], pred_bb, w, h)

مکان یابی یک شی در تصویر با هوش مصنوعی

  1. یکبار دیگر مراحل انجام شده را با هم مرور میکنیم:
  2. از یک شبکه کانولوشنال مانند رزنت استفاده میکنیم
  3. لایه های فولی کانکتد آن را حذف میکنیم
  4. آخرین لایه کانولوشن را خودمان فلت میکنیم و یک سری ویژگی از شبکه استخراج میکنیم
  5. ویژگی های استخراج شده را به تعداد کلاس های مورد نیاز تبدیل میکنیم(اگر امتیاز هر کلاس بالاتر باشد عکس متعلق به آن دسته‌بندی است)
  6. علاوه بر آن ها به تعدادی خروجی دیگر برای پیش ینی مختصات جعبه محاطی و مکان یابی آن نیاز داریم
  7. با train کردن یک شبکه CNN هر دو مسئله (کلاس‌بندی، مکان‌یابی) را همزمان انجام داده می‌شود؛
author-avatar

درباره محمد اسماعیلی

علاقه مند به مفاهیم هوش مصنوعی، دیتاساینس و سئو؛ مطالبی که برام جالب باشه رو اینجا می نویسم، و این دلیل بر متخصص بودن من در اون حوزه ها نمیشه😊

نوشته های مرتبط

دیدگاهتان را بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *