$$
\begin{align*}
\text{destination of a one}&=\text{\#zeros in total}+\text{\#ones before} \\
&=(\text{\#keys in total}-\text{\#ones in total})+\text{\#ones before} \\
&=\text{input size}-\text{\#ones in total}+\text{\#ones before}
\end{align*}
$$
Parallelizing a Radix Sort Iteration by Assigning One Input Key to Each Thread
下图展示了每个线程查找其键的目标索引所执行的操作。
Finding the Destination of Each Input Key
对应的内核代码如下所示。在每个线程确定自己的索引并提取出对应的 bit 后,因为这些位不是 0 就是 1,所以排除扫描的结果就等于索引前面 1 的个数。
__global__voidexclusiveScan(unsignedint*bits,intN){extern__shared__unsignedinttemp[];intthid=threadIdx.x;intoffset=1;// Load input into shared memory
temp[2*thid]=(2*thid<N)?bits[2*thid]:0;temp[2*thid+1]=(2*thid+1<N)?bits[2*thid+1]:0;// Build sum in place up the tree
for(intd=N>>1;d>0;d>>=1){__syncthreads();if(thid<d){intai=offset*(2*thid+1)-1;intbi=offset*(2*thid+2)-1;temp[bi]+=temp[ai];}offset*=2;}// Clear the last element
if(thid==0){temp[N-1]=0;}// Traverse down the tree
for(intd=1;d<N;d*=2){offset>>=1;__syncthreads();if(thid<d){intai=offset*(2*thid+1)-1;// left child index of the thread
intbi=offset*(2*thid+2)-1;// right
unsignedintt=temp[ai];temp[ai]=temp[bi];temp[bi]+=t;}}// Write results to output array
__syncthreads();if(2*thid<N)bits[2*thid]=temp[2*thid];if(2*thid+1<N)bits[2*thid+1]=temp[2*thid+1];}__global__voidradix_sort_iter(unsignedint*input,unsignedint*output,unsignedint*bits,intN,unsignedintiter){unsignedinti=blockIdx.x*blockDim.x+threadIdx.x;unsignedintkey,bit;if(i<N){key=input[i];bit=(key>>iter)&1;bits[i]=bit;}exclusiveScan(bits,N);// # ones before
if(i<N){unsignedintnumberOnesBefore=bits[i];unsignedintnumberOnesTotal=bits[N];unsignedintdst=(bit==0)?(i-numberOnesBefore):(N-numberOnesTotal-numberOnesBefore);output[dst]=key;}}
#define SECTION_SIZE 32
__global__voidmemory_coalescing_radix_sort(unsignedint*input,unsignedint*output,unsignedint*bits,unsignedint*table,intN,intiter){__shared__unsignedintinput_s[SECTION_SIZE];__shared__unsignedintoutput_s[SECTION_SIZE];// Load input into shared memory
unsignedintglobalIdx=blockIdx.x*blockDim.x+threadIdx.x;if(globalIdx<N){input_s[threadIdx.x]=input[globalIdx];}__syncthreads();// Sort each section
radix_sort_iter(input_s,output_s,bits+blockIdx.x*SECTION_SIZE,SECTION_SIZE,iter);__syncthreads();// Store local bucket num
if(threadIdx.x==0){unsignedintnumberOnesTotal=0;unsignedintnumberZerosTotal=0;for(inti=0;i<SECTION_SIZE;++i){numberOnesTotal+=bits[blockIdx.x*SECTION_SIZE+i];}numberZerosTotal=SECTION_SIZE-numberOnesTotal;table[blockIdx.x]=numberZerosTotal;table[blockIdx.x+gridDim.x]=numberOnesTotal;}__syncthreads();// Exclusive prefix sum to determine output index
exclusiveScan(table,2*gridDim.x);// Write results to output array
if(globalIdx<N){intzeroOffset=table[blockIdx.x];intoneOffset=table[blockIdx.x+gridDim.x];unsignedintbit=bits[blockIdx.x*SECTION_SIZE+threadIdx.x];unsignedintdst=(bit==0)?(globalIdx-zeroOffset):(N-oneOffset);output[dst]=input[globalIdx];}}