4) JavaからTensorFlow Servingを呼び出す
(1) Javaプロジェクトの準備
TensorFlow ServingにはGoogleが開発したRPCの実装であるgRPCを用いて接続します。
Javaの外部ライブラリとして、gRPC本体と、gRPCで使用するTensorFlow Serving用のインタフェースを定義するものを参照する必要があります。
Mavenを使う場合、gRPC本体は依存関係に以下を追加すれば利用可能です。
<dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty</artifactId> <version>1.9.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> <version>1.9.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> <version>1.9.0</version> </dependency>
gRPCで使用するTensorFlow Serving用のインタフェースについては、その定義がprotobufという形式で配布されており、公式な手順としてはこれをJavaのクラスに自分で変換して埋め込むということが必要です。
有志開発者によりあとはMavenでビルドするだけの状態に整えられたプロジェクトが公開されていますので、利用するとよいでしょう。
(2) 画像のロード
今回作成した学習済みモデルは、入力として固定長のfloat値の集合を受け取りますので、判定したい画像をfloat値のリストに変換してやる必要があります。
そのような操作を行うライブラリを探してみたのですが、そのものずばり実現できるものは見つからなかったので下記のようなクラスを作成しました。
package jp.scsk.furuba.wi.service; import java.awt.Image; import java.awt.image.BufferedImage; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.List; import javax.imageio.ImageIO; public class ImageLoader { private static final int NUM_CHANNELS = 3; private final int width; private final int height; public ImageLoader(int width, int height) { this.width = width; this.height = height; } public List<Float> load(String path) { BufferedImage rawImage; try (InputStream imgIs = new FileInputStream(new File(path))) { rawImage = ImageIO.read(imgIs); } catch (IOException e) { throw new RuntimeException(e); } BufferedImage image = new BufferedImage(width, height, rawImage.getType()); Image scaledImage = rawImage.getScaledInstance(width, height, Image.SCALE_AREA_AVERAGING); image.getGraphics().drawImage(scaledImage, 0, 0, width, height, null); List<Float> flatImage = new ArrayList<>(width * height * NUM_CHANNELS); for (int x = 0; x < width; x++) { for (int y = 0; y < height; y++) { int rgb = image.getRGB(x, y); // intで表現されているRGB値をRGBの各チャンネルに分解し、0-1のfloatに変換 flatImage.add(toFloatRed(rgb)); flatImage.add(toFloatGreen(rgb)); flatImage.add(toFloatBlue(rgb)); } } return flatImage; } private static final float MAX_PIXEL_VALUE = 255; public static float toFloatRed(int rgb) { return (rgb >> 16 & 0xff) / MAX_PIXEL_VALUE; } public static float toFloatGreen(int rgb) { return (rgb >> 8 & 0xff) / MAX_PIXEL_VALUE; } public static float toFloatBlue(int rgb) { return (rgb & 0xff) / MAX_PIXEL_VALUE; } }
画像を読み込んで指定のサイズに変換し、Float
のリストとして出力するようにしています。
今回手元にある学習済みモデルは224ピクセル×224ピクセル×3チャンネルのデータを入力として受け取りますので、たとえば以下のように使うことになります。
ImageLoader loader = new ImageLoader(224, 224); List<Float> flatImage = loader.load("(画像のパス)");
なお、処理にAWT関連のクラスを使用している関係上、サーバ環境で動かすとエラーとなる場合があります。
その際にはJavaの起動オプションとして-Djava.awt.headless=true
を指定してください。
(3) データをサーバで処理
無事入力データを得たところで、それを実際にサーバで処理していきます。
まずはサーバとの通信を担うStubと呼ばれるオブジェクトを作成します。
ManagedChannel grpcChannel = ManagedChannelBuilder .forAddress("(サーバの名前かIPアドレス)", 9000) .usePlaintext(true) .build(); PredictionServiceBlockingStub grpcStub = PredictionServiceGrpc.newBlockingStub(grpcChannel);
Stubにはいくつか種類がありますが、今回は同期的に処理を書きたいので、サーバからの応答を待つ間ブロックするBlockingStubを利用しています。
次にサーバに送信するリクエストを作成し、送信します。
判定したい画像のデータはflatImage
変数に入っている想定です。
ClassificationRequest request; { ExampleList exampleList = ExampleList.newBuilder() .addExamples(buildExample(flatImage)) .build(); request = ClassificationRequest.newBuilder() .setModelSpec(ModelSpec.newBuilder().setName("(学習済みモデルの識別名)")) .setInput(Input.newBuilder().setExampleList(exampleList)) .build(); }
ClassificationRequest
などのクラスはprotoファイルから生成したもので、基本的にBuilderによってオブジェクトを作成します。
ClassificationRequest
のBuilderにおいて、setModelSpec
ではTensorFlow Servingの保持しているどのモデルに宛てたリクエストなのかということを指定し、
setInput
では送信する入力データを指定しています。
送信する入力データであるexampleList
の作成において、このコードでは1件だけ画像データをセットしていますが、addExamples
メソッドを繰り返し呼ぶことで複数件のデータをまとめてリクエストに含めることもできます。
上記のコードで使用しているbuildExample関数の定義は以下のとおりです:
private static Example buildExample(List<Float> flatImage) { FloatList floatList = FloatList.newBuilder().addAllValue(flatImage).build(); Feature feature = Feature.newBuilder().setFloatList(floatList).build(); Features features = Features.newBuilder().putFeature("img", feature).build(); Example example = Example.newBuilder().setFeatures(features).build(); return example; }
List<Float>
であるところのflatImage
を、TensorFlow Servingが要求する入力データの形であるExample
に変換しています。
その後、grpcStub
を通してリクエストの送信を行います。
ClassificationResponse response = grpcStub .withDeadlineAfter(10, TimeUnit.SECONDS) .classify(request);
(4) 結果の確認
先程得たresponse
から処理の結果が取得できます。
処理結果は、入力データ1件ごとにClassifications
というクラスのオブジェクトに格納されています。これは分類タスクの結果を格納するものでで、getClasses
メソッドを用いて分類ごとのスコア(ここでは猫のスコアと犬のスコア)にアクセスできます。
List<Classifications> results = response.getResult().getClassificationsList() Classifications firstResult = results.get(0) System.out.println(String.format("猫らしさ:%f", firstResult.getClasses(0).getScore())); System.out.println(String.format("犬らしさ:%f", firstResult.getClasses(1).getScore()));
ここまでのコード(buildExample
関数を除く)をまとめると、以下の通りになります。
ImageLoader loader = new ImageLoader(224, 224); List<Float> flatImage = loader.load("(画像のパス)"); ManagedChannel grpcChannel = ManagedChannelBuilder .forAddress("(サーバの名前かIPアドレス)", 9000) .usePlaintext(true) .build(); PredictionServiceBlockingStub grpcStub = PredictionServiceGrpc.newBlockingStub(grpcChannel); ClassificationRequest request; { ExampleList exampleList = ExampleList.newBuilder() .addExamples(buildExample(flatImage)) .build(); request = ClassificationRequest.newBuilder() .setModelSpec(ModelSpec.newBuilder().setName("(学習済みモデルの識別名)")) .setInput(Input.newBuilder().setExampleList(exampleList)) .build(); } ClassificationResponse response = grpcStub .withDeadlineAfter(10, TimeUnit.SECONDS) .classify(request); List<Classifications> results = response.getResult().getClassificationsList() Classifications firstResult = results.get(0) System.out.println(String.format("猫らしさ:%f", firstResult.getClasses(0).getScore())); System.out.println(String.format("犬らしさ:%f", firstResult.getClasses(1).getScore()));
これを手元にある犬の画像に対して実行したところ、下記のような結果を得ました:
猫らしさ:0.016830
犬らしさ:0.983170
無事、JavaからTensorFlowの学習済みモデルを利用することができました!
おわりに
TensorFlow Servingを用いると、学習済みモデルをサービス化し、gRPCが使える限り好きな言語から利用することができます。
これにより、ディープラーニング関連の作業はPython環境で行いつつ、Javaのお堅いシステムからそれを利用するということが実現できます。
さて、Kerasお仕着せの楽勝ディープラーニングからスタートしたこの解説も、やや駆け足ではありましたが、本番のビジネスで使える道具立てがそろうところまで、たどり着きました。
ここから先は現場の課題に応じて、うまくディープラーニングや、その他の手段を適用していくという段階になります。
この解説が、あなたの現場の課題解決にいつか役立つことを祈念しつつ、このあたりで終わりにしたいと思います。
ここまでお読みいただき、ありがとうございました。
※TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
その他、本コンテンツ内で利用させて頂いた各プロダクト名やサービス名などは、各社もしくは各団体の商標または登録商標です。