인공지능(AI)

tensorflow keras 활용한 손글씨 맞추기 with Flask

sysman 2020. 12. 20. 22:49

프로젝트 소개

  1. 프로젝트 기획 목표
  • AI를 이용하여 손글씨를 맞추고 웹으로 개발하여 모든사람이 이용가능하게 할 것으로 목표
  1. 프로젝트 내용
    mnist 학습 후 가중치값 파일로 저장하기
  1. colab을 이용하여 mnist를 가져옴
  2. 데이터셋 전처리
  3. 라벨링 인코딩
  4. 모델링 설계
  5. 모델 컴파일
  6. 모델 fit으로 학습하기
  7. matplotlib으로 정확도, 로스율 확인
  8. 모델 평가
  9. 모델링 h5 파일로 저장
    1. python flask 설치
    2. 만들어놓은 h5 파일 로딩
    3. javascript로 만든 canvas를 이용하여 그린 데이터를 가져옴
    4. 가져온 데이터를 new_model에 predict하여 argmax로 가장 높은 숫자를 출력
    5. 해당 index 결과값을 web homepase에 출력
  10. 웹서버 구성
  1. 프로젝트 발전가능성
  • 아이들 교육용 또는 사진이나 글자 맞추기 등 활용분야가 넓을 것으로 생각됩니다.
  • 예를들면, 집 모양의 그림을 학습시켜 집을 그리면 집의 사진이 나오게 하거나, 사람 그림을 학습시켜 사람사진이 나오게 할 수 있다.
  • 필체를 학습시켜 그 사람의 필체가 어느정도 정확한지 확인하여 범인을 잡는데에도 사용 가능하다.
    20210909mnist_colab_excute.ipynb
    0.01MB

 

mlserver.py
0.00MB
predict.html
0.00MB

 

mlserver.py

# 관련 라이브러리 불러옴
from flask import Flask
from flask import render_template
from flask import request
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import models, layers
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
#from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.keras import backend as K

app = Flask(__name__ , static_url_path='')

#root에서 predict.html 파일 랜더링
@app.route("/")
def root(name=None):
    return render_template('predict.html', name=name)


@app.route("/predict", methods=['POST', 'GET'])

#예측값 함수 정의
def predict():
#폼의 이미지를 canvas에 대입
    canvas = request.form['images']
    print( canvas)
    #mnist 데이터를 로딩함
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    print(train_images.shape)
    print(train_labels.shape)
# 이미지 데이터 전처리
    test_images=test_images.reshape((10000,28,28,1))
    test_images=test_images.astype('float32')/255
    test_labels[0]
    test_labels=to_categorical(test_labels)

    print(test_labels[0])
    print(test_labels.shape)
# 미리 만들어 놓은 가중치 파일을 로딩
    new_model=tf.keras.models.load_model('./mnistCNN.h5')
    new_model.summary()
    new_model.evaluate(x= test_images, y=test_labels)

#캔버스값들을 소수형으로 변환
    lcanvas = canvas.split()
    lcanvas = [float(v) for v in lcanvas]

#리쉐입해서 new_model에 집어넣어 predict하고(정확한 값이 가장 높은 수가 나옴) np.argmax로 가장 높은 숫자를 예측값으로  p_val로 저장
    predict=new_model.predict(np.array(lcanvas).reshape(1,28,28,1))
    print(predict)
    K.clear_session()
    p_val=np.argmax(predict)
    print(p_val)
    # 이미 그래프가 있을 경우 중복이 될 수 있기 때문에, 기존 그래프를 모두 리셋한다.
    #tf.reset_default_graph()    
    print('\nreload has been done\n')
    print(lcanvas)
   # p_val은 예측한 결과값을 predict.html에 p_val변수로 출력 
    return render_template('predict.html', result=p_val)


if __name__ == "__main__":
    app.run(host='0.0.0.0')

 

predict.html

<!DOCTYPE html>
<html>
<head>
    <meta http-equiv="Content-Type" content="text/html; charset="utf-8">
    <title>Deep Learning mnist -jiwon</title>
	<script src="/js/jquery-3.1.1.min.js" type="text/javascript"></script>
    <script type="text/Javascript">


var mousePressed = false;
var lastX, lastY;
var ctx;


function InitThis() {
    ctx = document.getElementById('canvas').getContext("2d");

    $('#canvas').mousedown(function (e) {
        mousePressed = true;
        Draw(e.pageX - $(this).offset().left, e.pageY - $(this).offset().top, false);
    });

    $('#canvas').mousemove(function (e) {
        if (mousePressed) {
            Draw(e.pageX - $(this).offset().left, e.pageY - $(this).offset().top, true);
        }
    });

    $('#canvas').mouseup(function (e) {
        mousePressed = false;
    });

	    $('#canvas').mouseleave(function (e) {
        mousePressed = false;
    });
}


function Draw(x, y, isDown) {
    if (isDown) {
        ctx.beginPath();
        ctx.strokeStyle = "#000000";
        ctx.lineWidth = 5;
        ctx.lineJoin = "round";
        ctx.moveTo(lastX, lastY);
        ctx.lineTo(x, y);
        ctx.closePath();
        ctx.stroke();
    }
    lastX = x; lastY = y;
}


    var pixels = [];
    for (var i = 0; i < 28*28; i++) pixels[i] = 0
    var click = 0;
    var result = ""
    var canvas = document.getElementById("canvas");


    function clear_value(){
        canvas.getContext("2d").fillStyle = "rgb(255,255,255)";
        canvas.getContext("2d").fillRect(0, 0, 140, 140);
        for (var i = 0; i < 28*28; i++) pixels[i] = 0
    }


    function _submit() {

        var imgdata = ctx.getImageData(1, 1, 140, 140).data;
		console.log("imgdata:"+imgdata.length);
		console.log(imgdata);

		var count = 0;
		var rgb = 0;

		for (var i = 0; i < 28; i++) {
        	for (var j = 0; j < 28; j++) {
            	rgb = GetPixel(j,i);
				result += (rgb + " ");
        	}
		}
		//console.log(result);
        document.getElementById("images").value = result;
        document.getElementById("pform").submit();

    }


    function GetPixel(x, y)
    {
        var p = ctx.getImageData(x*5, y*5, 5, 5);
		for(var i=0; i<100; i++){
			if(p.data[i] == 0) {
				continue;
			}else{
				return 1;
			}
		}
		return 0;
    }


    function rgbToHex(r, g, b) {
        if (r > 255 || g > 255 || b > 255)
            throw "Invalid color component";
        return ((r << 16) | (g << 8) | b).toString(16);
    }



</script>
</head>
<body onload="InitThis();">
<h3><strong>please draw between 1 to 9<strong> </h3><br>
<form id="pform" action="/predict" method="POST" enctype="multipart/form-data">
<table>
<td style="border-style: none;">
<div style="border: solid 2px #666; width: 143px; height: 144px;">
<canvas id="canvas" width="140" height="140" ></canvas>
</div></td>
<td style="border-style: none;"><br>
<button onclick="clear_value()">Clear</button>
<button onclick="javascript:_submit()">submit</button>
<input type="hidden" id="images" name="images">


</td>
</table>
</form>
<hr>
<h1>AI Result : {{result}}</h1>
</body>
</html>

 

깃허브 참고

https://github.com/tech-picnic/ai_mnist_web/blob/master/README.md

 

GitHub - tech-picnic/ai_mnist_web: AI cnn flask WEB Develop

AI cnn flask WEB Develop . Contribute to tech-picnic/ai_mnist_web development by creating an account on GitHub.

github.com

 

웹서버 : Flask

AI : tensorflow Keras

 

flask를 이용해서 CNN 으로 손글씨 맞추는 것을 만들어 보았다.

 

노트북 성능이 딸려서 가끔 kernel dead로 죽어 버리는 것만 제외 하면 나름 재미 있었다.

 

그리고 가끔 틀리기도 한다. ㅎㅎ

 

집에서 웹서버 켜놓고 다른사람들 들어와서 해보라고 하니 신기해 하긴 한다.

 

 

방법 :

1. mnist 파일을 가져와서 모델링 한다음 최적화된 가중치 값을 만든다.

2. 그리고 그 값과 모델을 h5 확장자 파일로 저장한다.

3. flask 웹서버와 캔버스를 만들고 최적화된 모델값을 불러온다.

4. 그 값을 캔버스에 그린 파일을 리스트화 시켜 모델값을 predict 하면 결과가 나온다.