使用python API进行的培训作为Java API中LabelImage模块的输入?

我有java tensorflow API的问题。 我使用python tensorflow API运行训练,生成文件output_graph.pb和output_labels.txt。 现在由于某种原因,我想使用这些文件作为java tensorflow API中LabelImage模块的输入。 我认为一切都会正常工作,因为该模块只需要一个.pb和一个.txt。 不过,当我运行模块时,我收到此错误:

2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization(). Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph at org.tensorflow.Session$Runner.operationByName(Session.java:343) at org.tensorflow.Session$Runner.feed(Session.java:137) at org.tensorflow.Session$Runner.feed(Session.java:126) at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115) at it.zero11.LabelImage.main(LabelImage.java:68) 

如果你帮助我找到问题所在,我将非常感激。 此外,我想问你是否有办法从java tensorflow API运行培训,因为这会使事情变得更容易。

更确切地说:

事实上,我不使用自编代码,至少对于相关步骤。 我所做的就是使用这个模块进行培训, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py ,将其包含在包含子目录之间的图像的目录中根据他们的描述。 特别是,我认为这些是产生输出的线:

 output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(FLAGS.output_graph, 'wb') as f: f.write(output_graph_def.SerializeToString()) with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') 

然后,我将输出(一个some_graph.pb和一个some_labels.txt)作为此java模块的输入: https : //github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/ org / tensorflow / examples / LabelImage.java ,替换默认输入。 我得到的错误是上面报告的错误。

LabelImage.java中默认使用的模型与正在重新训练的模型不同,因此输入和输出节点的名称不对齐。 请注意,TensorFlow模型是图形, feed()fetch()的参数是图形中节点的名称。 因此,您需要知道适合您模型的名称。

看看retrain.py ,它似乎有一个节点,它将JPEG文件的原始内容作为输入(节点DecodeJpeg/contents ),并在节点final_result生成标签集。

如果是这种情况,那么你将在Java中执行类似下面的操作(并且您不需要构造图形来对图像进行规范化的位,因为这似乎是重新训练模型的一部分,因此请替换LabelImage.java:64有类似的东西:

 try (Tensor image = Tensor.create(imageBytes); Graph g = new Graph()) { g.importGraphDef(graphDef); try (Session s = new Session(g); // Note the change to the name of the node and the fact // that it is being provided the raw imageBytes as input Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) { final long[] rshape = result.shape(); if (result.numDimensions() != 2 || rshape[0] != 1) { throw new RuntimeException( String.format( "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape))); } int nlabels = (int) rshape[1]; float[] probabilities = result.copyTo(new float[1][nlabels])[0]; // At this point nlabels = number of classes in your retrained model DoSomethingWith(probabilities); } } 

希望有所帮助。

关于“无操作”错误,我能够通过分别使用输入和输出层名称“Mul”和“final_result”来解决这个问题。 看到:

https://github.com/tensorflow/tensorflow/issues/2883