TensorFlow Serving介绍

简介

TensorFlow的模型文件包含了深度学习模型的Graph和所有参数,其实就是checkpoint文件,用户可以加载模型文件继续训练或者对外提供Inference服务。

使用SavedModel导出模型

模型导出方式参考 https://tensorflow.github.io/serving/serving_basic

使用方法基本如下。

from tensorflow.python.saved_model import builder as saved_model_builder

export_path_base = sys.argv[-1]
export_path = os.path.join(
      compat.as_bytes(export_path_base),
      compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path

builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
      sess, [tag_constants.SERVING],
      signature_def_map={
           'predict_images':
               prediction_signature,
           signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
               classification_signature,
      },
      legacy_init_op=legacy_init_op)

builder.save()

可以参考 https://github.com/tobegit3hub/deep_recommend_system/ 提供的可运行代码示例。

./dense_classifier.py --mode savedmodel

使用exporter导出模型

这里有导出TensorFlow serving支持的模型文件例子,可以参考使用 https://github.com/tobegit3hub/deep_recommend_system/blob/master/dense_classifier.py

导出的代码也比较简单,用户在inputs和output中填入模型Inference时的输入和输出即可。

from tensorflow.contrib.session_bundle import exporter

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("model_path", "./model", "The path to export the model")
flags.DEFINE_integer("export_version", 1, "Version number of the model")

# Define the graph
keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
keys = tf.identity(keys_placeholder)

# Start the session

# Export the model
print("Exporting trained model to {}".format(FLAGS.model_path))
model_exporter = exporter.Exporter(saver)
model_exporter.init(
  sess.graph.as_graph_def(),
    named_graph_signatures={
      'inputs': exporter.generic_signature({"keys": keys_placeholder, "features": inference_features}),
      'outputs': exporter.generic_signature({"keys": keys, "softmax": inference_softmax, "prediction": inference_op})
    })
model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess)
print 'Done exporting!'

与SavedModel方法相比,两者都可以直接用TensorFlow Serving加载,我们使用deep_recommend_system导出两种模型方式测试过预测结果一模一样,只是模型文件大小不同。

导入带assert的模型文件

在NLP等场景除了参数文件,还需要导入vocabulary等文件,可以在exporter中设置assets_collection,参考 https://github.com/tensorflow/serving/issues/264

results matching ""

    No results matching ""