TensorFlow Lite のサイトにはホストされているモデルの一覧があり、ここからダウンロードすることができます。
https://www.tensorflow.org/lite/models

現在ここには以下のモデルがあります。
- AutoML mobile image classification models (Float Models)
- Image classification (Float Models)
- Image classification (Quantized Models)
- Other models
- Smart reply
モデルがたくさんあってどれを選べばいいのかわからない...となりますよね。size, accuracy, performance などを見てユースケースにあったものを選びましょう。
モバイルアプリに組み込むとなると size はできれば10Mb以下に抑えたいですし、カメラのプレビューに繋いでリアルタイムに推論するなら performance は30ms以下にしたいところです。一方静止画像で推論するなら perfomance が多少遅くなっても accuracy が高いものを選ぶことができます。
MobileNet
MobileNet は on-device や組み込みアプリケーションなどの制限されたリソースを考慮しながら正確さを最大限に高められるように設計されたモデルです。サイズが小さく低遅延で低消費電力という特徴があります。
MobileNet の Pre-trained Models は ImageNet Large Scale Visual Recognition Challenge 2012 (ILSVRC2012) の image classification dataset で学習されています。これの学習用データは 1000 classes の計 1.2 million の画像(class ごとに約700〜1300枚の画像)です。
Mobilenet_V1_1.0_224_quant
静止画像を Mobilenet_V1_1.0_224_quant で推論してみましょう。
Mobilenet_V1_1.0_224_quant は MobileNet V1の Post-training quantization が施されたモデルです。
まずはモデルをダウンロードしましょう。

中には tflite ファイルの他に checkpoint(ckpt.*)も入っています。
- $ ls mobilenet_v1_1.0_224_quant
- mobilenet_v1_1.0_224_quant.ckpt.data-00000-of-00001
- mobilenet_v1_1.0_224_quant.ckpt.index
- mobilenet_v1_1.0_224_quant.ckpt.meta
- mobilenet_v1_1.0_224_quant.tflite
- mobilenet_v1_1.0_224_quant_eval.pbtxt
- mobilenet_v1_1.0_224_quant_frozen.pb
- mobilenet_v1_1.0_224_quant_info.txt
https://www.tensorflow.org/install/ に従って tensorflow をインストールし、以下の python コードを実行して入力と出力の情報を print します。
dump_input_and_output.py
- import tensorflow as tf
- interpreter = tf.contrib.lite.Interpreter(model_path="mobilenet_v1_1.0_224_quant.tflite")
- interpreter.allocate_tensors()
- print("input")
- print(interpreter.get_input_details()[0])
- print("output")
- print(interpreter.get_output_details()[0])
- $ python dump_input_and_output.py
- input
- {'name': 'input', 'index': 87, 'shape': array([ 1, 224, 224, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.007843137718737125, 128)}
- output
- {'name': 'MobilenetV1/Predictions/Reshape_1', 'index': 86, 'shape': array([ 1, 1001], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.00390625, 0)}
出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。1000個より1つ多いのは index 0 が 'background' class として予約されているからです。出力の dtype は numpy.uint8 です。
1001個 の class それぞれの確率が出力されるわけですが、index 0 が 'background' として index 1 〜 1000 の class は何でしょうか。
それを調べるために tensorflow の datasets の中にある imagenet.py の create_readable_names_for_imagenet_labels() を利用します。 このメソッドは index と class 名(人が読めるラベル)のマップを返します。
LSVRC の synsets 一覧(index 1 〜 1000 の class の id 一覧)
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/imagenet_lsvrc_2015_synsets.txt
- n01440764
- n01443537
- n01484850
- ...
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/imagenet_metadata.txt
- n00004475 organism, being
- n00005787 benthos
- n00006024 heterotroph
- ...
- n01440764 → tench, Tinca tinca
- n01443537 → goldfish, Carassius auratus
- n01484850 → great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
- ...
- {
- 0: 'background',
- 1: 'tench, Tinca tinca',
- 2: 'goldfish, Carassius auratus',
- 3: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
- ...
- 1000: 'toilet tissue, toilet paper, bathroom tissue'
- }
dump_labels.py
- from six.moves import urllib
- synset_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_lsvrc_2015_synsets.txt'
- synset_to_label_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_metadata.txt'
- filename, _ = urllib.request.urlretrieve(synset_url)
- synset_list = [s.strip() for s in open(filename).readlines()]
- assert len(synset_list) == 1000
- filename, _ = urllib.request.urlretrieve(synset_to_label_url)
- synset_to_label_list = open(filename).readlines()
- assert len(synset_to_label_list) == 21842
- synset_to_label_map = {}
- for s in synset_to_label_list:
- parts = s.strip().split('\t')
- assert len(parts) == 2
- synset_to_label_map[parts[0]] = parts[1]
- print("background")
- for synset in synset_list:
- print(synset_to_label_map[synset])
- $ python dump_labels.py > labels.txt
- $ cat labels.txt
- background
- tench, Tinca tinca
- goldfish, Carassius auratus
- great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
- ...
- toilet tissue, toilet paper, bathroom tissue
これで事前準備が完了です。その2で ML Kit に組み込んでいきます。
「ML Kit Custom Model その2 : Mobilenet_V1_1.0_224_quant を LocalModel として使う」
0 件のコメント:
コメントを投稿