我试图使用Fast API和Tensorflow创建一个简单的图像识别API


介绍

我通常会经常使用Flask,但是我的熟人说"快速API很好!",所以我决定制作一个简单的图像识别API。
但是,我没有看到很多有关FastAPI和ML的日语文章,因此我决定创建这篇文章,而不是一份备忘!

在本文中,准备了开发环境之后,将对API服务器和前端进行简要说明。

这次使用的所有代码都发布在Github上。
(以下实现的文件夹结构在Github的前提下进行了描述。示例模型的下载也在README.md中进行了描述。)

什么是FastAPI?

Flask之类的Python框架之一。

有关简单概述和使用方法的摘要,请参阅以下文章。 (也非常感谢您在本文中的帮助!)

https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9

对于那些想了解更多详细信息的人,我们建议使用官方的FastAPI教程!

https://fastapi.tiangolo.com/tutorial/

关于图像识别

这次我没有时间,所以我将使用tensorflow.keras模型进行构建!

具体来说,我们将按原样使用imagenet学习的ResNet50,并推断输入图像属于1000个类别中的哪个类别。

(我真正想使用的模型没有及时进行广受好评的学习……)

https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras?hl=ja

开发环境

Mac OS X Mojave
Python3.7.1(Anaconda)

环境

安装所需的Python库。

1
2
3
$pip install tensorflow==1.15
$pip install fastapi
$pip install uvicorn

由于存在以下情况,因此也请安装必要的库。
--Render index.html
-上传图片文件
-加载图像并调整大小

1
2
3
4
$pip install Jinja
$pip install aiofiles
$pip install python-multipart
$pip install opencv-python

API服务器

API服务器的实现如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# -*- coding: utf-8 -*-
import io
from typing import List

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import decode_predictions
from fastapi import FastAPI, Request, File, UploadFile
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

# 画像認識モデルの用意
global model, graph
graph = tf.get_default_graph()
model = tf.keras.models.load_model("./static/model/resnet_imagenet.h5")

# FastAPIの用意
app = FastAPI()

# static/js/post.jsをindex.htmlから呼び出すために必要
app.mount("/static", StaticFiles(directory="static"), name="static")

# templates配下に格納したindex.htmlをrenderするために必要
templates = Jinja2Templates(directory="templates")


def read_image(bin_data, size=(224, 224)):
    """画像を読み込む

    Arguments:
        bin_data {bytes} -- 画像のバイナリデータ

    Keyword Arguments:
        size {tuple} -- リサイズしたい画像サイズ (default: {(224, 224)})

    Returns:
        numpy.array -- 画像
    """
    file_bytes = np.asarray(bytearray(bin_data.read()), dtype=np.uint8)
    img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, size)
    return img


@app.post("/api/image_recognition")
async def image_recognition(files: List[UploadFile] = File(...)):
    """画像認識API

    Keyword Arguments:
        files {List[UploadFile]} -- アップロードされたファイル情報 (default: {File(...)})

    Returns:
        dict -- 推論結果
    """
    bin_data = io.BytesIO(files[0].file.read())
    img = read_image(bin_data)
    with graph.as_default():
        pred = model.predict(np.expand_dims(img, axis=0))
        result_label = decode_predictions(pred, top=1)[0][0][1]
        return {"response": result_label}


@app.get("/")
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

从前台接收数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@app.post("/api/image_recognition")
async def image_recognition(files: List[UploadFile] = File(...)):
    """画像認識API

    Keyword Arguments:
        files {List[UploadFile]} -- アップロードされたファイル情報 (default: {File(...)})

    Returns:
        dict -- 推論結果
    """
    bin_data = io.BytesIO(files[0].file.read())
    img = read_image(bin_data)
    with graph.as_default():
        pred = model.predict(np.expand_dims(img, axis=0))
        result_label = decode_predictions(pred, top=1)[0][0][1]
        return {"response": result_label}

这次,我们使用快速API上传文件来获取POST图像。

1
bin_data = io.BytesIO(files[0].file.read())

由于仅发布了一个

文件,因此将其设置为文件[0],并且由于从正面以BASE64格式传递了该文件,因此已在API端将其转换为Bytes数组。

将数据转换为图像

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def read_image(bin_data, size=(224, 224)):
    """画像を読み込む

    Arguments:
        bin_data {bytes} -- 画像のバイナリデータ

    Keyword Arguments:
        size {tuple} -- リサイズしたい画像サイズ (default: {(224, 224)})

    Returns:
        numpy.array -- 画像
    """
    file_bytes = np.asarray(bytearray(bin_data.read()), dtype=np.uint8)
    img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, size)
    return img

借助opencv,将字节数组转换为uint8映像。
此时,由于opencv的默认格式为BGR,因此我将其转换为RGB并调整了大小。

推断

1
2
3
4
5
6
7
8
9
global model, graph
graph = tf.get_default_graph()
model = tf.keras.models.load_model("./static/model/resnet_imagenet.h5")

...

with graph.as_default():
        pred = model.predict(np.expand_dims(img, axis=0))
        result_label = decode_predictions(pred, top=1)[0][0][1]

我预先创建了resnet_imagenet.h5,并在文件顶部读取了它。
通过将此线程中的上下文固定到使用graph.as_default()全局设置的TensorFlow图,可通过预测函数来推断推理过程本身。

由于这一次我们使用的是来自tf.keras的ResNet50,因此我们使用的是encode_predictions,将predict的结果转换为标签以获取推断结果。

我认为,可以通过将.h5文件保存在项目目录中的某个位置并对其进行load_model来使用其他类似的模型和自制模型。

正面安装

我以此为参考。 (谢谢!)

https://qiita.com/katsunory/items/9bf9ee49ee5c08bf2b3d

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
<html>
<head>
    <meta http-qeuiv="Content-Type" content="text/html; charset=utf-8">
    Fastapi 画像認識テスト
    <script src="//code.jquery.com/jquery-2.2.3.min.js"></script>
    <script src="/static/js/post.js"></script>
</head>

<body>

<!-- ファイル選択ボタン -->
<div style="width: 500px">
  <form enctype="multipart/form-data" method="post">
    <input type="file" name="userfile" accept="image/*">
  </form>
</div>

<!-- 画像表示領域 -->
<canvas id="canvas" width="0" height="0"></canvas>

<!-- アップロード開始ボタン -->
<button class="btn btn-primary" id="post">投稿</button>
<br>
<h2 id="result"></h2>
</body>
</html>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// 画像をリサイズして、HTMLで表示する
$(function () {
  var file = null;
  var blob = null;
  const RESIZED_WIDTH = 300;
  const RESIZED_HEIGHT = 300;

  $("input[type=file]").change(function () {
    file = $(this).prop("files")[0];

    // ファイルチェック
    if (file.type != "image/jpeg" && file.type != "image/png") {
      file = null;
      blob = null;
      return;
    }

    var result = document.getElementById("result");
    result.innerHTML = "";

    // 画像をリサイズする
    var image = new Image();
    var reader = new FileReader();
    reader.onload = function (e) {
      image.onload = function () {
        var width, height;

        // 縦or横の長い方に合わせてリサイズする
        if (image.width > image.height) {
          var ratio = image.height / image.width;
          width = RESIZED_WIDTH;
          height = RESIZED_WIDTH * ratio;
        } else {
          var ratio = image.width / image.height;
          width = RESIZED_HEIGHT * ratio;
          height = RESIZED_HEIGHT;
        }

        var canvas = $("#canvas").attr("width", width).attr("height", height);
        var ctx = canvas[0].getContext("2d");
        ctx.clearRect(0, 0, width, height);
        ctx.drawImage(
          image,
          0,
          0,
          image.width,
          image.height,
          0,
          0,
          width,
          height
        );

        // canvasからbase64画像データを取得し、POST用のBlobを作成する
        var base64 = canvas.get(0).toDataURL("image/jpeg");
        var barr, bin, i, len;
        bin = atob(base64.split("base64,")[1]);
        len = bin.length;
        barr = new Uint8Array(len);
        i = 0;
        while (i < len) {
          barr[i] = bin.charCodeAt(i);
          i++;
        }
        blob = new Blob([barr], { type: "image/jpeg" });
        console.log(blob);
      };
      image.src = e.target.result;
    };
    reader.readAsDataURL(file);
  });

  // アップロード開始ボタンがクリックされたら
  $("#post").click(function () {
    if (!file || !blob) {
      return;
    }

    var name,
      fd = new FormData();
    fd.append("files", blob);

    // API宛にPOSTする
    $.ajax({
      url: "/api/image_recognition",
      type: "POST",
      dataType: "json",
      data: fd,
      processData: false,
      contentType: false,
    })
      .done(function (data, textStatus, jqXHR) {
          // 通信が成功した場合、結果を出力する
        var response = JSON.stringify(data);
        var response = JSON.parse(response);
        console.log(response);
        var result = document.getElementById("result");
        result.innerHTML = "この画像...「" + response["response"] + "」やんけ";
      })
      .fail(function (jqXHR, textStatus, errorThrown) {
          // 通信が失敗した場合、エラーメッセージを出力する
        var result = document.getElementById("result");
        result.innerHTML = "サーバーとの通信が失敗した...";
      });
  });
});

使用ajax对图像识别API进行POST,并显示结果。

操作检查

结果就是这样!
demo.png

(我想使前面板更时尚一些...)

结论

我制作了图像识别API来研究Fast API。
我认为这一次的实现不是最佳实践,但是我很高兴能够做出一些行之有效的事情。

我不知道将来使用哪种框架,但是FastAPI相对易于使用,我认为我应该从Flask切换到。

最后但同样重要的是,我要感谢大家的参考!