Browse Source

dimensions fix

master
anapt 7 years ago
parent
commit
67f40ab195
  1. 2
      mean_shift_cuda/meanshift_gpu_utils.cu
  2. 12
      mean_shift_cuda/meanshift_kernels.cu

2
mean_shift_cuda/meanshift_gpu_utils.cu

@ -361,7 +361,7 @@ void calculate_norm(Matrix d_mean_shift_vector, double *current_norm){
dim3 dimGrid; dim3 dimGrid;
do { do {
dimBlock.x = requested_block_size; dimBlock.x = requested_block_size;
dimBlock.y = d_mean_shift_vector.width; dimBlock.y = 1;
dimGrid.x = (d_mean_shift_vector.height + dimBlock.x - 1) / dimBlock.x; dimGrid.x = (d_mean_shift_vector.height + dimBlock.x - 1) / dimBlock.x;
dimGrid.y = 1; dimGrid.y = 1;

12
mean_shift_cuda/meanshift_kernels.cu

@ -86,13 +86,21 @@ __global__ void norm(Matrix mean_shift_vector, double *current_norm) {
// by accumulating results into cell_value // by accumulating results into cell_value
double cell_value = 0; double cell_value = 0;
int row = blockIdx.x * blockDim.x + threadIdx.x; int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
// performs calculations only if thread's indexes are within matrix bounds // performs calculations only if thread's indexes are within matrix bounds
if (row * mean_shift_vector.width + col >= mean_shift_vector.width * mean_shift_vector.height){ if (row >= denominator.height){
return; return;
} }
// for (int column = 0; column < kernel_matrix.width; ++column){
// cell_value += kernel_matrix.elements[row * kernel_matrix.width + column];
// }
denominator.elements[row] = cell_value;
// performs calculations only if thread's indexes are within matrix bounds
// if (row * mean_shift_vector.width + col >= mean_shift_vector.width * mean_shift_vector.height){
// return;
// }
for (int element_index = 0; element_index < mean_shift_vector.width; ++element_index){ for (int element_index = 0; element_index < mean_shift_vector.width; ++element_index){
cell_value += mean_shift_vector.elements[row * mean_shift_vector.width + element_index] cell_value += mean_shift_vector.elements[row * mean_shift_vector.width + element_index]
* mean_shift_vector.elements[row * mean_shift_vector.width + element_index]; * mean_shift_vector.elements[row * mean_shift_vector.width + element_index];

Loading…
Cancel
Save