Skip to content

Commit

Permalink
moving aTV delta parameter to constant memory
Browse files Browse the repository at this point in the history
  • Loading branch information
kylechampley committed May 5, 2024
1 parent 9035cde commit 83d6ce0
Showing 1 changed file with 109 additions and 107 deletions.
216 changes: 109 additions & 107 deletions src/total_variation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#define USE_P_HUBER

__constant__ float d_HUBER_P;
__constant__ float d_HUBER_DELTA;
__constant__ float d_HUBER_SLOPE;
__constant__ float d_HUBER_SHIFT;

Expand All @@ -29,50 +30,50 @@ __forceinline__ __device__ float square(const float x)
return x * x;
}

__forceinline__ __device__ float Huber(const float x, const float delta)
__forceinline__ __device__ float Huber(const float x)
{
#ifdef USE_P_HUBER
if (fabs(x) <= delta)
if (fabs(x) <= d_HUBER_DELTA)
return 0.5f * x * x;
else
return d_HUBER_SLOPE * pow(fabs(x), d_HUBER_P) + d_HUBER_SHIFT;
//return delta * delta * (5.0f * pow(fabs(x / delta), 1.2f) - 2.0f) / 6.0f;
#else
if (fabs(x) <= delta)
if (fabs(x) <= d_HUBER_DELTA)
return 0.5f * x * x;
else
return delta * (fabs(x) - 0.5f * delta);
return d_HUBER_DELTA * (fabs(x) - 0.5f * d_HUBER_DELTA);
#endif
}

__forceinline__ __device__ float DHuber(const float x, const float delta)
__forceinline__ __device__ float DHuber(const float x)
{
#ifdef USE_P_HUBER
if (fabs(x) <= delta)
if (fabs(x) <= d_HUBER_DELTA)
return x;
else
return d_HUBER_P * d_HUBER_SLOPE * pow(fabs(x), d_HUBER_P) / x;
//return x * pow(fabs(x / delta), -0.8f); // p=1.2
//return delta * pow(fabs(x/delta), 0.2f);
#else
if (fabs(x) <= delta)
if (fabs(x) <= d_HUBER_DELTA)
return x;
else
return (x > 0.0f) ? delta : -delta;
return (x > 0.0f) ? d_HUBER_DELTA : -d_HUBER_DELTA;
#endif
}

__forceinline__ __device__ float DDHuber(const float x, const float delta)
__forceinline__ __device__ float DDHuber(const float x)
{
#ifdef USE_P_HUBER
if (fabs(x) <= delta)
if (fabs(x) <= d_HUBER_DELTA)
return 1.0;
else
return d_HUBER_P * d_HUBER_SLOPE * pow(fabs(x), d_HUBER_P - 2.0f);
//return pow(fabs(x / delta), d_HUBER_P - 2.0f); // require a divide
//return pow(fabs(x / delta), -0.8f); // p=1.2
#else
return delta / max(delta, fabs(x));
return d_HUBER_DELTA / max(d_HUBER_DELTA, fabs(x));
#endif
}

Expand All @@ -98,43 +99,43 @@ __device__ float aTV_Huber_costTerm(float* f, const int i, const int j, const in

if (numNeighbors == 6)
{
return (Huber(curVal - f_i_plus[j * N.z + k], delta) +
Huber(curVal - f_i_minus[j * N.z + k], delta) +
Huber(curVal - f_i[j_plus * N.z + k], delta) +
Huber(curVal - f_i[j_minus * N.z + k], delta) +
Huber(curVal - f_i[j * N.z + k_plus], delta) +
Huber(curVal - f_i[j * N.z + k_minus], delta)) * dist_1;
return (Huber(curVal - f_i_plus[j * N.z + k]) +
Huber(curVal - f_i_minus[j * N.z + k]) +
Huber(curVal - f_i[j_plus * N.z + k]) +
Huber(curVal - f_i[j_minus * N.z + k]) +
Huber(curVal - f_i[j * N.z + k_plus]) +
Huber(curVal - f_i[j * N.z + k_minus])) * dist_1;
}
else
{
return (Huber(curVal - f_i_plus[j * N.z + k], delta) +
Huber(curVal - f_i_minus[j * N.z + k], delta) +
Huber(curVal - f_i[j_plus * N.z + k], delta) +
Huber(curVal - f_i[j_minus * N.z + k], delta) +
Huber(curVal - f_i[j * N.z + k_plus], delta) +
Huber(curVal - f_i[j * N.z + k_minus], delta)) *
return (Huber(curVal - f_i_plus[j * N.z + k]) +
Huber(curVal - f_i_minus[j * N.z + k]) +
Huber(curVal - f_i[j_plus * N.z + k]) +
Huber(curVal - f_i[j_minus * N.z + k]) +
Huber(curVal - f_i[j * N.z + k_plus]) +
Huber(curVal - f_i[j * N.z + k_minus])) *
dist_1 +
(Huber(curVal - f_i_plus[j_plus * N.z + k], delta) +
Huber(curVal - f_i_plus[j_minus * N.z + k], delta) +
Huber(curVal - f_i_plus[j * N.z + k_plus], delta) +
Huber(curVal - f_i_plus[j * N.z + k_minus], delta) +
Huber(curVal - f_i_minus[j_plus * N.z + k], delta) +
Huber(curVal - f_i_minus[j_minus * N.z + k], delta) +
Huber(curVal - f_i_minus[j * N.z + k_plus], delta) +
Huber(curVal - f_i_minus[j * N.z + k_minus], delta) +
Huber(curVal - f_i[j_plus * N.z + k_plus], delta) +
Huber(curVal - f_i[j_plus * N.z + k_minus], delta) +
Huber(curVal - f_i[j_minus * N.z + k_plus], delta) +
Huber(curVal - f_i[j_minus * N.z + k_minus], delta)) *
(Huber(curVal - f_i_plus[j_plus * N.z + k]) +
Huber(curVal - f_i_plus[j_minus * N.z + k]) +
Huber(curVal - f_i_plus[j * N.z + k_plus]) +
Huber(curVal - f_i_plus[j * N.z + k_minus]) +
Huber(curVal - f_i_minus[j_plus * N.z + k]) +
Huber(curVal - f_i_minus[j_minus * N.z + k]) +
Huber(curVal - f_i_minus[j * N.z + k_plus]) +
Huber(curVal - f_i_minus[j * N.z + k_minus]) +
Huber(curVal - f_i[j_plus * N.z + k_plus]) +
Huber(curVal - f_i[j_plus * N.z + k_minus]) +
Huber(curVal - f_i[j_minus * N.z + k_plus]) +
Huber(curVal - f_i[j_minus * N.z + k_minus])) *
dist_2 +
(Huber(curVal - f_i_plus[j_plus * N.z + k_plus], delta) +
Huber(curVal - f_i_plus[j_plus * N.z + k_minus], delta) +
Huber(curVal - f_i_plus[j_minus * N.z + k_plus], delta) +
Huber(curVal - f_i_plus[j_minus * N.z + k_minus], delta) +
Huber(curVal - f_i_minus[j_plus * N.z + k_plus], delta) +
Huber(curVal - f_i_minus[j_plus * N.z + k_minus], delta) +
Huber(curVal - f_i_minus[j_minus * N.z + k_plus], delta) +
Huber(curVal - f_i_minus[j_minus * N.z + k_minus], delta)) *
(Huber(curVal - f_i_plus[j_plus * N.z + k_plus]) +
Huber(curVal - f_i_plus[j_plus * N.z + k_minus]) +
Huber(curVal - f_i_plus[j_minus * N.z + k_plus]) +
Huber(curVal - f_i_plus[j_minus * N.z + k_minus]) +
Huber(curVal - f_i_minus[j_plus * N.z + k_plus]) +
Huber(curVal - f_i_minus[j_plus * N.z + k_minus]) +
Huber(curVal - f_i_minus[j_minus * N.z + k_plus]) +
Huber(curVal - f_i_minus[j_minus * N.z + k_minus])) *
dist_3;
}
//*/
Expand Down Expand Up @@ -217,74 +218,74 @@ __device__ float aTV_Huber_quadFormTerm(float* f, float* d, const int i, const i

if (numNeighbors == 6)
{
return (DDHuber(curVal - f_i_plus[j * N.z + k], delta) *
return (DDHuber(curVal - f_i_plus[j * N.z + k]) *
square(curVal_d - d_i_plus[j * N.z + k]) +
DDHuber(curVal - f_i_minus[j * N.z + k], delta) *
DDHuber(curVal - f_i_minus[j * N.z + k]) *
square(curVal_d - d_i_minus[j * N.z + k]) +
DDHuber(curVal - f_i[j_plus * N.z + k], delta) *
DDHuber(curVal - f_i[j_plus * N.z + k]) *
square(curVal_d - d_i[j_plus * N.z + k]) +
DDHuber(curVal - f_i[j_minus * N.z + k], delta) *
DDHuber(curVal - f_i[j_minus * N.z + k]) *
square(curVal_d - d_i[j_minus * N.z + k]) +
DDHuber(curVal - f_i[j * N.z + k_plus], delta) *
DDHuber(curVal - f_i[j * N.z + k_plus]) *
square(curVal_d - d_i[j * N.z + k_plus]) +
DDHuber(curVal - f_i[j * N.z + k_minus], delta) *
DDHuber(curVal - f_i[j * N.z + k_minus]) *
square(curVal_d - d_i[j * N.z + k_minus])) * dist_1;
}
else
{
return (DDHuber(curVal - f_i_plus[j * N.z + k], delta) *
return (DDHuber(curVal - f_i_plus[j * N.z + k]) *
square(curVal_d - d_i_plus[j * N.z + k]) +
DDHuber(curVal - f_i_minus[j * N.z + k], delta) *
DDHuber(curVal - f_i_minus[j * N.z + k]) *
square(curVal_d - d_i_minus[j * N.z + k]) +
DDHuber(curVal - f_i[j_plus * N.z + k], delta) *
DDHuber(curVal - f_i[j_plus * N.z + k]) *
square(curVal_d - d_i[j_plus * N.z + k]) +
DDHuber(curVal - f_i[j_minus * N.z + k], delta) *
DDHuber(curVal - f_i[j_minus * N.z + k]) *
square(curVal_d - d_i[j_minus * N.z + k]) +
DDHuber(curVal - f_i[j * N.z + k_plus], delta) *
DDHuber(curVal - f_i[j * N.z + k_plus]) *
square(curVal_d - d_i[j * N.z + k_plus]) +
DDHuber(curVal - f_i[j * N.z + k_minus], delta) *
DDHuber(curVal - f_i[j * N.z + k_minus]) *
square(curVal_d - d_i[j * N.z + k_minus])) *
dist_1 +
(DDHuber(curVal - f_i_plus[j_plus * N.z + k], delta) *
(DDHuber(curVal - f_i_plus[j_plus * N.z + k]) *
square(curVal_d - d_i_plus[j_plus * N.z + k]) +
DDHuber(curVal - f_i_plus[j_minus * N.z + k], delta) *
DDHuber(curVal - f_i_plus[j_minus * N.z + k]) *
square(curVal_d - d_i_plus[j_minus * N.z + k]) +
DDHuber(curVal - f_i_plus[j * N.z + k_plus], delta) *
DDHuber(curVal - f_i_plus[j * N.z + k_plus]) *
square(curVal_d - d_i_plus[j * N.z + k_plus]) +
DDHuber(curVal - f_i_plus[j * N.z + k_minus], delta) *
DDHuber(curVal - f_i_plus[j * N.z + k_minus]) *
square(curVal_d - d_i_plus[j * N.z + k_minus]) +
DDHuber(curVal - f_i_minus[j_plus * N.z + k], delta) *
DDHuber(curVal - f_i_minus[j_plus * N.z + k]) *
square(curVal_d - d_i_minus[j_plus * N.z + k]) +
DDHuber(curVal - f_i_minus[j_minus * N.z + k], delta) *
DDHuber(curVal - f_i_minus[j_minus * N.z + k]) *
square(curVal_d - d_i_minus[j_minus * N.z + k]) +
DDHuber(curVal - f_i_minus[j * N.z + k_plus], delta) *
DDHuber(curVal - f_i_minus[j * N.z + k_plus]) *
square(curVal_d - d_i_minus[j * N.z + k_plus]) +
DDHuber(curVal - f_i_minus[j * N.z + k_minus], delta) *
DDHuber(curVal - f_i_minus[j * N.z + k_minus]) *
square(curVal_d - d_i_minus[j * N.z + k_minus]) +
DDHuber(curVal - f_i[j_plus * N.z + k_plus], delta) *
DDHuber(curVal - f_i[j_plus * N.z + k_plus]) *
square(curVal_d - d_i[j_plus * N.z + k_plus]) +
DDHuber(curVal - f_i[j_plus * N.z + k_minus], delta) *
DDHuber(curVal - f_i[j_plus * N.z + k_minus]) *
square(curVal_d - d_i[j_plus * N.z + k_minus]) +
DDHuber(curVal - f_i[j_minus * N.z + k_plus], delta) *
DDHuber(curVal - f_i[j_minus * N.z + k_plus]) *
square(curVal_d - d_i[j_minus * N.z + k_plus]) +
DDHuber(curVal - f_i[j_minus * N.z + k_minus], delta) *
DDHuber(curVal - f_i[j_minus * N.z + k_minus]) *
square(curVal_d - d_i[j_minus * N.z + k_minus])) *
dist_2 +
(DDHuber(curVal - f_i_plus[j_plus * N.z + k_plus], delta) *
(DDHuber(curVal - f_i_plus[j_plus * N.z + k_plus]) *
square(curVal_d - d_i_plus[j_plus * N.z + k_plus]) +
DDHuber(curVal - f_i_plus[j_plus * N.z + k_minus], delta) *
DDHuber(curVal - f_i_plus[j_plus * N.z + k_minus]) *
square(curVal_d - d_i_plus[j_plus * N.z + k_minus]) +
DDHuber(curVal - f_i_plus[j_minus * N.z + k_plus], delta) *
DDHuber(curVal - f_i_plus[j_minus * N.z + k_plus]) *
square(curVal_d - d_i_plus[j_minus * N.z + k_plus]) +
DDHuber(curVal - f_i_plus[j_minus * N.z + k_minus], delta) *
DDHuber(curVal - f_i_plus[j_minus * N.z + k_minus]) *
square(curVal_d - d_i_plus[j_minus * N.z + k_minus]) +
DDHuber(curVal - f_i_minus[j_plus * N.z + k_plus], delta) *
DDHuber(curVal - f_i_minus[j_plus * N.z + k_plus]) *
square(curVal_d - d_i_minus[j_plus * N.z + k_plus]) +
DDHuber(curVal - f_i_minus[j_plus * N.z + k_minus], delta) *
DDHuber(curVal - f_i_minus[j_plus * N.z + k_minus]) *
square(curVal_d - d_i_minus[j_plus * N.z + k_minus]) +
DDHuber(curVal - f_i_minus[j_minus * N.z + k_plus], delta) *
DDHuber(curVal - f_i_minus[j_minus * N.z + k_plus]) *
square(curVal_d - d_i_minus[j_minus * N.z + k_plus]) +
DDHuber(curVal - f_i_minus[j_minus * N.z + k_minus], delta) *
DDHuber(curVal - f_i_minus[j_minus * N.z + k_minus]) *
square(curVal_d - d_i_minus[j_minus * N.z + k_minus])) *
dist_3;
}
Expand Down Expand Up @@ -402,43 +403,43 @@ __global__ void aTV_Huber_gradient(float* f, float* Df, int3 N, float delta, flo
// dist 3: 8
if (numNeighbors == 6)
{
Df[uint64(i) * uint64(N.y * N.z) + uint64(j * N.z + k)] = (DHuber(curVal - f_i_plus[j * N.z + k], delta) +
DHuber(curVal - f_i_minus[j * N.z + k], delta) +
DHuber(curVal - f_i[j_plus * N.z + k], delta) +
DHuber(curVal - f_i[j_minus * N.z + k], delta) +
DHuber(curVal - f_i[j * N.z + k_plus], delta) +
DHuber(curVal - f_i[j * N.z + k_minus], delta)) * dist_1;
Df[uint64(i) * uint64(N.y * N.z) + uint64(j * N.z + k)] = (DHuber(curVal - f_i_plus[j * N.z + k]) +
DHuber(curVal - f_i_minus[j * N.z + k]) +
DHuber(curVal - f_i[j_plus * N.z + k]) +
DHuber(curVal - f_i[j_minus * N.z + k]) +
DHuber(curVal - f_i[j * N.z + k_plus]) +
DHuber(curVal - f_i[j * N.z + k_minus])) * dist_1;
}
else
{
Df[uint64(i) * uint64(N.y * N.z) + uint64(j * N.z + k)] = (DHuber(curVal - f_i_plus[j * N.z + k], delta) +
DHuber(curVal - f_i_minus[j * N.z + k], delta) +
DHuber(curVal - f_i[j_plus * N.z + k], delta) +
DHuber(curVal - f_i[j_minus * N.z + k], delta) +
DHuber(curVal - f_i[j * N.z + k_plus], delta) +
DHuber(curVal - f_i[j * N.z + k_minus], delta)) *
Df[uint64(i) * uint64(N.y * N.z) + uint64(j * N.z + k)] = (DHuber(curVal - f_i_plus[j * N.z + k]) +
DHuber(curVal - f_i_minus[j * N.z + k]) +
DHuber(curVal - f_i[j_plus * N.z + k]) +
DHuber(curVal - f_i[j_minus * N.z + k]) +
DHuber(curVal - f_i[j * N.z + k_plus]) +
DHuber(curVal - f_i[j * N.z + k_minus])) *
dist_1 +
(DHuber(curVal - f_i_plus[j_plus * N.z + k], delta) +
DHuber(curVal - f_i_plus[j_minus * N.z + k], delta) +
DHuber(curVal - f_i_plus[j * N.z + k_plus], delta) +
DHuber(curVal - f_i_plus[j * N.z + k_minus], delta) +
DHuber(curVal - f_i_minus[j_plus * N.z + k], delta) +
DHuber(curVal - f_i_minus[j_minus * N.z + k], delta) +
DHuber(curVal - f_i_minus[j * N.z + k_plus], delta) +
DHuber(curVal - f_i_minus[j * N.z + k_minus], delta) +
DHuber(curVal - f_i[j_plus * N.z + k_plus], delta) +
DHuber(curVal - f_i[j_plus * N.z + k_minus], delta) +
DHuber(curVal - f_i[j_minus * N.z + k_plus], delta) +
DHuber(curVal - f_i[j_minus * N.z + k_minus], delta)) *
(DHuber(curVal - f_i_plus[j_plus * N.z + k]) +
DHuber(curVal - f_i_plus[j_minus * N.z + k]) +
DHuber(curVal - f_i_plus[j * N.z + k_plus]) +
DHuber(curVal - f_i_plus[j * N.z + k_minus]) +
DHuber(curVal - f_i_minus[j_plus * N.z + k]) +
DHuber(curVal - f_i_minus[j_minus * N.z + k]) +
DHuber(curVal - f_i_minus[j * N.z + k_plus]) +
DHuber(curVal - f_i_minus[j * N.z + k_minus]) +
DHuber(curVal - f_i[j_plus * N.z + k_plus]) +
DHuber(curVal - f_i[j_plus * N.z + k_minus]) +
DHuber(curVal - f_i[j_minus * N.z + k_plus]) +
DHuber(curVal - f_i[j_minus * N.z + k_minus])) *
dist_2 +
(DHuber(curVal - f_i_plus[j_plus * N.z + k_plus], delta) +
DHuber(curVal - f_i_plus[j_plus * N.z + k_minus], delta) +
DHuber(curVal - f_i_plus[j_minus * N.z + k_plus], delta) +
DHuber(curVal - f_i_plus[j_minus * N.z + k_minus], delta) +
DHuber(curVal - f_i_minus[j_plus * N.z + k_plus], delta) +
DHuber(curVal - f_i_minus[j_plus * N.z + k_minus], delta) +
DHuber(curVal - f_i_minus[j_minus * N.z + k_plus], delta) +
DHuber(curVal - f_i_minus[j_minus * N.z + k_minus], delta)) *
(DHuber(curVal - f_i_plus[j_plus * N.z + k_plus]) +
DHuber(curVal - f_i_plus[j_plus * N.z + k_minus]) +
DHuber(curVal - f_i_plus[j_minus * N.z + k_plus]) +
DHuber(curVal - f_i_plus[j_minus * N.z + k_minus]) +
DHuber(curVal - f_i_minus[j_plus * N.z + k_plus]) +
DHuber(curVal - f_i_minus[j_plus * N.z + k_minus]) +
DHuber(curVal - f_i_minus[j_minus * N.z + k_plus]) +
DHuber(curVal - f_i_minus[j_minus * N.z + k_minus])) *
dist_3;
}
//*/
Expand Down Expand Up @@ -490,6 +491,7 @@ void setConstantMemoryParameters(const float delta, const float p)
//d_HUBER_SLOPE * pow(fabs(x), d_HUBER_P) + d_HUBER_SHIFT

cudaMemcpyToSymbol(d_HUBER_P, &p, sizeof(float));
cudaMemcpyToSymbol(d_HUBER_DELTA, &delta, sizeof(float));
cudaMemcpyToSymbol(d_HUBER_SLOPE, &HuberSlope, sizeof(float));
cudaMemcpyToSymbol(d_HUBER_SHIFT, &HuberShift, sizeof(float));
}
Expand Down

0 comments on commit 83d6ce0

Please sign in to comment.