🐐 Flask를 사용하여 TensorFlow를 사용하기

프로잭트 배경

Flask를 사용하여 Python Web서비스를 구축을 하고
웹상에서 사진을 첨부하여 Yolo v5를 이용한 릴Chip의 Object Detection과
TensorFlow를 이용하여 LED정상 및 비정상을 판단하는 프로젝트 였습니다.

해당글은 TensorFlow를 이용한 내용만 담습니다.


구성 환경

필자는 아래와 같은 환경에서 작성하였습니다.

CPU : i7 - 10750H
그래픽카드 : RTX2060
Ram : 32GB
OS : Windows 10
Python : 3.9.7


tensorFlow 설치

pip install tensorflow

gpu를 사용하고 싶다면 아래와 같이 설치하면 됩니다.

pip install tensorflow-gpu

CUDA Toolkit 설치

CUDA Toolkit은 GPU를 사용하기 위한 필수적인 요소입니다.
필자는 tensorFlow 2.10.0 버전이므로
CUDA 11.7를, cuDNN 8.6.0을 사용하였습니다.

requirements.txt

구현한 환경은 이렇습니다.

// requirements.txt
absl-py==1.3.0
aniso8601==9.0.1
asttokens==2.1.0
astunparse==1.6.3
attrs==22.1.0
awscli==1.26.5
backcall==0.2.0
boto3==1.25.4
botocore==1.28.5
cachetools==5.2.0
certifi==2022.9.24
charset-normalizer==2.1.1
click==8.1.3
colorama==0.4.4
contourpy==1.0.5
cycler==0.11.0
decorator==5.1.1
docker==6.0.0
docutils==0.16
executing==1.2.0
Flask==2.2.2
Flask-Cors==3.0.10
flask-marshmallow==0.14.0
Flask-RESTful==0.3.9
Flask-SQLAlchemy==3.0.0
flatbuffers==22.9.24
fonttools==4.37.4
gast==0.4.0
google-auth==2.12.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
greenlet==1.1.3
grpcio==1.49.1
h5py==3.7.0
idna==3.4
importlib-metadata==5.0.0
iniconfig==1.1.1
ipython==8.6.0
itsdangerous==2.1.2
jaraco.classes==3.2.3
jedi==0.18.1
Jinja2==3.1.2
jmespath==1.0.1
keras==2.10.0
Keras-Preprocessing==1.1.2
keyring==8.7
keyrings.alt==4.2.0
kiwisolver==1.4.4
libclang==14.0.6
Markdown==3.4.1
MarkupSafe==2.1.1
marshmallow==3.18.0
matplotlib==3.6.0
matplotlib-inline==0.1.6
more-itertools==9.0.0
mysql-connector-python==8.0.30
numpy==1.23.4
oauthlib==3.2.1
opencv-python==4.6.0.66
opt-einsum==3.3.0
packaging==21.3
pandas==1.5.1
parso==0.8.3
pickleshare==0.7.5
Pillow==9.3.0
pip==22.3
pluggy==1.0.0
prompt-toolkit==3.0.31
protobuf==3.19.6
psutil==5.9.3
psycopg2-binary==2.9.3
pure-eval==0.2.2
py==1.11.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.13.0
pymssql==2.2.5
PyMySQL==1.0.2
pyparsing==3.0.9
pytest==7.1.3
python-dateutil==2.8.2
pytz==2022.4
pywin32==304
pywin32-ctypes==0.2.0
PyYAML==5.4.1
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.7.2
ruamel.yaml.clib==0.2.7
ruamel.yaml==0.17.21
s3transfer==0.6.0
scipy==1.9.3
seaborn==0.12.1
setuptools==57.4.0
six==1.16.0
SQLAlchemy==1.4.41
stack-data==0.6.0
tabulate==0.9.0
tensorboard==2.10.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.10.0
tensorflow-estimator==2.10.0
tensorflow-io-gcs-filesystem==0.27.0
termcolor==2.0.1
thop
tomli==2.0.1
torch==1.13.0
torchaudio==0.13.0
torchvision==0.14.0
tqdm==4.64.1
traitlets==5.5.0
typing_extensions==4.4.0
urllib3==1.26.12
voluptuous==0.13.1
wcwidth==0.2.5
websocket-client==1.4.1
Werkzeug==2.2.2
wheel==0.37.1
wrapt==1.14.1
yolo==0.3.1
zipp==3.10.0

TensorFlow를 이용하여 학습하기

Pass와 fail 이미지를 디렉토리에 담고
TensorFlow를 사용하여 학습을 진행했습니다.

import datetime
from distutils.command.upload import upload
from logging import exception
from tkinter.messagebox import showinfo
from xmlrpc.client import DateTime
from flask import render_template
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import cv2
from flask_sqlalchemy import Model
from sqlalchemy import create_engine, text, Column, Integer, String, DateTime, ForeignKey, Table
from sqlalchemy.orm import relationship, backref, sessionmaker, scoped_session
import numpy as np
import matplotlib.pyplot as plt
import pymysql
import pickle
import tensorflow as tf
import pandas as pd
from keras.optimizers import RMSprop
from keras.preprocessing import image
from keras.utils import load_img, img_to_array
from keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from scipy import ndimage


def LED_train():
    try : 
        # Sequential 모델 생성부분은 중요하다.
        model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(126, 150, 3)),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
        ])

        model.summary()

        model.compile(optimizer=RMSprop(lr=0.001),
                loss='binary_crossentropy',
                metrics = ['accuracy'])

        train_datagen = ImageDataGenerator( rescale = 1.0/255. )
        test_datagen  = ImageDataGenerator( rescale = 1.0/255. )

        # Flow training images in batches of 20 using train_datagen generator
        train_generator = train_datagen.flow_from_directory(
                    '학습데이터 디렉토리 설정 /train',  # This is the source directory for training images
                    target_size=(150, 126),  # All images will be resized to 150x150
                    batch_size=12,
                    # Since we use binary_crossentropy loss, we need binary labels
                    class_mode='binary',
                    shuffle=True)

        # Flow validation images in batches of 20 using test_datagen generator
        validation_generator =  test_datagen.flow_from_directory(
                    '학습결과를 확인하는 디렉토리 설정 /validation', # This is the source directory for training images
                    target_size=(150, 126),  # All images will be resized to 150x150
                    batch_size=12,
                    # Since we use binary_crossentropy loss, we need binary labels
                    class_mode='binary',
                    shuffle=True)


        history = model.fit(
                train_generator,           
                steps_per_epoch=200,  # chage back to 100 # 데이터 1000개를 가지고 학습을 100개씩 묶어서 학습을 10번 반복 성능하고 무상관
                epochs=100, #change back to 20
                verbose=2,
                validation_data=validation_generator,
                validation_steps=1)  # 50 Images = batch_size * steps

        # 훈련 정확도와 검증 정확도
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']
        print(history.history.keys())
        print(str(acc) + str(val_acc))
        print("**************************************")

        # 훈련 손실과 검증 손실
        loss = history.history['loss']
        val_loss = history.history['val_loss']

        epochs = range(len(acc))

        # filename is today's date
        today = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        # Save Model
        model.save('LED_Model_SD_'+str(today)+'.h5', save_format='h5')
        print("Model Saved")
        return True, 'LED_Model_SD_'+str(today)+'.h5 is saved'
    except Exception as e:
        print(e)
        today = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        return False, 'LED_Model_SD_'+str(today)+'.h5 is not saved'
    

predict하기

Flask를 통하여 전달된 이미지 파일을 받아서 predict를 진행합니다.


from flask import render_template
import tensorflow as tf
tf.config.list_physical_devices('CPU')
import numpy as np
import os
from keras.optimizers import RMSprop
from keras.utils import load_img, img_to_array
from keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import datetime

def led_predict(file_path):
    try :
        modelFile = 'LED21.h5'
        model = tf.keras.models.load_model(modelFile)
        img_size = 300
        img = load_img(file_path)
        if img_size:
            print("input img size : ", img.size)
            scale = float(img_size) / max(img.size)
            new_size = (int(np.ceil(scale * img.size[1]+1)), int(np.ceil(scale * img.size[0])))
            img = img.resize(new_size, resample=Image.BILINEAR)
            print("input img resize : ", img.size)
        img = img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = np.vstack([img])
        classes = model.predict(img , batch_size=128)
        print(classes[0])

        discrimination = "" # 판별 결과
        if classes[0]>0:
            print(file_path + " is a True")
            discrimination = "T"
        else:
            print(file_path + " is a False")
            discrimination = "F"

        createDate = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

        return discrimination, createDate, modelFile, file_path
    except Exception as e:
        print(e)
        return render_template('error.html', error=e)

Note: 만들고나니 내것이 아니었다.

Leave a comment