いよいよ ML Kit Custom Model に組み込みます。
ベースとして MLKitSample ( https://github.com/yanzm/MLKitSample/tree/start ) の start ブランチを利用します。
Firebase でプロジェクトを作ってアプリを登録し、google-services.json を app モジュール直下に配置します。詳しくは上記 MLKitSample の README.md 課題1,2 を参考にしてください。
アプリの dependencies に以下を追加します。
dependencies {
...
implementation "com.google.firebase:firebase-ml-model-interpreter:16.2.4"
}
今回は mobilenet_v1_1.0_224_quant.tflite をあらかじめアプリにバンドルして LocalModel として利用します。
そのため app/src/main/assets に mobilenet_v1_1.0_224_quant.tflite を配置します。また labels.txt も assets に配置します。
推論を実行する際には、入力データの他に入出力の形式も指定しないといけません。そのために用意するのが FirebaseModelInputOutputOptions です。
前回調べた入出力の shape は以下の通りでした。
入力の shape は [ 1, 224, 224, 3] の int32 です。1番目の 1 はバッチサイズ、2番目の 224 は画像の width、3番目の 224 は画像の height、4番目の 3 は色情報(R,G,B)です。入力の dtype は numpy.uint8 です。
出力の shape は [ 1, 1001] の int32 です。1番目の 1 はバッチサイズ、2番目の 1001 は class の個数です。出力の dtype は numpy.uint8 です。
入力の shape は [ 1, 224, 224, 3] の int32 なので inputDims として intArrayOf(1, 224, 224, 3) を用意します。
出力の shape は [ 1, 1001] の int32 なので outputDims として intArrayOf(1, 1001) を用意します。
入出力いずれも dtype は numpy.uint8 なので、dataType には FirebaseModelDataType.BYTE を指定します。
private val dataOptions: FirebaseModelInputOutputOptions by lazy {
val inputDims = intArrayOf(1, 224, 224, 3)
val outputDims = intArrayOf(1, 1001)
FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.BYTE, inputDims)
.setOutputFormat(0, FirebaseModelDataType.BYTE, outputDims)
.build()
}
assets にある model を LocalModel として使うには FirebaseLocalModelSource を用意します。 FirebaseLocalModelSource.Builder() の引数に LocalModelSource を識別するための名前(ここでは "asset")を指定し、setAssetFilePath() で assets/ に配置したモデルファイルを指定します。
作成した FirebaseLocalModelSource は FirebaseModelManager.registerLocalModelSource() で登録しておきます。
FirebaseModelOptions.Builder の setLocalModelName() には FirebaseLocalModelSource.Builder() の引数に指定した名前(ここでは "asset")を渡します。
最後に FirebaseModelOptions を渡して FirebaseModelInterpreter のインスタンスを取得します。
private val interpreter: FirebaseModelInterpreter by lazy {
val localModelSource = FirebaseLocalModelSource.Builder("asset")
.setAssetFilePath("mobilenet_v1_1.0_224_quant.tflite")
.build()
FirebaseModelManager.getInstance().registerLocalModelSource(localModelSource)
val modelOptions = FirebaseModelOptions.Builder()
.setLocalModelName("asset")
.build()
FirebaseModelInterpreter.getInstance(modelOptions)!!
}
次に検出部分の実装をしていきます。
FirebaseModelInterpreter の入力データは ByteBuffer か配列か多次元配列でなければなりません。 ここでは ByteBuffer にしてみます。
モデルに渡す画像は 224 x 224 なので Bitmap.createScaledBitmap() で画像をスケールし、ピクセル情報のうち R,G,B のデータを ByteBuffer に入れます。
/**
* Bitmap を ByteBuffer に変換
*/
private fun convertToByteBuffer(bitmap: Bitmap): ByteBuffer {
val intValues = IntArray(224 * 224).apply {
// bitmap を 224 x 224 に変換
val scaled = Bitmap.createScaledBitmap(bitmap, 224, 224, true)
// ピクセル情報を IntArray に取り出し
scaled.getPixels(this, 0, 224, 0, 0, 224, 224)
}
// ピクセル情報のうち R,G,B を ByteBuffer に入れる
return ByteBuffer.allocateDirect(1 * 224 * 224 * 3).apply {
order(ByteOrder.nativeOrder())
rewind()
for (value in intValues) {
put((value shr 16 and 0xFF).toByte())
put((value shr 8 and 0xFF).toByte())
put((value and 0xFF).toByte())
}
}
}
private fun detect(bitmap: Bitmap) {
overlay.clear()
// Bitmap を ByteBuffer に変換
val imageByteBuffer = convertToByteBuffer(bitmap)
}
FirebaseModelInputs.Builder で入力データを作成し、FirebaseModelInterpreter の run() で推論を実行します。
private fun detect(bitmap: Bitmap) {
overlay.clear()
// Bitmap を ByteBuffer に変換
val imageByteBuffer = convertToByteBuffer(bitmap)
val inputs = FirebaseModelInputs.Builder()
.add(imageByteBuffer)
.build()
interpreter
.run(inputs, dataOptions)
.addOnSuccessListener { outputs ->
val output = outputs!!.getOutput<Array<ByteArray>>(0)
// output.size : 1
val labelProbabilities: ByteArray = output[0]
// labelProbabilities.size : 1001
// labelProbabilities の各 Byte を 255f で割ると確率値になる
// 確率の高い上位3つを取得
val topLabels = getTopLabels(labelProbabilities, 3)
overlay.add(TextsData(topLabels))
}
.addOnFailureListener { e ->
e.printStackTrace()
detectButton.isEnabled = true
progressBar.visibility = View.GONE
Toast.makeText(this, e.message, Toast.LENGTH_SHORT).show()
}
}
assets の labels.txt を読んでラベルのリストを用意しておき、labelProbabilities の各確率に対応するラベルを割り出します。
getTopLabels() では PriorityQueue を使って確率の高いラベルだけ残すようにします。
private val labelList: List<String> by lazy {
assets.open("labels.txt").reader().use {
it.readLines()
}
}
@Synchronized
private fun getTopLabels(labelProbabilities: ByteArray, maxSize: Int): List<String> {
val sortedLabels = PriorityQueue<Map.Entry<String, Float>>(
maxSize,
Comparator<Map.Entry<String, Float>> { o1, o2 -> o1.value.compareTo(o2.value) }
)
labelList.forEachIndexed { index, label ->
sortedLabels.add(
AbstractMap.SimpleEntry<String, Float>(
label,
(labelProbabilities[index].toInt() and 0xff) / 255f
)
)
if (sortedLabels.size > maxSize) {
sortedLabels.poll()
}
}
return sortedLabels.map { "${it.key} : ${it.value}" }
.reversed()
}
サイなのだが、トリケラトプスって出てる...
猫がでない...
像は認識できた
次は「ML Kit Custom Model その3 : Mobilenet_V1_1.0_224_quant を CloudModel として使う」 です。
0 件のコメント:
コメントを投稿