From 01726197ab112e280e2c31936255834c51e362d1 Mon Sep 17 00:00:00 2001 From: CNugteren Date: Mon, 15 Jun 2015 08:38:24 +0200 Subject: Fixed a bug in AXPBY defines for complex data-types --- src/kernels/xgemm.opencl | 63 ++++++++++++++++++++++++------------------------ src/kernels/xgemv.opencl | 9 ++++--- 2 files changed, 38 insertions(+), 34 deletions(-) (limited to 'src/kernels') diff --git a/src/kernels/xgemm.opencl b/src/kernels/xgemm.opencl index facaf5dc..a4f45e90 100644 --- a/src/kernels/xgemm.opencl +++ b/src/kernels/xgemm.opencl @@ -293,42 +293,43 @@ inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int int idm = mg + get_group_id(0)*(MWG/VWM); int idn = ng + get_group_id(1)*NWG; int index = idn*(kSizeM/VWM) + idm; + realM cval = cgm[index]; #if VWM == 1 - AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cgm[index]); + AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval); #elif VWM == 2 - AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cgm[index].x); - AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cgm[index].y); + AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); + AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); #elif VWM == 4 - AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cgm[index].x); - AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cgm[index].y); - AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cgm[index].z); - AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cgm[index].w); + AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x); + AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y); + AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z); + AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w); #elif VWM == 8 - AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cgm[index].s0); - AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cgm[index].s1); - AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cgm[index].s2); - AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cgm[index].s3); - AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cgm[index].s4); - AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cgm[index].s5); - AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cgm[index].s6); - AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cgm[index].s7); + AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); + AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); + AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); + AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); + AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); + AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); + AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); + AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); #elif VWM == 16 - AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cgm[index].s0); - AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cgm[index].s1); - AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cgm[index].s2); - AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cgm[index].s3); - AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cgm[index].s4); - AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cgm[index].s5); - AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cgm[index].s6); - AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cgm[index].s7); - AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cgm[index].s8); - AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cgm[index].s9); - AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cgm[index].sA); - AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cgm[index].sB); - AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cgm[index].sC); - AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cgm[index].sD); - AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cgm[index].sE); - AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cgm[index].sF); + AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0); + AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1); + AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2); + AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3); + AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4); + AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5); + AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6); + AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7); + AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8); + AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9); + AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA); + AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB); + AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC); + AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD); + AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE); + AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF); #endif } } diff --git a/src/kernels/xgemv.opencl b/src/kernels/xgemv.opencl index b1b2fc69..5ea70e0d 100644 --- a/src/kernels/xgemv.opencl +++ b/src/kernels/xgemv.opencl @@ -133,7 +133,8 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta, } // Stores the final result - AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[w], beta, ygm[gid*y_inc + y_offset]); + real yval = ygm[gid*y_inc + y_offset]; + AXPBY(ygm[gid*y_inc + y_offset], alpha, acc[w], beta, yval); } } } @@ -239,7 +240,8 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b #pragma unroll for (int w=0; w