aboutsummaryrefslogtreecommitdiffstats
path: root/test/mppa
diff options
context:
space:
mode:
authorCyril SIX <cyril.six@kalray.eu>2018-04-18 16:44:43 +0200
committerCyril SIX <cyril.six@kalray.eu>2018-04-18 16:44:43 +0200
commit41a048fa4bb9ddefd4e4acff2207251bb3ddbf06 (patch)
treed2395bd4f0a130631d5f60f5b3da9aabcc96ae02 /test/mppa
parentb7021853e651ddde91450cc83d3c77c5377efc06 (diff)
downloadcompcert-kvx-41a048fa4bb9ddefd4e4acff2207251bb3ddbf06.tar.gz
compcert-kvx-41a048fa4bb9ddefd4e4acff2207251bb3ddbf06.zip
MPPA - Added divide & conqueer test matmul
Diffstat (limited to 'test/mppa')
-rw-r--r--test/mppa/mmult/mmult.c88
-rw-r--r--test/mppa/mmult/mmult.h1
2 files changed, 87 insertions, 2 deletions
diff --git a/test/mppa/mmult/mmult.c b/test/mppa/mmult/mmult.c
index 16dcf34c..c9e7ad5e 100644
--- a/test/mppa/mmult/mmult.c
+++ b/test/mppa/mmult/mmult.c
@@ -31,8 +31,91 @@ void mmult_col(uint64_t C[][SIZE], uint64_t A[][SIZE], uint64_t B[][SIZE]){
C[i][j] += A[i][k] * B[k][j];
}
+typedef struct mblock {
+ int imin, imax, jmin, jmax;
+ uint64_t *mat;
+} mblock;
+
+#define MAT_XY(mat, x, y) (mat)[(x)*SIZE + (y)]
+#define MAT_IJ(block, i, j) MAT_XY((block)->mat, (block)->imin + (i), block->jmin + (j))
+
+int strassen_mul(mblock *C, const mblock *A, const mblock *B){
+ const int size = C->imax - C->imin;
+
+ for (int i = 0 ; i < size ; i++)
+ for (int j = 0 ; j < size ; j++)
+ for (int k = 0 ; k < size ; k++)
+ MAT_IJ(C, i, j) += MAT_IJ(A, i, k) * MAT_IJ(B, k, j);
+}
+
+#define BLOCK_X_MID(block) ((block)->imin + (block)->imax) / 2
+#define BLOCK_Y_MID(block) ((block)->jmin + (block)->jmax) / 2
+
+#define MAKE_MBLOCK(newb, block, I, J) \
+ mblock newb = {.mat=(block)->mat};\
+ if ((I) == 0){\
+ newb.imin = (block)->imin;\
+ newb.imax = BLOCK_X_MID((block));\
+ } else {\
+ newb.imin = BLOCK_X_MID((block));\
+ newb.imax = (block)->imax;\
+ } if ((J) == 0){\
+ newb.jmin = (block)->jmin;\
+ newb.jmax = BLOCK_Y_MID((block));\
+ } else {\
+ newb.jmin = BLOCK_Y_MID((block));\
+ newb.jmax = (block)->jmax;\
+ }
+
+int strassen_part(mblock *C, const mblock *A, const mblock *B);
+
+void strassen_wrap(mblock *C , char IC, char JC,
+ const mblock *A, char IA, char JA,
+ const mblock *B, char IB, char JB){
+ MAKE_MBLOCK(Cb, C, IC, JC);
+ MAKE_MBLOCK(Ab, A, IA, JA);
+ MAKE_MBLOCK(Bb, B, IB, JB);
+
+ strassen_part(&Cb, &Ab, &Bb);
+}
+
+
+int strassen_part(mblock *C, const mblock *A, const mblock *B){
+ const int size = C->imax - C->imin;
+
+ if (size % 2 == 1)
+ strassen_mul(C, A, B);
+ else{
+ /* C_00 = A_00 B_00 + A_01 B_10 */
+ strassen_wrap(C, 0, 0, A, 0, 0, B, 0, 0);
+ strassen_wrap(C, 0, 0, A, 0, 1, B, 1, 0);
+
+ /* C_10 = A_10 B_00 + A_11 B_10 */
+ strassen_wrap(C, 1, 0, A, 1, 0, B, 0, 0);
+ strassen_wrap(C, 1, 0, A, 1, 1, B, 1, 0);
+
+ /* C_01 = A_00 B_01 + A_01 B_11 */
+ strassen_wrap(C, 0, 1, A, 0, 0, B, 0, 1);
+ strassen_wrap(C, 0, 1, A, 0, 1, B, 1, 1);
+
+ /* C_11 = A_10 B_01 + A_11 B_11 */
+ strassen_wrap(C, 1, 1, A, 1, 0, B, 0, 1);
+ strassen_wrap(C, 1, 1, A, 1, 1, B, 1, 1);
+ }
+
+}
+
+void mmult_strassen(uint64_t C[][SIZE], uint64_t A[][SIZE], uint64_t B[][SIZE]){
+ mblock Cb = {.mat = (uint64_t *) C, .imin = 0, .imax = SIZE, .jmin = 0, .jmax = SIZE};
+ mblock Ab = {.mat = (uint64_t *) A , .imin = 0, .imax = SIZE, .jmin = 0, .jmax = SIZE};
+ mblock Bb = {.mat = (uint64_t *) B , .imin = 0, .imax = SIZE, .jmin = 0, .jmax = SIZE};
+
+ strassen_part(&Cb, &Ab, &Bb);
+}
+
#ifdef __UNIT_TEST_MMULT__
-static uint64_t C1[SIZE][SIZE], C2[SIZE][SIZE], A[SIZE][SIZE], B[SIZE][SIZE];
+static uint64_t C1[SIZE][SIZE], C2[SIZE][SIZE], C3[SIZE][SIZE];
+static uint64_t A[SIZE][SIZE], B[SIZE][SIZE];
int main(void){
srand(42);
@@ -45,10 +128,11 @@ int main(void){
mmult_row(C1, A, B);
mmult_col(C2, A, B);
+ mmult_strassen(C3, A, B);
for (int i = 0 ; i < SIZE ; i++)
for (int j = 0 ; j < SIZE ; j++)
- if (C1[i][j] != C2[i][j])
+ if (!(C1[i][j] == C2[i][j] && C1[i][j] == C3[i][j]))
return -1;
return 0;
diff --git a/test/mppa/mmult/mmult.h b/test/mppa/mmult/mmult.h
index 50c04afd..3721784a 100644
--- a/test/mppa/mmult/mmult.h
+++ b/test/mppa/mmult/mmult.h
@@ -5,5 +5,6 @@
void mmult_row(uint64_t *A, const uint64_t *B, const uint64_t *C);
void mmult_column(uint64_t *A, const uint64_t *B, const uint64_t *C);
+void mmult_strassen(uint64_t *A, const uint64_t *B, const uint64_t *C);
#endif /* __MMULT_H__ */