本文最后更新于:14 天前
simple-tensorflow-serving
TensorFlow Serving是一种灵活,高性能的机器学习模型服务系统,专为生产环境而设计。TensorFlow服务可以轻松部署新算法和实验,同时保持相同的服务器架构和API。TensorFlow Serving提供与TensorFlow模型的开箱即用集成,但可以轻松扩展以提供其他类型的模型和数据。
官方提供的serving只支持TensorFlow相关的模型,但是这里介绍的simple-TensorFlow-serving Support multiple models of TensorFlow/MXNet/PyTorch/Caffe2/CNTK/ONNX/H2o/Scikit-learn/XGBoost/PMML, 操作流程和官方的serving是一样的,很是简单。
savedmodel格式模型保存
【直接保存】
tf.saved_model.simple_save(sess,
"./model/1/", # 保存路径
inputs={"myInput": inputs}, # 数据接口
outputs={"myOutput": logitic}) # 输出接口
【TensorFlow】
checkpoints转savedmodel
import tensorflow as tf
# 定义输入和标签的placeholder
inputs = tf.placeholder(tf.float32, [None, 32, 32, 3]) # 输入数据
logitic = ... # 输出结果
print('TESTING....')
# TODO: 保存模型
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
print("Evaluate The Model")
sess.run(init)
# TODO: 读取模型
saver.restore(sess, './model/cifar.ckpt') # 加载checkpoint模型
# Initialize v1 since the saver will not.
tf.saved_model.simple_save(
sess,
"./savedmodel/2/",
inputs={"image": inputs},
outputs={"scores": logitic}
)
.pb转savedmodel
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
export_dir = './saved/1' # 保存路径
graph_pb = 'freezed.pb' # 导入模型文件
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name("input:0") # 输入数据
out = g.get_tensor_by_name("output:0") # 输出结果
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"in": inp}, {"out": out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()
【Keras】
.h5转savedmodel
import tensorflow as tf
with tf.device("/cpu:0"):
model = tf.keras.models.load_model('./model.h5') # 导入模型文件
tf.saved_model.simple_save(
tf.keras.backend.get_session(),
"h5_savedmodel/1/", # 保存路径
inputs={"image": model.input}, # 输入数据
outputs={"scores": model.output} # 输出结果
)
print('end')
模型部署
simple_tensorflow_serving --model_base_path="./models/"
生成客户端代码
curl http://localhost:8500/v1/models/default/gen_client?language=python > client.py
查看模型接口
saved_model_cli show --dir saved_model/0 --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['input'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, 161, 1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, 2599)
name: dense/BiasAdd:0
Method name is: tensorflow/serving/predict
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!