2019年1月1日火曜日

ML Kit Custom Model その1 : TensorFlow Lite Hosted Models を利用する

ML Kit Custom Model を使ってみるには TensorFlow Lite 形式のモデルファイルが必要です。
TensorFlow Lite のサイトにはホストされているモデルの一覧があり、ここからダウンロードすることができます。

https://www.tensorflow.org/lite/models




現在ここには以下のモデルがあります。
  • AutoML mobile image classification models (Float Models)
  • Image classification (Float Models)
  • Image classification (Quantized Models)
  • Other models
    • Smart reply
Image classification は画像識別を行うモデルです。入力として画像のピクセルデータを渡すと、画像に写っているものを識別し、各ラベルの確率が出力されます。

モデルがたくさんあってどれを選べばいいのかわからない...となりますよね。size, accuracy, performance などを見てユースケースにあったものを選びましょう。

モバイルアプリに組み込むとなると size はできれば10Mb以下に抑えたいですし、カメラのプレビューに繋いでリアルタイムに推論するなら performance は30ms以下にしたいところです。一方静止画像で推論するなら perfomance が多少遅くなっても accuracy が高いものを選ぶことができます。



MobileNet

MobileNet は on-device や組み込みアプリケーションなどの制限されたリソースを考慮しながら正確さを最大限に高められるように設計されたモデルです。サイズが小さく低遅延で低消費電力という特徴があります。

MobileNet の Pre-trained Models は ImageNet Large Scale Visual Recognition Challenge 2012 (ILSVRC2012) の image classification dataset で学習されています。これの学習用データは 1000 classes の計 1.2 million の画像(class ごとに約700〜1300枚の画像)です。



Mobilenet_V1_1.0_224_quant

静止画像を Mobilenet_V1_1.0_224_quant で推論してみましょう。

Mobilenet_V1_1.0_224_quant は MobileNet V1の Post-training quantization が施されたモデルです。

まずはモデルをダウンロードしましょう。



中には tflite ファイルの他に checkpoint(ckpt.*)も入っています。
  1. $ ls mobilenet_v1_1.0_224_quant  
  2. mobilenet_v1_1.0_224_quant.ckpt.data-00000-of-00001  
  3. mobilenet_v1_1.0_224_quant.ckpt.index  
  4. mobilenet_v1_1.0_224_quant.ckpt.meta  
  5. mobilenet_v1_1.0_224_quant.tflite  
  6. mobilenet_v1_1.0_224_quant_eval.pbtxt  
  7. mobilenet_v1_1.0_224_quant_frozen.pb  
  8. mobilenet_v1_1.0_224_quant_info.txt  
モデルを使うにあたって、どんな入力を渡せばいいのか、どんな出力が得られるのかを知らなければいけません。そこでまず入力と出力の形式を調べます。

https://www.tensorflow.org/install/ に従って tensorflow をインストールし、以下の python コードを実行して入力と出力の情報を print します。

dump_input_and_output.py
  1. import tensorflow as tf  
  2.   
  3. interpreter = tf.contrib.lite.Interpreter(model_path="mobilenet_v1_1.0_224_quant.tflite")  
  4. interpreter.allocate_tensors()  
  5.   
  6. print("input")  
  7. print(interpreter.get_input_details()[0])  
  8.   
  9. print("output")  
  10. print(interpreter.get_output_details()[0])  
  1. $ python dump_input_and_output.py  
  2. input  
  3. {'name''input''index'87'shape': array([  1224224,   3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.007843137718737125128)}  
  4. output  
  5. {'name''MobilenetV1/Predictions/Reshape_1''index'86'shape': array([   11001], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003906250)}  
入力の shape は [ 1, 224, 224, 3] の int32 です。1番目の 1 はバッチサイズ、2番目の 224 は画像の width、3番目の 224 は画像の height、4番目の 3 は色情報(R,G,B)です。入力の dtype は numpy.uint8 です。

出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。1000個より1つ多いのは index 0 が 'background' class として予約されているからです。出力の dtype は numpy.uint8 です。


1001個 の class それぞれの確率が出力されるわけですが、index 0 が 'background' として index 1 〜 1000 の class は何でしょうか。

それを調べるために tensorflow の datasets の中にある imagenet.py の create_readable_names_for_imagenet_labels() を利用します。 このメソッドは index と class 名(人が読めるラベル)のマップを返します。

LSVRC の synsets 一覧(index 1 〜 1000 の class の id 一覧)
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/imagenet_lsvrc_2015_synsets.txt
  1. n01440764  
  2. n01443537  
  3. n01484850  
  4. ...  
と、id とラベル名の一覧
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/imagenet_metadata.txt
  1. n00004475 organism, being  
  2. n00005787 benthos  
  3. n00006024 heterotroph  
  4. ...  
からマップを作っています。
  1. n01440764 → tench, Tinca tinca  
  2. n01443537 → goldfish, Carassius auratus  
  3. n01484850 → great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias  
  4. ...  
  1. {  
  2.   0'background',  
  3.   1'tench, Tinca tinca',  
  4.   2'goldfish, Carassius auratus',  
  5.   3'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',  
  6.   ...  
  7.   1000'toilet tissue, toilet paper, bathroom tissue'  
  8. }  
これを参考にして、ラベルだけ出力する python コードを実行して labels.txt として保存しておきます。

dump_labels.py
  1. from six.moves import urllib  
  2.   
  3. synset_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_lsvrc_2015_synsets.txt'  
  4. synset_to_label_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/imagenet_metadata.txt'  
  5.   
  6. filename, _ = urllib.request.urlretrieve(synset_url)  
  7. synset_list = [s.strip() for s in open(filename).readlines()]  
  8. assert len(synset_list) == 1000  
  9.   
  10. filename, _ = urllib.request.urlretrieve(synset_to_label_url)  
  11. synset_to_label_list = open(filename).readlines()  
  12. assert len(synset_to_label_list) == 21842  
  13.   
  14. synset_to_label_map = {}  
  15. for s in synset_to_label_list:  
  16.   parts = s.strip().split('\t')  
  17.   assert len(parts) == 2  
  18.   synset_to_label_map[parts[0]] = parts[1]  
  19.   
  20. print("background")  
  21.   
  22. for synset in synset_list:  
  23.   print(synset_to_label_map[synset])  
  1. $ python dump_labels.py > labels.txt  
  2. $ cat labels.txt  
  3. background  
  4. tench, Tinca tinca  
  5. goldfish, Carassius auratus  
  6. great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias  
  7. ...  
  8. toilet tissue, toilet paper, bathroom tissue  



これで事前準備が完了です。その2で ML Kit に組み込んでいきます。

「ML Kit Custom Model その2 : Mobilenet_V1_1.0_224_quant を LocalModel として使う」


0 件のコメント:

コメントを投稿