TensorFlow Lite のサイトにはホストされているモデルの一覧があり、ここからダウンロードすることができます。
https://www.tensorflow.org/lite/models
現在ここには以下のモデルがあります。
- AutoML mobile image classification models (Float Models)
- Image classification (Float Models)
- Image classification (Quantized Models)
- Other models
- Smart reply
モデルがたくさんあってどれを選べばいいのかわからない...となりますよね。size, accuracy, performance などを見てユースケースにあったものを選びましょう。
モバイルアプリに組み込むとなると size はできれば10Mb以下に抑えたいですし、カメラのプレビューに繋いでリアルタイムに推論するなら performance は30ms以下にしたいところです。一方静止画像で推論するなら perfomance が多少遅くなっても accuracy が高いものを選ぶことができます。
MobileNet
MobileNet は on-device や組み込みアプリケーションなどの制限されたリソースを考慮しながら正確さを最大限に高められるように設計されたモデルです。サイズが小さく低遅延で低消費電力という特徴があります。
MobileNet の Pre-trained Models は ImageNet Large Scale Visual Recognition Challenge 2012 (ILSVRC2012) の image classification dataset で学習されています。これの学習用データは 1000 classes の計 1.2 million の画像(class ごとに約700〜1300枚の画像)です。
Mobilenet_V1_1.0_224_quant
静止画像を Mobilenet_V1_1.0_224_quant で推論してみましょう。
Mobilenet_V1_1.0_224_quant は MobileNet V1の Post-training quantization が施されたモデルです。
まずはモデルをダウンロードしましょう。
中には tflite ファイルの他に checkpoint(ckpt.*)も入っています。
$ ls mobilenet_v1_1.0_224_quant
mobilenet_v1_1.0_224_quant.ckpt.data-00000-of-00001
mobilenet_v1_1.0_224_quant.ckpt.index
mobilenet_v1_1.0_224_quant.ckpt.meta
mobilenet_v1_1.0_224_quant.tflite
mobilenet_v1_1.0_224_quant_eval.pbtxt
mobilenet_v1_1.0_224_quant_frozen.pb
mobilenet_v1_1.0_224_quant_info.txt
モデルを使うにあたって、どんな入力を渡せばいいのか、どんな出力が得られるのかを知らなければいけません。そこでまず入力と出力の形式を調べます。
https://www.tensorflow.org/install/ に従って tensorflow をインストールし、以下の python コードを実行して入力と出力の情報を print します。
dump_input_and_output.py
import tensorflow as tf
interpreter = tf.contrib.lite.Interpreter(model_path="mobilenet_v1_1.0_224_quant.tflite")
interpreter.allocate_tensors()
print("input")
print(interpreter.get_input_details()[0])
print("output")
print(interpreter.get_output_details()[0])
$ python dump_input_and_output.py
input
{'name': 'input', 'index': 87, 'shape': array([ 1, 224, 224, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.007843137718737125, 128)}
output
{'name': 'MobilenetV1/Predictions/Reshape_1', 'index': 86, 'shape': array([ 1, 1001], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.00390625, 0)}
入力の shape は [ 1, 224, 224, 3] の int32 です。1番目の 1 はバッチサイズ、2番目の 224 は画像の width、3番目の 224 は画像の height、4番目の 3 は色情報(R,G,B)です。入力の dtype は numpy.uint8 です。
出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。1000個より1つ多いのは index 0 が 'background' class として予約されているからです。出力の dtype は numpy.uint8 です。
1001個 の class それぞれの確率が出力されるわけですが、index 0 が 'background' として index 1 〜 1000 の class は何でしょうか。
それを調べるために tensorflow の datasets の中にある imagenet.py の create_readable_names_for_imagenet_labels() を利用します。 このメソッドは index と class 名(人が読めるラベル)のマップを返します。
LSVRC の synsets 一覧(index 1 〜 1000 の class の id 一覧)
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/imagenet_lsvrc_2015_synsets.txt
n01440764
n01443537
n01484850
...
と、id とラベル名の一覧
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/imagenet_metadata.txt
n00004475 organism, being
n00005787 benthos
n00006024 heterotroph
...
からマップを作っています。
n01440764 → tench, Tinca tinca
n01443537 → goldfish, Carassius auratus
n01484850 → great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
...
{
0: 'background',
1: 'tench, Tinca tinca',
2: 'goldfish, Carassius auratus',
3: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
...
1000: 'toilet tissue, toilet paper, bathroom tissue'
}
これを参考にして、ラベルだけ出力する python コードを実行して labels.txt として保存しておきます。
dump_labels.py
from six.moves import urllib
synset_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_lsvrc_2015_synsets.txt'
synset_to_label_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_metadata.txt'
filename, _ = urllib.request.urlretrieve(synset_url)
synset_list = [s.strip() for s in open(filename).readlines()]
assert len(synset_list) == 1000
filename, _ = urllib.request.urlretrieve(synset_to_label_url)
synset_to_label_list = open(filename).readlines()
assert len(synset_to_label_list) == 21842
synset_to_label_map = {}
for s in synset_to_label_list:
parts = s.strip().split('\t')
assert len(parts) == 2
synset_to_label_map[parts[0]] = parts[1]
print("background")
for synset in synset_list:
print(synset_to_label_map[synset])
$ python dump_labels.py > labels.txt
$ cat labels.txt
background
tench, Tinca tinca
goldfish, Carassius auratus
great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
...
toilet tissue, toilet paper, bathroom tissue
これで事前準備が完了です。その2で ML Kit に組み込んでいきます。
「ML Kit Custom Model その2 : Mobilenet_V1_1.0_224_quant を LocalModel として使う」
0 件のコメント:
コメントを投稿