TensorFlow Serving with a variable batch size

TensorFlow serving can handle a variable batch size when doing predictions. I never understood how to configure this and also the shape of the results returned. Finally figuring this out, here’s the changes to our previous serving setup to accept a variable number of images to classify for our model.

Serving input function

First thing is to update our serving input receiver function placeholder. In the past we had set the placeholder to have a shape of [1], for variable batch size, this is as easy as setting it to [None]. Our updated input receiver function is now:

def serving_input_receiver_fn():
    
    feature_spec = {
        'image': tf.FixedLenFeature([], dtype=tf.string)
    }
    
    default_batch_size = None # default_batch_size: the number of query examples expected per batch. Leave unset for variable batch size (recommended).
    
    serialized_tf_example = tf.placeholder(
        dtype=tf.string, shape=[default_batch_size], 
        name='input_image_tensor')
    
    received_tensors = { 'images': serialized_tf_example }
    features = tf.parse_example(serialized_tf_example, feature_spec)
    
    fn = lambda image: _img_string_to_tensor(image, input_img_size)
    
    features['image'] = tf.map_fn(fn, features['image'], dtype=tf.float32)
    
    return tf.estimator.export.ServingInputReceiver(features, received_tensors)

Making a batched request

Secondly we’ll make a couple of changes to our make_request function to accept a variable number of input paths or urls. We also need to reshape the result to extract these features appropriately.

The code is below, but the main changes are:

  • Accept multiple paths with *file_paths
  • Add each serialized example to a list
  • Update the request.inputs['inputs'] shape to be len(serialized)

When we access the output classes and scores we receive a list and so we reshape this into the desired batch size and category size. And then extract the top scored class and transform the full results into a dictionary. A bit more complex than expected, but the result gives us the dimensions for these makes things a little easier. See below for full request and response sample.

def make_request_multi(stub, *file_paths):
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'default'
    
    serialized = []
    for file_path in file_paths:
        if file_path.startswith('http'):
            data = urllib.request.urlopen(file_path).read()
        else:
            with open(file_path, 'rb') as f:
                data = f.read()

        feature_dict = {
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))
        }
        example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
        serialized.append(example.SerializeToString())        
    
    request.inputs['inputs'].CopyFrom(tf.contrib.util.make_tensor_proto(serialized, shape=[len(serialized)]))
    
    result_future = stub.Predict.future(request, 10.0)
    prediction = result_future.result()
    
    # Get shape of batch and categories in prediction
    NUM_CLASSES = prediction.outputs['classes'].tensor_shape.dim[1].size
    NUM_PREDICTIONS = prediction.outputs['classes'].tensor_shape.dim[0].size
    output_shape = (NUM_PREDICTIONS, NUM_CLASSES)
    
    # Reshape the output into prediction per sample
    classes = np.reshape(prediction.outputs['classes'].string_val, output_shape)
    scores = np.reshape(prediction.outputs['scores'].float_val, output_shape)
    
    # Create prediction scores per sample
    results = []
    all = np.dstack((classes, scores))
    for pred in all:
        results.append([{ 'label': str(label), 'score': float(score) } for (label, score) in pred])
    
    # Return list of predicted classes per sample
    predicted_classes = np.take(classes, np.argmax(scores, axis=-1))
    
    return predicted_classes, results

pred_classes, pred_class_scores = make_request_multi(
    stub, 
    os.path.expanduser('~/Downloads/Dog_CTA_Desktop_HeroImage-1024x496.jpg'), 
    os.path.expanduser('~/Downloads/Dog_CTA_Desktop_HeroImage-1024x496.jpg'), 
    'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTeclOl-kLClSo9SS0fH1dF2h35hABWBwQCYQMI2HWGZY4H6teBfg'
)

print(pred_classes)
# array([b'dogs', b'dogs', b'cats'], dtype='|S4')

print(json.dumps(pred_class_scores, indent=4))
[
#     [
#         {
#             "label": "b'dogs'",
#             "score": 0.9997784495353699
#         },
#         {
#             "label": "b'cats'",
#             "score": 0.00022158514184411615
#         }
#     ],
#     [
#         {
#             "label": "b'dogs'",
#             "score": 0.9997784495353699
#         },
#         {
#             "label": "b'cats'",
#             "score": 0.00022158514184411615
#         }
#     ],
#     [
#         {
#             "label": "b'dogs'",
#             "score": 4.9226546252612025e-06
#         },
#         {
#             "label": "b'cats'",
#             "score": 0.9999951124191284
#         }
#     ]
# ]