python – What causes periodic inference spikes in my pytorch-based transformer model for text classification on a fixed input batch?

I have a pytorch based 3 layer deep model that I use for classifying text inputs. Its based on transformers. Now I have fixed input X containing 10 text inputs in a batch. I call the inference function for X repeatedly 25 times. I note the timing. I see periodic spikes in inference. Why is this happening?

Note: I use time.time() to measure delta timings

Inference code (infer part in graph):

input_ids = torch.from_numpy(input_ids).to(self.device)
attention_masks = torch.from_numpy(attention_masks).to(self.device)
b_input_task = torch.full((input_ids.shape[0],), -1, dtype=torch.int32).to(self.device)
with torch.no_grad():
     result = self.opt_model((input_ids, attention_masks, b_input_task))
logits = result.detach().cpu().numpy()

I notice that the line result = self.opt_model… behaves non-deterministically in terms of timing.

I have already done self.opt_model.eval() in init function of my class that is initialized at the very beginning. I am running this on V100 GPU.

Graph with batch size 10 (10 inputs, repeatedly bombarded):

image

Graph with batch size 1 (1 input, repeatedly bombarded):
image2

If I add a 5 second delay between the repeated inference calls, the timings change (nearing to what I expect):

image2

What is this behaviour?

Read more here: Source link