개발자로서 살아남기/YOLO - Colab 이용해서 Custom 학습하기

YOLO - Colab 이용해서 Custom 학습하기 (3)

코드 살인마 2020. 11. 29. 20:54
728x90

파이썬을 이용하는 다른 플랫폼에 적용시키기

darknet에서만 사용하는 것이 아닌 라즈베리파이, 웹사이트, 모바일 등에 적용할 수 있다.

준비물은 3가지이다.

  1. CFG 파일
  2. 훈련된 weight 파일
  3. names 파일

라이브러리와 YOLO 로드

import cv2
import numpy as np
from matplotlib import pyplot as plt

net = cv2.dnn.readNet("yolov4_19000.weights", "yolov4.cfg") #CFG 파일, weight 파일을 넣어준다.
classes = []
with open("food30.names", "rt",encoding = "UTF8") as f: #클래스 이름을 따로 저장해준다. 이 형식은 클래스가 한글이름 일 때 불러오는 방식이다.
    classes = [line.strip() for line in f.readlines()]
layer_names = net.getLayerNames()
output_layers = [layer_names[i[0] - 1] for i in net.getUnconnectedOutLayers()]
colors = np.random.uniform(0, 255, size=(len(classes), 3))

Detecting objects

blob = cv2.dnn.blobFromImage(img, 0.00392, (416, 416), (0, 0, 0), True, crop=True)     # 네트워크에 넣기 위한 전처리

net.setInput(blob)  # 전처리된 blob 네트워크에 입력
outs = net.forward(output_layers)     # 결과 받아오기

class_ids = []
confidences = []

boxes = []
for out in outs:
    for detection in out:
        scores = detection[5:]
        class_id = np.argmax(scores)
        if scores[class_id] != 0:
            print(class_id)
            print(scores[class_id])
        confidence = scores[class_id]
        if confidence > 0.6: # 임계값 설정 정확도가 60% 이상만
            # 탐지된 객체의 너비, 높이 및 중앙 좌표값 찾기
            center_x = int(detection[0] * width)
            center_y = int(detection[1] * height)
            #print(center_x,center_y)
            w = abs(int(detection[2] * width))
            h = abs(int(detection[3] * height))
            #print(w,h)
             # 객체의 사각형 테두리 중 좌상단 좌표값 찾기
            x = abs(int(center_x - w / 2))
            y = abs(int(center_y - h / 2))
            boxes.append([x, y, w, h])
            confidences.append(float(confidence))
            class_ids.append(class_id)

class index와 정확도가 출력된다. (여러 가지의 객체와 좌표가 검출된 것)

중복 제거

#indexes = cv2.dnn.NMSBoxes(boxes, confidences, 0.1, 0.1)    # Non Maximum Suppression (겹쳐있는 박스 중 confidence 가 가장 높은 박스를 선택)
#같은 index 중 확률이 가장 높은것

# bounding box 가 겹쳐도 여러개 검출 할 수 있도록 설정
class_list = list(set(class_ids))
idxx = []
indexes=[]
for i in range(len(class_list)):
    max_v=0
    for j in range(len(class_ids)):
        if class_ids[j] == class_list[i]:
            if max_v < confidences[j]:
                max_v = confidences[j]
                idxx.append(j)
    indexes.append(idxx[len(idxx)-1])            

print(class_ids)

주석처리된 부분은 같은 index 중 확률이 가장 높은 것만 뽑아주는 부분인데 겹친 부분에 다른 클래스가 잘 인식이 안 되는 문제가 발생하여
직접 만들어줬다.

class name 과 bounding box, 좌표 출력하기

font = cv2.FONT_HERSHEY_PLAIN
det_foods = []
for i in range(len(boxes)):
    if i in indexes:
        x, y, w, h = boxes[i]
        class_name = classes[class_ids[i]]
        print(class_name)
        print(confidences[i])
        label = f"{class_name} {boxes[i]}"
        det_foods.append(label)
        color = colors[i]
        print(x,y,x+w,y+h)
        # 사각형 테두리 그리기 및 텍스트 쓰기
        nx = (x + w) / width
        ny = (y + h) / height

        if nx > 1:
            nx = 1
        if ny > 1:
            ny =1
        cv2.rectangle(img, (x, y), (x + w, y + h), color, 2) #상하좌
        print(x/width,y/height,nx,ny)
        cv2.rectangle(img, (x - 1, y), (x + len(class_name)*13, y-12), color, -1)
        cv2.putText(img, class_name, (x, y - 4), cv2.FONT_HERSHEY_COMPLEX_SMALL, 0.5, (0, 0, 0), 1,cv2.LINE_AA)

b,g,r = cv2.split(img)
image2 = cv2.merge([r,g,b])
plt.imshow(image2)
#imShow(img)
plt.show()
print(det_foods)

아래의 그림과 같이 나오게 된다.

순서대로 클래스 이름, 클래스 이름일 확률, 좌표값, 비율좌표값