Tensorflow Java多GPU推理

我有一个具有多个GPU的服务器,并希望在Java应用程序内的模型推理期间充分利用它们。 默认情况下,tensorflow会占用所有可用的GPU,但仅使用第一个GPU。

我可以想出三个选项来克服这个问题:

  1. 在进程级别限制设备可见性,即使用CUDA_VISIBLE_DEVICES环境变量。

    这将要求我运行java应用程序的几个实例并在它们之间分配流量。 不是那个诱人的想法。

  2. 在单个应用程序中启动多个会话,并尝试通过ConfigProto为每个会话分配一个设备:

     public class DistributedPredictor { private Predictor[] nested; private int[] counters; // ... public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) { nested = new Predictor[numDevices]; counters = new int[numDevices]; for (int i = 0; i < nested.length; i++) { nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice); } } public Prediction predict(Data data) { int i = acquirePredictorIndex(); Prediction result = nested[i].predict(data); releasePredictorIndex(i); return result; } private synchronized int acquirePredictorIndex() { int i = argmin(counters); counters[i] += 1; return i; } private synchronized void releasePredictorIndex(int i) { counters[i] -= 1; } } public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { GPUOptions gpuOptions = GPUOptions.newBuilder() .setVisibleDeviceList("" + deviceIdx) .setAllowGrowth(true) .build(); ConfigProto config = ConfigProto.newBuilder() .setGpuOptions(gpuOptions) .setInterOpParallelismThreads(numDevices * numThreadsPerDevice) .build(); byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); Graph graph = new Graph(); graph.importGraphDef(graphDef); this.session = new Session(graph, config.toByteArray()); } public Prediction predict(Data data) { // ... } } 

    这种方法似乎一目了然。 但是,会话偶尔会忽略setVisibleDeviceList选项,并且所有会话都会导致第一个导致Out-Of-Memory崩溃的设备。

  3. 使用tf.device()规范在python中以多塔方式构建模型。 在java端,在共享会话中给出不同的Predictor不同的塔。

    对我来说,感觉很麻烦和惯用。

更新:正如@ash提议的那样,还有另一种选择:

  1. 通过修改其定义( graphDef )为现有图的每个操作分配适当的设备。

    要完成它,可以调整方法2中的代码:

     public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); graphDef = setGraphDefDevice(graphDef, deviceIdx) Graph graph = new Graph(); graph.importGraphDef(graphDef); ConfigProto config = ConfigProto.newBuilder() .setAllowSoftPlacement(true) .build(); this.session = new Session(graph, config.toByteArray()); } private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException { String deviceString = String.format("/gpu:%d", deviceIdx); GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); i++) { builder.getNodeBuilder(i).setDevice(deviceString); } return builder.build().toByteArray(); } public Prediction predict(Data data) { // ... } } 

    就像其他提到的方法一样,这个方法并没有让我免于在设备之间手动分配数据。 但至少它运行稳定,并且相对容易实现。 总的来说,这看起来像(几乎)正常的技术。

使用tensorflow java API有一种优雅的方式来做这样的基本事情吗? 任何想法,将不胜感激。

简而言之:有一种解决方法,每个GPU最终会有一个会话。

细节:

一般流程是TensorFlow运行时尊重为图中的操作指定的设备。 如果没有为操作指定设备,则它会根据某些启发式“放置”它。 这些启发式技术目前导致“在GPU上进行操作:0如果GPU可用并且有GPU内核用于操作”( Placer::Run以防您感兴趣)。

您认为我要求的是TensorFlow的合理function请求 – 能够将序列化图形中的设备视为“虚拟”设备,在运行时映射到一组“phyiscal”设备,或者设置“默认设备” ”。 此function目前不存在。 向ConfigProto添加此选项是您可能要为其提交function请求的内容。

我可以在此期间提出一个解决方法。 首先,对您提出的解决方案进行一些评论。

  1. 你的第一个想法肯定会奏效,但正如你所指出的那样,很麻烦。

  2. ConfigProto中使用visible_device_list进行设置并不ConfigProto ,因为它实际上是一个每进程设置,并且在进程中创建第一个会话后被忽略。 这肯定没有记录,应该是(并且有点不幸的是,它出现在每会话配置中)。 但是,这解释了为什么您的建议不起作用以及您仍然看到使用单个GPU的原因。

  3. 这可行。

另一种选择是最终得到不同的图形(操作明确放在不同的GPU上),每个GPU产生一个会话。 这样的东西可以用来编辑图形并明确地为每个操作分配一个设备:

 public static byte[] modifyGraphDef(byte[] graphDef, String device) throws Exception { GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); ++i) { builder.getNodeBuilder(i).setDevice(device); } return builder.build().toByteArray(); } 

之后,您可以使用以下内容为每个GPU创建一个GraphSession

 final int NUM_GPUS = 8; // setAllowSoftPlacement: Just in case our device modifications were too aggressive // (eg, setting a GPU device on an operation that only has CPU kernels) // setLogDevicePlacment: So we can see what happens. byte[] config = ConfigProto.newBuilder() .setLogDevicePlacement(true) .setAllowSoftPlacement(true) .build() .toByteArray(); Graph graphs[] = new Graph[NUM_GPUS]; Session sessions[] = new Session[NUM_GPUS]; for (int i = 0; i < NUM_GPUS; ++i) { graphs[i] = new Graph(); graphs[i].importGraphDef(modifyGraphDef(graphDef, String.format("/gpu:%d", i))); sessions[i] = new Session(graphs[i], config); } 

然后使用sessions[i]在GPU #i上执行图形。

希望有所帮助。