Skip to content

Commit b1b9e46

Browse files
committed
mpi matrix vector multiplication
1 parent bc43539 commit b1b9e46

File tree

2 files changed

+370
-0
lines changed

2 files changed

+370
-0
lines changed

user/sujith/Mmpimatmul.c

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/* Matrix-vector multiplication using MPI*/
2+
/*
3+
Copyright (C) 2025 University of Texas at Austin
4+
5+
This program is free software; you can redistribute it and/or modify
6+
it under the terms of the GNU General Public License as published by
7+
the Free Software Foundation; either version 2 of the License, or
8+
(at your option) any later version.
9+
10+
This program is distributed in the hope that it will be useful,
11+
but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
GNU General Public License for more details.
14+
15+
You should have received a copy of the GNU General Public License
16+
along with this program; if not, write to the Free Software
17+
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18+
*/
19+
20+
21+
#include <stdio.h>
22+
#include <stdlib.h>
23+
#include <stdbool.h>
24+
#include <mpi.h>
25+
#include <rsf.h>
26+
#include <string.h>
27+
28+
int main(int argc, char* argv[])
29+
{
30+
int rank, size;
31+
int nrows=0, ncols=0; // global matrix dimensions
32+
int myrows=0, offset=0; // local rows (or local columns when adj)
33+
bool adj = false;
34+
sf_file in=NULL, out=NULL, mat=NULL;
35+
float *x=NULL, *y=NULL;
36+
float *alocal=NULL, *ylocal=NULL;
37+
38+
MPI_Init(&argc, &argv);
39+
sf_init(argc, argv);
40+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
41+
MPI_Comm_size(MPI_COMM_WORLD, &size);
42+
43+
/* Root reads matrix and vector info */
44+
if(rank==0){
45+
in = sf_input("--input"); // vector x
46+
out = sf_output("--output"); // result y
47+
mat = sf_input("mat"); // matrix A
48+
49+
if(!sf_histint(in,"n1",&ncols)) sf_error("No n1= in input vector");
50+
if(!sf_getbool("adj",&adj)) adj=false;
51+
52+
if(!sf_histint(mat,"n1",&ncols)) sf_error("No n1= in matrix");
53+
if(!sf_histint(mat,"n2",&nrows)) sf_error("No n2= in matrix");
54+
55+
sf_putint(out,"n1", adj ? ncols : nrows);
56+
}
57+
58+
/* Broadcast global sizes and adj flag */
59+
MPI_Bcast(&nrows, 1, MPI_INT, 0, MPI_COMM_WORLD);
60+
MPI_Bcast(&ncols, 1, MPI_INT, 0, MPI_COMM_WORLD);
61+
MPI_Bcast(&adj, 1, MPI_C_BOOL, 0, MPI_COMM_WORLD);
62+
63+
/* Decide local rows/columns for this rank */
64+
int rows_per_rank, remainder;
65+
if(adj){
66+
rows_per_rank = ncols / size; // distributing columns when adj
67+
remainder = ncols % size;
68+
} else {
69+
rows_per_rank = nrows / size; // distributing rows when forward
70+
remainder = nrows % size;
71+
}
72+
73+
myrows = rows_per_rank + (rank < remainder ? 1 : 0);
74+
offset = rank*rows_per_rank + (rank < remainder ? rank : remainder);
75+
76+
/* Allocate buffers */
77+
if(adj){
78+
/* alocal stores myrows columns, each column has length nrows:
79+
contiguous layout: alocal[i*nrows + j] = A[j, offset+i] */
80+
alocal = sf_floatalloc((size_t)myrows * nrows);
81+
ylocal = sf_floatalloc(ncols); // full y (will be reduced)
82+
} else {
83+
alocal = sf_floatalloc((size_t)myrows * ncols); // local rows
84+
ylocal = sf_floatalloc(myrows); // local result
85+
}
86+
87+
/* IMPORTANT: allocate x with correct length depending on adj */
88+
int xlen = adj ? nrows : ncols;
89+
x = sf_floatalloc(xlen);
90+
if(!alocal || !ylocal || !x) MPI_Abort(MPI_COMM_WORLD,1);
91+
92+
/* Root reads vector with correct length, then broadcast that many floats */
93+
if(rank==0){
94+
if(adj)
95+
sf_floatread(x, nrows, in); /* for adjoint x length must be nrows */
96+
else
97+
sf_floatread(x, ncols, in); /* for forward x length must be ncols */
98+
}
99+
MPI_Bcast(x, xlen, MPI_FLOAT, 0, MPI_COMM_WORLD);
100+
101+
/* Root reads full matrix and scatters (keeps your column-send approach for adj) */
102+
if(rank==0){
103+
float *Aall = sf_floatalloc((size_t)nrows * ncols);
104+
sf_floatread(Aall, nrows*ncols, mat);
105+
106+
for(int r=0;r<size;r++){
107+
int rrows = rows_per_rank + (r < remainder ? 1 : 0);
108+
int roffs = r*rows_per_rank + (r < remainder ? r : remainder);
109+
110+
if(adj){
111+
/* send rrows columns to rank r; each column has nrows floats */
112+
for(int i=0;i<rrows;i++){
113+
if(r==0){
114+
for(int j=0;j<nrows;j++)
115+
alocal[i*(size_t)nrows + j] = Aall[j*(size_t)ncols + roffs + i];
116+
} else {
117+
float *tmp = malloc((size_t)nrows * sizeof(float));
118+
for(int j=0;j<nrows;j++)
119+
tmp[j] = Aall[j*(size_t)ncols + roffs + i];
120+
MPI_Send(tmp, nrows, MPI_FLOAT, r, 0, MPI_COMM_WORLD);
121+
free(tmp);
122+
}
123+
}
124+
} else {
125+
if(r==0){
126+
memcpy(alocal, Aall + roffs*(size_t)ncols, (size_t)rrows*(size_t)ncols*sizeof(float));
127+
} else {
128+
MPI_Send(Aall + roffs*(size_t)ncols, rrows*ncols, MPI_FLOAT, r, 0, MPI_COMM_WORLD);
129+
}
130+
}
131+
}
132+
free(Aall);
133+
} else {
134+
if(adj){
135+
for(int i=0;i<myrows;i++){
136+
MPI_Recv(alocal + i*(size_t)nrows, nrows, MPI_FLOAT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
137+
}
138+
} else {
139+
MPI_Recv(alocal, myrows*ncols, MPI_FLOAT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
140+
}
141+
}
142+
143+
/* Local mat-vec multiply */
144+
if(adj){
145+
/* y = A^T x
146+
alocal is stored per-local-column: alocal[i*nrows + j] == A[j, offset+i]
147+
We'll compute local contribution to full y (length ncols),
148+
so initialize ylocal to zeros and add into ylocal[offset+i]. */
149+
for(int j=0;j<ncols;j++) ylocal[j] = 0.0f; /* zero full-length buffer */
150+
151+
for(int i=0;i<myrows;i++){ /* local columns (0..myrows-1) */
152+
int col = offset + i; /* global column index */
153+
float accum = 0.0f;
154+
for(int j=0;j<nrows;j++){
155+
/* alocal[i*nrows + j] == A[j, col] */
156+
/* x[j] corresponds to row j of A (x length is nrows for adjoint) */
157+
accum += alocal[i*(size_t)nrows + j] * x[j];
158+
}
159+
ylocal[col] = accum; /* write contribution into the correct position */
160+
}
161+
162+
/* reduce contributions across ranks: each rank has nonzero ylocal only
163+
on indices [offset .. offset+myrows-1], other entries zero */
164+
y = sf_floatalloc(ncols);
165+
MPI_Reduce(ylocal, y, ncols, MPI_FLOAT, MPI_SUM, 0, MPI_COMM_WORLD);
166+
} else {
167+
/* y = A x (forward). alocal stores rows: alocal[i*ncols + j] = A[offset+i, j] */
168+
for(int i=0;i<myrows;i++){
169+
float sum = 0.0f;
170+
for(int j=0;j<ncols;j++){
171+
sum += alocal[i*(size_t)ncols + j] * x[j];
172+
}
173+
ylocal[i] = sum;
174+
}
175+
}
176+
177+
/* Gather / write results */
178+
if(!adj){
179+
if(rank==0){
180+
y = sf_floatalloc(nrows);
181+
int pos = 0;
182+
for(int r=0;r<size;r++){
183+
int rrows = rows_per_rank + (r < remainder ? 1 : 0);
184+
if(r==0){
185+
memcpy(y, ylocal, rrows * sizeof(float));
186+
} else {
187+
MPI_Recv(y + pos, rrows, MPI_FLOAT, r, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
188+
}
189+
pos += rrows;
190+
}
191+
sf_floatwrite(y, nrows, out);
192+
free(y);
193+
} else {
194+
MPI_Send(ylocal, myrows, MPI_FLOAT, 0, 1, MPI_COMM_WORLD);
195+
}
196+
} else {
197+
if(rank==0){
198+
sf_floatwrite(y, ncols, out);
199+
free(y);
200+
}
201+
}
202+
203+
free(alocal);
204+
free(ylocal);
205+
free(x);
206+
207+
MPI_Finalize();
208+
return 0;
209+
}

user/sujith/SConstruct

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import os, sys, re, string
2+
sys.path.append('../../framework')
3+
import bldutil
4+
5+
progs = ''
6+
7+
8+
libprop = ''
9+
10+
ccprogs = ''
11+
12+
mpi_progs = 'mpimatmul'
13+
14+
15+
mpicxx_progs = ''
16+
17+
pyprogs = ''
18+
pymods = ''
19+
20+
try: # distributed version
21+
Import('env root pkgdir bindir libdir incdir')
22+
env = env.Clone()
23+
except: # local version
24+
env = bldutil.Debug()
25+
root = None
26+
SConscript('../lexing/SConstruct')
27+
28+
env.Prepend(CPPPATH=['../../include'],
29+
LIBPATH=['../../lib'],
30+
LIBS=[env.get('DYNLIB','')+'rsf'])
31+
32+
fftw = env.get('FFTW')
33+
if fftw:
34+
env.Prepend(CPPDEFINES=['SF_HAS_FFTW'])
35+
36+
src = Glob('[a-z]*.c')
37+
for source in src:
38+
inc = env.RSF_Include(source,prefix='')
39+
obj = env.StaticObject(source)
40+
env.Ignore(inc,inc)
41+
env.Depends(obj,inc)
42+
43+
if 'c++' in env.get('API',[]):
44+
lapack = env.get('LAPACK')
45+
else:
46+
lapack = None
47+
48+
csrc = Glob('[a-z]*.cc')
49+
for source in csrc:
50+
inc = env.RSF_Include(source,prefix='')
51+
env.Ignore(inc,inc)
52+
if lapack:
53+
obj = env.StaticObject(source)
54+
env.Depends(obj,inc)
55+
56+
mpicc = env.get('MPICC')
57+
mpicxx = env.get('MPICXX')
58+
mpi_src = Glob('Q[a-z]*.c')
59+
for source in mpi_src:
60+
inc = env.RSF_Include(source,prefix='')
61+
env.Ignore(inc,inc)
62+
if mpicc:
63+
obj = env.StaticObject(source,CC=mpicc)
64+
env.Depends(obj,inc)
65+
66+
mains = Split(progs+' '+libprop)
67+
for prog in mains:
68+
sources = ['M' + prog]
69+
bldutil.depends(env,sources,'M'+prog)
70+
prog = env.Program(prog,[x + '.c' for x in sources])
71+
if root:
72+
env.Install(bindir,prog)
73+
74+
mpi_mains = Split(mpi_progs)
75+
for prog in mpi_mains:
76+
sources = ['M' + prog]
77+
bldutil.depends(env,sources,'M'+prog)
78+
if mpicc:
79+
env.StaticObject('M'+prog+'.c',CC=mpicc)
80+
#for distributed FFTW3
81+
#prog = env.Program(prog,map(lambda x: x + '.o',sources),CC=mpicc,LIBS=env.get('LIBS')+['fftw3f_mpi'])
82+
prog = env.Program(prog,[x + '.o' for x in sources],CC=mpicc)
83+
else:
84+
prog = env.RSF_Place('sf'+prog,None,var='MPICC',package='mpi')
85+
if root:
86+
env.Install(bindir,prog)
87+
88+
89+
if lapack:
90+
libsxx = [env.get('DYNLIB','')+'rsf++','vecmatop']
91+
if not isinstance(lapack,bool):
92+
libsxx.extend(lapack)
93+
env.Prepend(LIBS=libsxx)
94+
95+
#ccsubs = 'lowrank.cc fftomp.c rtmutil.c ksutil.c revolve.c'
96+
ccmains = Split(ccprogs)
97+
for prog in ccmains:
98+
sources = ['M' + prog + '.cc']
99+
# if prog == 'cfftrtm3':
100+
# sources += Split(ccsubs)
101+
if lapack:
102+
prog = env.Program(prog,sources)
103+
else:
104+
prog = env.RSF_Place('sf'+prog,None,var='LAPACK',package='lapack')
105+
if root:
106+
env.Install(bindir,prog)
107+
108+
109+
##################################################################################################################
110+
# To use sfmpicfftrtm, one needs to obtain the source file revolve.c from http://dl.acm.org/citation.cfm?id=347846
111+
##################################################################################################################
112+
xxsubs = 'lowrank fftomp rtmutil ksutil revolve'
113+
mpicxx_mains = Split(mpicxx_progs)
114+
for prog in mpicxx_mains:
115+
sources = ['M' + prog] + Split(xxsubs)
116+
if FindFile('revolve.c','.') and mpicxx:
117+
env.StaticObject('M'+prog+'.cc',CXX=mpicxx)
118+
prog = env.Program(prog,[x + '.o' for x in sources],CXX=mpicxx)
119+
else:
120+
prog = env.RSF_Place('sf'+prog,None,var='MPICXX',package='mpi')
121+
if root:
122+
env.Install(bindir,prog)
123+
124+
# for prog in Split('cmatmult2'):
125+
# sources = ['Test' + prog,prog]
126+
# if prog=='cmatmult2':
127+
# sources.append('cgmres')
128+
# bldutil.depends(env,sources,prog)
129+
# sources = [x + '.o' for x in sources]
130+
# env.Object('Test' + prog + '.c')
131+
# env.Program(sources,PROGPREFIX='',PROGSUFFIX='.x')
132+
133+
######################################################################
134+
# PYTHON METAPROGRAMS (python API not needed)
135+
######################################################################
136+
137+
if root: # no compilation, just rename
138+
pymains = Split(pyprogs)
139+
exe = env.get('PROGSUFFIX','')
140+
for prog in pymains:
141+
binary = os.path.join(bindir,'sf'+prog+exe)
142+
env.InstallAs(binary,'M'+prog+'.py')
143+
env.AddPostAction(binary,Chmod(str(binary),0o755))
144+
for mod in Split(pymods):
145+
env.Install(pkgdir,mod+'.py')
146+
147+
######################################################################
148+
# SELF-DOCUMENTATION
149+
######################################################################
150+
151+
if root:
152+
user = os.path.basename(os.getcwd())
153+
main = 'sf%s.py' % user
154+
155+
docs = [env.Doc(prog,'M' + prog) for prog in mains+mpi_mains] + \
156+
[env.Doc(prog,'M'+prog+'.py',lang='python') for prog in pymains] + \
157+
[env.Doc(prog,'M%s.cc' % prog,lang='c++') for prog in ccmains+mpicxx_mains]
158+
159+
env.Depends(docs,'#/framework/rsf/doc.py')
160+
doc = env.RSF_Docmerge(main,docs)
161+
env.Install(pkgdir,doc)

0 commit comments

Comments
 (0)