第5回 Java WebアプリでTensorFlow(実装編) (4/4)

技術特集

1_tensorflow_keras1_tit

tit_tensorflow_keras

4) JavaからTensorFlow Servingを呼び出す

(1) Javaプロジェクトの準備

TensorFlow ServingにはGoogleが開発したRPCの実装であるgRPCを用いて接続します。
Javaの外部ライブラリとして、gRPC本体と、gRPCで使用するTensorFlow Serving用のインタフェースを定義するものを参照する必要があります。

Mavenを使う場合、gRPC本体は依存関係に以下を追加すれば利用可能です。

<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-netty</artifactId>
  <version>1.9.0</version>
</dependency>
<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-protobuf</artifactId>
  <version>1.9.0</version>
</dependency>
<dependency>
  <groupId>io.grpc</groupId>
  <artifactId>grpc-stub</artifactId>
  <version>1.9.0</version>
</dependency>

gRPCで使用するTensorFlow Serving用のインタフェースについては、その定義がprotobufという形式で配布されており、公式な手順としてはこれをJavaのクラスに自分で変換して埋め込むということが必要です。
有志開発者によりあとはMavenでビルドするだけの状態に整えられたプロジェクトが公開されていますので、利用するとよいでしょう。

(2) 画像のロード

今回作成した学習済みモデルは、入力として固定長のfloat値の集合を受け取りますので、判定したい画像をfloat値のリストに変換してやる必要があります。
そのような操作を行うライブラリを探してみたのですが、そのものずばり実現できるものは見つからなかったので下記のようなクラスを作成しました。

package jp.scsk.furuba.wi.service;
 
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
 
import javax.imageio.ImageIO;
 
public class ImageLoader {
    private static final int NUM_CHANNELS = 3;
    private final int width;
    private final int height;
 
    public ImageLoader(int width, int height) {
        this.width = width;
        this.height = height;
    }
 
    public List<Float> load(String path) {
        BufferedImage rawImage;
        try (InputStream imgIs = new FileInputStream(new File(path))) {
            rawImage = ImageIO.read(imgIs);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
 
        BufferedImage image = new BufferedImage(width, height, rawImage.getType());
        Image scaledImage =
            rawImage.getScaledInstance(width, height, Image.SCALE_AREA_AVERAGING);
        image.getGraphics().drawImage(scaledImage, 0, 0, width, height, null);
 
        List<Float> flatImage = new ArrayList<>(width * height * NUM_CHANNELS);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                int rgb = image.getRGB(x, y);
                // intで表現されているRGB値をRGBの各チャンネルに分解し、0-1のfloatに変換
                flatImage.add(toFloatRed(rgb));
                flatImage.add(toFloatGreen(rgb));
                flatImage.add(toFloatBlue(rgb));
            }
        }
        return flatImage;
    }
 
    private static final float MAX_PIXEL_VALUE = 255;
 
    public static float toFloatRed(int rgb) {
        return (rgb >> 16 &amp; 0xff) / MAX_PIXEL_VALUE;
    }
 
    public static float toFloatGreen(int rgb) {
        return (rgb >> 8 &amp; 0xff) / MAX_PIXEL_VALUE;
    }
 
    public static float toFloatBlue(int rgb) {
        return (rgb &amp; 0xff) / MAX_PIXEL_VALUE;
    }
}

画像を読み込んで指定のサイズに変換し、Floatのリストとして出力するようにしています。
今回手元にある学習済みモデルは224ピクセル×224ピクセル×3チャンネルのデータを入力として受け取りますので、たとえば以下のように使うことになります。

ImageLoader loader = new ImageLoader(224, 224);
List<Float> flatImage = loader.load("(画像のパス)");

なお、処理にAWT関連のクラスを使用している関係上、サーバ環境で動かすとエラーとなる場合があります。
その際にはJavaの起動オプションとして-Djava.awt.headless=trueを指定してください。

(3) データをサーバで処理

無事入力データを得たところで、それを実際にサーバで処理していきます。
まずはサーバとの通信を担うStubと呼ばれるオブジェクトを作成します。

ManagedChannel grpcChannel = ManagedChannelBuilder
    .forAddress("(サーバの名前かIPアドレス)", 9000)
    .usePlaintext(true)
    .build();
PredictionServiceBlockingStub grpcStub =
    PredictionServiceGrpc.newBlockingStub(grpcChannel);

Stubにはいくつか種類がありますが、今回は同期的に処理を書きたいので、サーバからの応答を待つ間ブロックするBlockingStubを利用しています。

次にサーバに送信するリクエストを作成し、送信します。
判定したい画像のデータはflatImage変数に入っている想定です。

ClassificationRequest request;
{
    ExampleList exampleList = ExampleList.newBuilder()
        .addExamples(buildExample(flatImage))
        .build();
 
    request = ClassificationRequest.newBuilder()
        .setModelSpec(ModelSpec.newBuilder().setName("(学習済みモデルの識別名)"))
        .setInput(Input.newBuilder().setExampleList(exampleList))
        .build();
}

ClassificationRequestなどのクラスはprotoファイルから生成したもので、基本的にBuilderによってオブジェクトを作成します。
ClassificationRequestのBuilderにおいて、setModelSpecではTensorFlow Servingの保持しているどのモデルに宛てたリクエストなのかということを指定し、
setInputでは送信する入力データを指定しています。

送信する入力データであるexampleListの作成において、このコードでは1件だけ画像データをセットしていますが、addExamplesメソッドを繰り返し呼ぶことで複数件のデータをまとめてリクエストに含めることもできます。

上記のコードで使用しているbuildExample関数の定義は以下のとおりです:

private static Example buildExample(List<Float> flatImage) {
    FloatList floatList = FloatList.newBuilder().addAllValue(flatImage).build();
    Feature feature = Feature.newBuilder().setFloatList(floatList).build();
    Features features = Features.newBuilder().putFeature("img", feature).build();
    Example example = Example.newBuilder().setFeatures(features).build();
    return example;
}

List<Float>であるところのflatImageを、TensorFlow Servingが要求する入力データの形であるExampleに変換しています。

その後、grpcStubを通してリクエストの送信を行います。

ClassificationResponse response = grpcStub
    .withDeadlineAfter(10, TimeUnit.SECONDS)
    .classify(request);

(4) 結果の確認

先程得たresponseから処理の結果が取得できます。
処理結果は、入力データ1件ごとにClassificationsというクラスのオブジェクトに格納されています。これは分類タスクの結果を格納するものでで、getClassesメソッドを用いて分類ごとのスコア(ここでは猫のスコアと犬のスコア)にアクセスできます。

List<Classifications> results = response.getResult().getClassificationsList()
Classifications firstResult = results.get(0)
 
System.out.println(String.format("猫らしさ:%f", firstResult.getClasses(0).getScore()));
System.out.println(String.format("犬らしさ:%f", firstResult.getClasses(1).getScore()));

ここまでのコード(buildExample関数を除く)をまとめると、以下の通りになります。

ImageLoader loader = new ImageLoader(224, 224);
List<Float> flatImage = loader.load("(画像のパス)");
 
ManagedChannel grpcChannel = ManagedChannelBuilder
    .forAddress("(サーバの名前かIPアドレス)", 9000)
    .usePlaintext(true)
    .build();
PredictionServiceBlockingStub grpcStub = PredictionServiceGrpc.newBlockingStub(grpcChannel);
 
ClassificationRequest request;
{
    ExampleList exampleList = ExampleList.newBuilder()
        .addExamples(buildExample(flatImage))
        .build();
 
    request = ClassificationRequest.newBuilder()
        .setModelSpec(ModelSpec.newBuilder().setName("(学習済みモデルの識別名)"))
        .setInput(Input.newBuilder().setExampleList(exampleList))
        .build();
}
 
ClassificationResponse response = grpcStub
    .withDeadlineAfter(10, TimeUnit.SECONDS)
    .classify(request);
 
List<Classifications> results = response.getResult().getClassificationsList()
Classifications firstResult = results.get(0)
 
System.out.println(String.format("猫らしさ:%f", firstResult.getClasses(0).getScore()));
System.out.println(String.format("犬らしさ:%f", firstResult.getClasses(1).getScore()));

これを手元にある犬の画像に対して実行したところ、下記のような結果を得ました:

猫らしさ:0.016830
犬らしさ:0.983170

無事、JavaからTensorFlowの学習済みモデルを利用することができました!

おわりに

TensorFlow Servingを用いると、学習済みモデルをサービス化し、gRPCが使える限り好きな言語から利用することができます。
これにより、ディープラーニング関連の作業はPython環境で行いつつ、Javaのお堅いシステムからそれを利用するということが実現できます。

さて、Kerasお仕着せの楽勝ディープラーニングからスタートしたこの解説も、やや駆け足ではありましたが、本番のビジネスで使える道具立てがそろうところまで、たどり着きました。
ここから先は現場の課題に応じて、うまくディープラーニングや、その他の手段を適用していくという段階になります。
この解説が、あなたの現場の課題解決にいつか役立つことを祈念しつつ、このあたりで終わりにしたいと思います。

ここまでお読みいただき、ありがとうございました。