2019年1月1日火曜日

ML Kit Custom Model その2 : Mobilenet_V1_1.0_224_quant を LocalModel として使う

「ML Kit Custom Model その1 : TensorFlow Lite Hosted Models を利用する」で mobilenet_v1_1.0_224_quant.tflite をダウンロードし、ラベル情報を labels.txt として用意しました。

いよいよ ML Kit Custom Model に組み込みます。


ベースとして MLKitSample ( https://github.com/yanzm/MLKitSample/tree/start ) の start ブランチを利用します。

Firebase でプロジェクトを作ってアプリを登録し、google-services.json を app モジュール直下に配置します。詳しくは上記 MLKitSample の README.md 課題1,2 を参考にしてください。

アプリの dependencies に以下を追加します。
  1. dependencies {  
  2.     ...  
  3.   
  4.     implementation "com.google.firebase:firebase-ml-model-interpreter:16.2.4"  
  5. }  


今回は mobilenet_v1_1.0_224_quant.tflite をあらかじめアプリにバンドルして LocalModel として利用します。
そのため app/src/main/assets に mobilenet_v1_1.0_224_quant.tflite を配置します。また labels.txt も assets に配置します。





推論を実行する際には、入力データの他に入出力の形式も指定しないといけません。そのために用意するのが FirebaseModelInputOutputOptions です。

前回調べた入出力の shape は以下の通りでした。

入力の shape は [ 1, 224, 224, 3] の int32 です。1番目の 1 はバッチサイズ、2番目の 224 は画像の width、3番目の 224 は画像の height、4番目の 3 は色情報(R,G,B)です。入力の dtype は numpy.uint8 です。

出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。出力の dtype は numpy.uint8 です。

入力の shape は [ 1, 224, 224, 3] の int32 なので inputDims として intArrayOf(1, 224, 224, 3) を用意します。
出力の shape は [ 1, 1001] の int32 なので outputDims として intArrayOf(1, 1001) を用意します。

入出力いずれも dtype は numpy.uint8 なので、dataType には FirebaseModelDataType.BYTE を指定します。
  1. private val dataOptions: FirebaseModelInputOutputOptions by lazy {  
  2.     val inputDims = intArrayOf(12242243)  
  3.     val outputDims = intArrayOf(11001)  
  4.   
  5.     FirebaseModelInputOutputOptions.Builder()  
  6.         .setInputFormat(0, FirebaseModelDataType.BYTE, inputDims)  
  7.         .setOutputFormat(0, FirebaseModelDataType.BYTE, outputDims)  
  8.         .build()  
  9. }  


assets にある model を LocalModel として使うには FirebaseLocalModelSource を用意します。 FirebaseLocalModelSource.Builder() の引数に LocalModelSource を識別するための名前(ここでは "asset")を指定し、setAssetFilePath() で assets/ に配置したモデルファイルを指定します。

作成した FirebaseLocalModelSource は FirebaseModelManager.registerLocalModelSource() で登録しておきます。

FirebaseModelOptions.Builder の setLocalModelName() には FirebaseLocalModelSource.Builder() の引数に指定した名前(ここでは "asset")を渡します。

最後に FirebaseModelOptions を渡して FirebaseModelInterpreter のインスタンスを取得します。
  1. private val interpreter: FirebaseModelInterpreter by lazy {  
  2.   
  3.     val localModelSource = FirebaseLocalModelSource.Builder("asset")  
  4.         .setAssetFilePath("mobilenet_v1_1.0_224_quant.tflite")  
  5.         .build()  
  6.   
  7.     FirebaseModelManager.getInstance().registerLocalModelSource(localModelSource)  
  8.   
  9.     val modelOptions = FirebaseModelOptions.Builder()  
  10.         .setLocalModelName("asset")  
  11.         .build()  
  12.   
  13.     FirebaseModelInterpreter.getInstance(modelOptions)!!  
  14. }  


次に検出部分の実装をしていきます。

FirebaseModelInterpreter の入力データは ByteBuffer か配列か多次元配列でなければなりません。 ここでは ByteBuffer にしてみます。

モデルに渡す画像は 224 x 224 なので Bitmap.createScaledBitmap() で画像をスケールし、ピクセル情報のうち R,G,B のデータを ByteBuffer に入れます。
  1. /** 
  2.  * Bitmap を ByteBuffer に変換 
  3.  */  
  4. private fun convertToByteBuffer(bitmap: Bitmap): ByteBuffer {  
  5.   
  6.     val intValues = IntArray(224 * 224).apply {  
  7.   
  8.         // bitmap を 224 x 224 に変換  
  9.         val scaled = Bitmap.createScaledBitmap(bitmap, 224224true)  
  10.   
  11.         // ピクセル情報を IntArray に取り出し  
  12.         scaled.getPixels(this022400224224)  
  13.     }  
  14.   
  15.     // ピクセル情報のうち R,G,B を ByteBuffer に入れる  
  16.     return ByteBuffer.allocateDirect(1 * 224 * 224 * 3).apply {  
  17.         order(ByteOrder.nativeOrder())  
  18.         rewind()  
  19.   
  20.         for (value in intValues) {  
  21.             put((value shr 16 and 0xFF).toByte())  
  22.             put((value shr 8 and 0xFF).toByte())  
  23.             put((value and 0xFF).toByte())  
  24.         }  
  25.     }  
  26. }  
  27.   
  28. private fun detect(bitmap: Bitmap) {  
  29.     overlay.clear()  
  30.   
  31.     // Bitmap を ByteBuffer に変換  
  32.     val imageByteBuffer = convertToByteBuffer(bitmap)  
  33. }  
FirebaseModelInputs.Builder で入力データを作成し、FirebaseModelInterpreter の run() で推論を実行します。
  1. private fun detect(bitmap: Bitmap) {  
  2.     overlay.clear()  
  3.   
  4.     // Bitmap を ByteBuffer に変換  
  5.     val imageByteBuffer = convertToByteBuffer(bitmap)  
  6.   
  7.     val inputs = FirebaseModelInputs.Builder()  
  8.         .add(imageByteBuffer)  
  9.         .build()  
  10.   
  11.     interpreter  
  12.         .run(inputs, dataOptions)  
  13.         .addOnSuccessListener { outputs ->  
  14.   
  15.             val output = outputs!!.getOutput<Array<ByteArray>>(0)  
  16.             // output.size : 1  
  17.   
  18.             val labelProbabilities: ByteArray = output[0]  
  19.             // labelProbabilities.size : 1001  
  20.             // labelProbabilities の各 Byte を 255f で割ると確率値になる  
  21.   
  22.             // 確率の高い上位3つを取得  
  23.             val topLabels = getTopLabels(labelProbabilities, 3)  
  24.   
  25.             overlay.add(TextsData(topLabels))  
  26.         }  
  27.         .addOnFailureListener { e ->  
  28.             e.printStackTrace()  
  29.             detectButton.isEnabled = true  
  30.             progressBar.visibility = View.GONE  
  31.             Toast.makeText(this, e.message, Toast.LENGTH_SHORT).show()  
  32.         }  
  33. }  
assets の labels.txt を読んでラベルのリストを用意しておき、labelProbabilities の各確率に対応するラベルを割り出します。 getTopLabels() では PriorityQueue を使って確率の高いラベルだけ残すようにします。
  1. private val labelList: List<String> by lazy {  
  2.     assets.open("labels.txt").reader().use {  
  3.         it.readLines()  
  4.     }  
  5. }  
  6.   
  7. @Synchronized  
  8. private fun getTopLabels(labelProbabilities: ByteArray, maxSize: Int): List<String> {  
  9.     val sortedLabels = PriorityQueue<Map.Entry<String, Float>>(  
  10.         maxSize,  
  11.         Comparator<Map.Entry<String, Float>> { o1, o2 -> o1.value.compareTo(o2.value) }  
  12.     )  
  13.   
  14.     labelList.forEachIndexed { index, label ->  
  15.         sortedLabels.add(  
  16.             AbstractMap.SimpleEntry<String, Float>(  
  17.                 label,  
  18.                 (labelProbabilities[index].toInt() and 0xff) / 255f  
  19.             )  
  20.         )  
  21.         if (sortedLabels.size > maxSize) {  
  22.             sortedLabels.poll()  
  23.         }  
  24.     }  
  25.   
  26.     return sortedLabels.map { "${it.key} : ${it.value}" }  
  27.         .reversed()  
  28. }  




サイなのだが、トリケラトプスって出てる...



猫がでない...



像は認識できた



次は「ML Kit Custom Model その3 : Mobilenet_V1_1.0_224_quant を CloudModel として使う」 です。



0 件のコメント:

コメントを投稿