2019年1月1日火曜日

ML Kit Custom Model その2 : Mobilenet_V1_1.0_224_quant を LocalModel として使う

「ML Kit Custom Model その1 : TensorFlow Lite Hosted Models を利用する」で mobilenet_v1_1.0_224_quant.tflite をダウンロードし、ラベル情報を labels.txt として用意しました。

いよいよ ML Kit Custom Model に組み込みます。


ベースとして MLKitSample ( https://github.com/yanzm/MLKitSample/tree/start ) の start ブランチを利用します。

Firebase でプロジェクトを作ってアプリを登録し、google-services.json を app モジュール直下に配置します。詳しくは上記 MLKitSample の README.md 課題1,2 を参考にしてください。

アプリの dependencies に以下を追加します。 dependencies { ... implementation "com.google.firebase:firebase-ml-model-interpreter:16.2.4" }

今回は mobilenet_v1_1.0_224_quant.tflite をあらかじめアプリにバンドルして LocalModel として利用します。
そのため app/src/main/assets に mobilenet_v1_1.0_224_quant.tflite を配置します。また labels.txt も assets に配置します。





推論を実行する際には、入力データの他に入出力の形式も指定しないといけません。そのために用意するのが FirebaseModelInputOutputOptions です。

前回調べた入出力の shape は以下の通りでした。

入力の shape は [ 1, 224, 224, 3] の int32 です。1番目の 1 はバッチサイズ、2番目の 224 は画像の width、3番目の 224 は画像の height、4番目の 3 は色情報(R,G,B)です。入力の dtype は numpy.uint8 です。

出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。出力の dtype は numpy.uint8 です。

入力の shape は [ 1, 224, 224, 3] の int32 なので inputDims として intArrayOf(1, 224, 224, 3) を用意します。
出力の shape は [ 1, 1001] の int32 なので outputDims として intArrayOf(1, 1001) を用意します。

入出力いずれも dtype は numpy.uint8 なので、dataType には FirebaseModelDataType.BYTE を指定します。 private val dataOptions: FirebaseModelInputOutputOptions by lazy { val inputDims = intArrayOf(1, 224, 224, 3) val outputDims = intArrayOf(1, 1001) FirebaseModelInputOutputOptions.Builder() .setInputFormat(0, FirebaseModelDataType.BYTE, inputDims) .setOutputFormat(0, FirebaseModelDataType.BYTE, outputDims) .build() }

assets にある model を LocalModel として使うには FirebaseLocalModelSource を用意します。 FirebaseLocalModelSource.Builder() の引数に LocalModelSource を識別するための名前(ここでは "asset")を指定し、setAssetFilePath() で assets/ に配置したモデルファイルを指定します。

作成した FirebaseLocalModelSource は FirebaseModelManager.registerLocalModelSource() で登録しておきます。

FirebaseModelOptions.Builder の setLocalModelName() には FirebaseLocalModelSource.Builder() の引数に指定した名前(ここでは "asset")を渡します。

最後に FirebaseModelOptions を渡して FirebaseModelInterpreter のインスタンスを取得します。 private val interpreter: FirebaseModelInterpreter by lazy { val localModelSource = FirebaseLocalModelSource.Builder("asset") .setAssetFilePath("mobilenet_v1_1.0_224_quant.tflite") .build() FirebaseModelManager.getInstance().registerLocalModelSource(localModelSource) val modelOptions = FirebaseModelOptions.Builder() .setLocalModelName("asset") .build() FirebaseModelInterpreter.getInstance(modelOptions)!! }

次に検出部分の実装をしていきます。

FirebaseModelInterpreter の入力データは ByteBuffer か配列か多次元配列でなければなりません。 ここでは ByteBuffer にしてみます。

モデルに渡す画像は 224 x 224 なので Bitmap.createScaledBitmap() で画像をスケールし、ピクセル情報のうち R,G,B のデータを ByteBuffer に入れます。 /** * Bitmap を ByteBuffer に変換 */ private fun convertToByteBuffer(bitmap: Bitmap): ByteBuffer { val intValues = IntArray(224 * 224).apply { // bitmap を 224 x 224 に変換 val scaled = Bitmap.createScaledBitmap(bitmap, 224, 224, true) // ピクセル情報を IntArray に取り出し scaled.getPixels(this, 0, 224, 0, 0, 224, 224) } // ピクセル情報のうち R,G,B を ByteBuffer に入れる return ByteBuffer.allocateDirect(1 * 224 * 224 * 3).apply { order(ByteOrder.nativeOrder()) rewind() for (value in intValues) { put((value shr 16 and 0xFF).toByte()) put((value shr 8 and 0xFF).toByte()) put((value and 0xFF).toByte()) } } } private fun detect(bitmap: Bitmap) { overlay.clear() // Bitmap を ByteBuffer に変換 val imageByteBuffer = convertToByteBuffer(bitmap) } FirebaseModelInputs.Builder で入力データを作成し、FirebaseModelInterpreter の run() で推論を実行します。 private fun detect(bitmap: Bitmap) { overlay.clear() // Bitmap を ByteBuffer に変換 val imageByteBuffer = convertToByteBuffer(bitmap) val inputs = FirebaseModelInputs.Builder() .add(imageByteBuffer) .build() interpreter .run(inputs, dataOptions) .addOnSuccessListener { outputs -> val output = outputs!!.getOutput<Array<ByteArray>>(0) // output.size : 1 val labelProbabilities: ByteArray = output[0] // labelProbabilities.size : 1001 // labelProbabilities の各 Byte を 255f で割ると確率値になる // 確率の高い上位3つを取得 val topLabels = getTopLabels(labelProbabilities, 3) overlay.add(TextsData(topLabels)) } .addOnFailureListener { e -> e.printStackTrace() detectButton.isEnabled = true progressBar.visibility = View.GONE Toast.makeText(this, e.message, Toast.LENGTH_SHORT).show() } } assets の labels.txt を読んでラベルのリストを用意しておき、labelProbabilities の各確率に対応するラベルを割り出します。 getTopLabels() では PriorityQueue を使って確率の高いラベルだけ残すようにします。 private val labelList: List<String> by lazy { assets.open("labels.txt").reader().use { it.readLines() } } @Synchronized private fun getTopLabels(labelProbabilities: ByteArray, maxSize: Int): List<String> { val sortedLabels = PriorityQueue<Map.Entry<String, Float>>( maxSize, Comparator<Map.Entry<String, Float>> { o1, o2 -> o1.value.compareTo(o2.value) } ) labelList.forEachIndexed { index, label -> sortedLabels.add( AbstractMap.SimpleEntry<String, Float>( label, (labelProbabilities[index].toInt() and 0xff) / 255f ) ) if (sortedLabels.size > maxSize) { sortedLabels.poll() } } return sortedLabels.map { "${it.key} : ${it.value}" } .reversed() }



サイなのだが、トリケラトプスって出てる...



猫がでない...



像は認識できた



次は「ML Kit Custom Model その3 : Mobilenet_V1_1.0_224_quant を CloudModel として使う」 です。



0 件のコメント:

コメントを投稿