-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
featureA new featureA new feature
Description
Problem statement
Flair's implementation for calculating euclidean distances scales poorly. This leads to slow training speeds for the PrototypicalDecoder with distance_function="euclidean".
Solution
EuclideanDistance.forward loops over mat_2.size(0), which is the num_prototypes in the case of the PrototypicalDecoder. This leads to a substantial slowdown in training speed when the number of prototypes is large.
Instead of calculating distances with a for loop, utilize torch.cdist.
Additional Context
Here is a script demonstrating the potential performance gains with the refactor.
import torch
import time
from torch import Tensor
# The current implementation
def euclidean_distance_old(mat_1: Tensor, mat_2: Tensor) -> Tensor:
_dist = [torch.sum((mat_1 - mat_2[i]) ** 2, dim=1) for i in range(mat_2.size(0))]
dist = torch.stack(_dist, dim=1)
return dist
# Refactored implementation
def euclidean_distance_new(mat_1: Tensor, mat_2: Tensor) -> Tensor:
return torch.cdist(mat_1, mat_2).pow(2)
# Compare performance
def test_euclidean_distance_performance():
total_samples = 10
total_time_method_1 = 0.0
total_time_method_2 = 0.0
for _ in range(total_samples):
batch_size = 4
num_prototypes = 10_000
embeddings_size = 128
mat_1 = torch.randn(batch_size, embeddings_size)
mat_2 = torch.randn(num_prototypes, embeddings_size)
# Method 1
start_time = time.time()
dist_method_1 = euclidean_distance_old(mat_1, mat_2)
end_time = time.time()
total_time_method_1 += (end_time - start_time)
# Method 2
start_time = time.time()
dist_method_2 = euclidean_distance_new(mat_1, mat_2)
end_time = time.time()
total_time_method_2 += (end_time - start_time)
# Compare outputs
if not torch.allclose(dist_method_1, dist_method_2):
print("Test failed!")
return False
avg_time_method_1 = total_time_method_1 / total_samples
avg_time_method_2 = total_time_method_2 / total_samples
print("Average time per calculation - Method 1 (EuclideanDistance module): {:.6f} seconds".format(avg_time_method_1))
print("Average time per calculation - Method 2 (torch.cdist): {:.6f} seconds".format(avg_time_method_2))
speedup_factor = avg_time_method_1 / avg_time_method_2
print("Refactored method is {:.2f} times faster than old method".format(speedup_factor))
test_euclidean_distance_performance()
Output:
Average time per calculation - Method 1 (EuclideanDistance module): 0.239093 seconds
Average time per calculation - Method 2 (torch.cdist): 0.001680 seconds
Refactored method is 142.34 times faster than old method
david-waterworth
Metadata
Metadata
Assignees
Labels
featureA new featureA new feature