いよいよ ML Kit Custom Model に組み込みます。
ベースとして MLKitSample ( https://github.com/yanzm/MLKitSample/tree/start ) の start ブランチを利用します。
Firebase でプロジェクトを作ってアプリを登録し、google-services.json を app モジュール直下に配置します。詳しくは上記 MLKitSample の README.md 課題1,2 を参考にしてください。
アプリの dependencies に以下を追加します。
- dependencies {
- ...
- implementation "com.google.firebase:firebase-ml-model-interpreter:16.2.4"
- }
今回は mobilenet_v1_1.0_224_quant.tflite をあらかじめアプリにバンドルして LocalModel として利用します。
そのため app/src/main/assets に mobilenet_v1_1.0_224_quant.tflite を配置します。また labels.txt も assets に配置します。

推論を実行する際には、入力データの他に入出力の形式も指定しないといけません。そのために用意するのが FirebaseModelInputOutputOptions です。
前回調べた入出力の shape は以下の通りでした。
入力の shape は [ 1, 224, 224, 3] の int32 です。1番目の 1 はバッチサイズ、2番目の 224 は画像の width、3番目の 224 は画像の height、4番目の 3 は色情報(R,G,B)です。入力の dtype は numpy.uint8 です。
出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。出力の dtype は numpy.uint8 です。
入力の shape は [ 1, 224, 224, 3] の int32 なので inputDims として intArrayOf(1, 224, 224, 3) を用意します。
出力の shape は [ 1, 1001] の int32 なので outputDims として intArrayOf(1, 1001) を用意します。
入出力いずれも dtype は numpy.uint8 なので、dataType には FirebaseModelDataType.BYTE を指定します。
- private val dataOptions: FirebaseModelInputOutputOptions by lazy {
- val inputDims = intArrayOf(1, 224, 224, 3)
- val outputDims = intArrayOf(1, 1001)
- FirebaseModelInputOutputOptions.Builder()
- .setInputFormat(0, FirebaseModelDataType.BYTE, inputDims)
- .setOutputFormat(0, FirebaseModelDataType.BYTE, outputDims)
- .build()
- }
assets にある model を LocalModel として使うには FirebaseLocalModelSource を用意します。 FirebaseLocalModelSource.Builder() の引数に LocalModelSource を識別するための名前(ここでは "asset")を指定し、setAssetFilePath() で assets/ に配置したモデルファイルを指定します。
作成した FirebaseLocalModelSource は FirebaseModelManager.registerLocalModelSource() で登録しておきます。
FirebaseModelOptions.Builder の setLocalModelName() には FirebaseLocalModelSource.Builder() の引数に指定した名前(ここでは "asset")を渡します。
最後に FirebaseModelOptions を渡して FirebaseModelInterpreter のインスタンスを取得します。
- private val interpreter: FirebaseModelInterpreter by lazy {
- val localModelSource = FirebaseLocalModelSource.Builder("asset")
- .setAssetFilePath("mobilenet_v1_1.0_224_quant.tflite")
- .build()
- FirebaseModelManager.getInstance().registerLocalModelSource(localModelSource)
- val modelOptions = FirebaseModelOptions.Builder()
- .setLocalModelName("asset")
- .build()
- FirebaseModelInterpreter.getInstance(modelOptions)!!
- }
次に検出部分の実装をしていきます。
FirebaseModelInterpreter の入力データは ByteBuffer か配列か多次元配列でなければなりません。 ここでは ByteBuffer にしてみます。
モデルに渡す画像は 224 x 224 なので Bitmap.createScaledBitmap() で画像をスケールし、ピクセル情報のうち R,G,B のデータを ByteBuffer に入れます。
- /**
- * Bitmap を ByteBuffer に変換
- */
- private fun convertToByteBuffer(bitmap: Bitmap): ByteBuffer {
- val intValues = IntArray(224 * 224).apply {
- // bitmap を 224 x 224 に変換
- val scaled = Bitmap.createScaledBitmap(bitmap, 224, 224, true)
- // ピクセル情報を IntArray に取り出し
- scaled.getPixels(this, 0, 224, 0, 0, 224, 224)
- }
- // ピクセル情報のうち R,G,B を ByteBuffer に入れる
- return ByteBuffer.allocateDirect(1 * 224 * 224 * 3).apply {
- order(ByteOrder.nativeOrder())
- rewind()
- for (value in intValues) {
- put((value shr 16 and 0xFF).toByte())
- put((value shr 8 and 0xFF).toByte())
- put((value and 0xFF).toByte())
- }
- }
- }
- private fun detect(bitmap: Bitmap) {
- overlay.clear()
- // Bitmap を ByteBuffer に変換
- val imageByteBuffer = convertToByteBuffer(bitmap)
- }
- private fun detect(bitmap: Bitmap) {
- overlay.clear()
- // Bitmap を ByteBuffer に変換
- val imageByteBuffer = convertToByteBuffer(bitmap)
- val inputs = FirebaseModelInputs.Builder()
- .add(imageByteBuffer)
- .build()
- interpreter
- .run(inputs, dataOptions)
- .addOnSuccessListener { outputs ->
- val output = outputs!!.getOutput<Array<ByteArray>>(0)
- // output.size : 1
- val labelProbabilities: ByteArray = output[0]
- // labelProbabilities.size : 1001
- // labelProbabilities の各 Byte を 255f で割ると確率値になる
- // 確率の高い上位3つを取得
- val topLabels = getTopLabels(labelProbabilities, 3)
- overlay.add(TextsData(topLabels))
- }
- .addOnFailureListener { e ->
- e.printStackTrace()
- detectButton.isEnabled = true
- progressBar.visibility = View.GONE
- Toast.makeText(this, e.message, Toast.LENGTH_SHORT).show()
- }
- }
- private val labelList: List<String> by lazy {
- assets.open("labels.txt").reader().use {
- it.readLines()
- }
- }
- @Synchronized
- private fun getTopLabels(labelProbabilities: ByteArray, maxSize: Int): List<String> {
- val sortedLabels = PriorityQueue<Map.Entry<String, Float>>(
- maxSize,
- Comparator<Map.Entry<String, Float>> { o1, o2 -> o1.value.compareTo(o2.value) }
- )
- labelList.forEachIndexed { index, label ->
- sortedLabels.add(
- AbstractMap.SimpleEntry<String, Float>(
- label,
- (labelProbabilities[index].toInt() and 0xff) / 255f
- )
- )
- if (sortedLabels.size > maxSize) {
- sortedLabels.poll()
- }
- }
- return sortedLabels.map { "${it.key} : ${it.value}" }
- .reversed()
- }

サイなのだが、トリケラトプスって出てる...

猫がでない...

像は認識できた
次は「ML Kit Custom Model その3 : Mobilenet_V1_1.0_224_quant を CloudModel として使う」 です。
0 件のコメント:
コメントを投稿