python – Slow performance of PyTorch Categorical

I have been using a PPO (Proximal Policy Optimisation) architecture for training my agent in a custom simulator. My simulator has become quite fast as it is written in Rust. The speed of my inner loop is therefore bottlenecked by some functions that are inside the PPO agent.

When I profiled the function with pyinstrument it showed that most of the time is spent on initialising the Categorical class and calculating the log probabilities.

I hope someone can help and if there is a faster way to do this using PyTorch.

    def act(self, state):
        action_probs =
        dist = Categorical(action_probs)

        action = dist.sample()
        action_logprob = dist.log_prob(action)

        return action.detach(), action_logprob.detach()

    def evaluate(self, state, action):
        """Evaluates the action given the state."""
        action_probs =
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_values = self.critic(state)

        return action_logprobs, state_values, dist_entropy

Pyinstrument showing the speed of the program.

I have seen some other techniques to do this, but it was not very clear to me if they would inprove the speed.

Read more here: Source link