Skip to content

[Feature]: Refactor EuclideanDistance for performance optimization #3484

@sheldon-roberts

Description

@sheldon-roberts

Problem statement

Flair's implementation for calculating euclidean distances scales poorly. This leads to slow training speeds for the PrototypicalDecoder with distance_function="euclidean".

Solution

EuclideanDistance.forward loops over mat_2.size(0), which is the num_prototypes in the case of the PrototypicalDecoder. This leads to a substantial slowdown in training speed when the number of prototypes is large.

Instead of calculating distances with a for loop, utilize torch.cdist.

Additional Context

Here is a script demonstrating the potential performance gains with the refactor.

import torch
import time
from torch import Tensor


# The current implementation
def euclidean_distance_old(mat_1: Tensor, mat_2: Tensor) -> Tensor:
    _dist = [torch.sum((mat_1 - mat_2[i]) ** 2, dim=1) for i in range(mat_2.size(0))]
    dist = torch.stack(_dist, dim=1)
    return dist


# Refactored implementation
def euclidean_distance_new(mat_1: Tensor, mat_2: Tensor) -> Tensor:
    return torch.cdist(mat_1, mat_2).pow(2)


# Compare performance
def test_euclidean_distance_performance():
    total_samples = 10
    total_time_method_1 = 0.0
    total_time_method_2 = 0.0

    for _ in range(total_samples):
        batch_size = 4
        num_prototypes = 10_000
        embeddings_size = 128

        mat_1 = torch.randn(batch_size, embeddings_size)
        mat_2 = torch.randn(num_prototypes, embeddings_size)

        # Method 1
        start_time = time.time()
        dist_method_1 = euclidean_distance_old(mat_1, mat_2)
        end_time = time.time()
        total_time_method_1 += (end_time - start_time)

        # Method 2
        start_time = time.time()
        dist_method_2 = euclidean_distance_new(mat_1, mat_2)
        end_time = time.time()
        total_time_method_2 += (end_time - start_time)

        # Compare outputs
        if not torch.allclose(dist_method_1, dist_method_2):
            print("Test failed!")
            return False

    avg_time_method_1 = total_time_method_1 / total_samples
    avg_time_method_2 = total_time_method_2 / total_samples

    print("Average time per calculation - Method 1 (EuclideanDistance module): {:.6f} seconds".format(avg_time_method_1))
    print("Average time per calculation - Method 2 (torch.cdist): {:.6f} seconds".format(avg_time_method_2))

    speedup_factor = avg_time_method_1 / avg_time_method_2
    print("Refactored method is {:.2f} times faster than old method".format(speedup_factor))

test_euclidean_distance_performance()

Output:

Average time per calculation - Method 1 (EuclideanDistance module): 0.239093 seconds
Average time per calculation - Method 2 (torch.cdist): 0.001680 seconds
Refactored method is 142.34 times faster than old method

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions