Skip to the content.

Fast Kernel Summation via Slicing

Back to: Main Page

In the following, we give details on the implemented fast kernel summation. A theoretical background can be found here. A precise specification of all attributes and arguments in this library is given here. We aim to compute the kernel sums

\[s_m=\sum_{n=1}^N w_n K(x_n,y_m)\]

for all $m=1,…,M$. The naive implementation has a computational complexity of $O(MN)$. The implementation of this library has complexity $O(M+N)$.

Supported Kernels

The implementation currently supports the following kernels:

Note: A better treatment for kernels which are non-smooth or singular at $x=y$ is implemented in the NFFT3 library.

Usage and Example

To use the fast kernel summation, we first create a Fastsum object with fastsum=Fastsum(d, kernel="Gauss"). It takes the dimension and the kernel (as string from the above list) as input.

Afterwards, we can compute the vector $s=(s_1,…,s_M)$ by s=fastsum(x, y, w, xis_or_P) where x has the shape (...,N,d), y has the shape (...,M,d) and w has the shape (...,N). The argument xis_or_P either takes the number of considered slices as integer (higher number = higher accuracy) or the slices itself as a tensor of size (P,d). The fastsum method supports batching. That is, we can add arbitrary many dimensions in the beginning of each of the tensors x, y and w. The number of batching dimensions has to be the same for all inputs, but broadcasting is supported (i.e., expanding batching dimensions of size 1).

Other optional arguments for the constructor of the Fastsum object include (full list in the specification):

import torch
from simple_torch_NFFT import Fastsum

device = "cuda" if torch.cuda.is_available() else "cpu"

d = 10 # data dimension
kernel = "Gauss" # kernel type


fastsum = Fastsum(d, kernel=kernel, device=device) # fastsum object

scale = 1.0 # kernel parameter

P = 256 # number of projections for slicing
N, M = 10000, 10000 # Number of data points

# data generation
x = torch.randn((N, d), device=device, dtype=torch.float)
y = torch.randn((M, d), device=device, dtype=torch.float)
x_weights = torch.rand(x.shape[0]).to(x)

kernel_sum = fastsum(x, y, x_weights, scale, P) # compute kernel sum