diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java index 955bf778b8b..fc430374599 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java @@ -71,7 +71,7 @@ public class TemplateRow extends TemplateBase private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{ OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN, OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH, - OpOp1.CUMSUM, OpOp1.CUMMIN, OpOp1.CUMMAX, OpOp1.SPROP, OpOp1.SIGMOID}; + OpOp1.CUMSUM, OpOp1.ROWCUMSUM, OpOp1.CUMMIN, OpOp1.CUMMAX, OpOp1.SPROP, OpOp1.SIGMOID}; private static final OpOp2[] SUPPORTED_VECT_BINARY = new OpOp2[]{ OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS, OpOp2.PLUS, OpOp2.POW, OpOp2.MIN, OpOp2.MAX, OpOp2.XOR, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, diff --git a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java index 0558e17c9cd..47261198c15 100644 --- a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java +++ b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java @@ -481,6 +481,7 @@ public static long getInstNFLOP( costs = 40; break; case "ucumk+": + case "urowcumk+": case "ucummin": case "ucummax": case "ucum*": diff --git a/src/main/java/org/apache/sysds/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/GPUInstructionParser.java index 4fca2ad0eae..cf767f3dc07 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/GPUInstructionParser.java @@ -145,6 +145,7 @@ public class GPUInstructionParser extends InstructionParser // Cumulative Ops String2GPUInstructionType.put( "ucumk+" , GPUINSTRUCTION_TYPE.BuiltinUnary); + String2GPUInstructionType.put( "urowcumk+", GPUINSTRUCTION_TYPE.BuiltinUnary); String2GPUInstructionType.put( "ucum*" , GPUINSTRUCTION_TYPE.BuiltinUnary); String2GPUInstructionType.put( "ucumk+*" , GPUINSTRUCTION_TYPE.BuiltinUnary); String2GPUInstructionType.put( "ucummin" , GPUINSTRUCTION_TYPE.BuiltinUnary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java index 0e333285c0f..c4d1f308774 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.java @@ -51,6 +51,8 @@ private CumulativeOffsetFEDInstruction(Operator op, CPOperand in1, CPOperand in2 if ("bcumoffk+".equals(opcode)) _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+")); + else if ("browcumoffk+".equals(opcode)) + _uop = new UnaryOperator(Builtin.getBuiltinFnObject("urowcumk+")); else if ("bcumoff*".equals(opcode)) _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*")); else if ("bcumoff+*".equals(opcode)) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java index f125c01d4b0..61f9c4e06d0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java @@ -76,7 +76,7 @@ public static UnaryMatrixFEDInstruction parseInstruction(String str) { in.split(parts[1]); out.split(parts[2]); ValueFunction func = Builtin.getBuiltinFnObject(opcode); - if(Arrays.asList(new String[] {"ucumk+", "ucum*", "ucumk+*", "ucummin", "ucummax", "exp", "log", "sigmoid"}) + if(Arrays.asList(new String[] {"ucumk+", "urowcumk+", "ucum*", "ucumk+*", "ucummin", "ucummax", "exp", "log", "sigmoid"}) .contains(opcode)) { UnaryOperator op = new UnaryOperator(func, Integer.parseInt(parts[3]), Boolean.parseBoolean(parts[4])); return new UnaryMatrixFEDInstruction(op, in, out, opcode, str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java index 0557ccc2791..250b0bf843a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixBuiltinGPUInstruction.java @@ -92,6 +92,10 @@ public void processInstruction(ExecutionContext ec) { LibMatrixCUDA.cumulativeScan(ec, ec.getGPUContext(0), getExtendedOpcode(), "cumulative_sum", mat, _output.getName()); break; + case "urowcumk+": + LibMatrixCUDA.cumulativeScan(ec, ec.getGPUContext(0), getExtendedOpcode(), "row_cumulative_sum", mat, + _output.getName()); + break; case "ucum*": LibMatrixCUDA.cumulativeScan(ec, ec.getGPUContext(0), getExtendedOpcode(), "cumulative_prod", mat, _output.getName()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeAggregateSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeAggregateSPInstruction.java index b511ac2e257..8c776abe56d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeAggregateSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeAggregateSPInstruction.java @@ -25,9 +25,11 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.KahanPlus; import org.apache.sysds.runtime.functionobjects.PlusMultiply; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.KahanObject; import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -52,62 +54,135 @@ public static CumulativeAggregateSPInstruction parseInstruction( String str ) { CPOperand in1 = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[2]); AggregateUnaryOperator aggun = InstructionUtils.parseCumulativeAggregateUnaryOperator(opcode); - return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str); + return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str); } - + @Override public void processInstruction(ExecutionContext ec) { SparkExecutionContext sec = (SparkExecutionContext)ec; DataCharacteristics mc = sec.getDataCharacteristics(input1.getName()); + + //get input + JavaPairRDD in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() ); + + if ("urowcumk+".equals(getOpcode())) { + processRowCumsum(sec, in, mc); + } else { + processCumsum(sec, in, mc); + } + } + + private void processRowCumsum(SparkExecutionContext sec, JavaPairRDD in, DataCharacteristics mc) { + JavaPairRDD localRowCumsum = + in.mapToPair(new LocalRowCumsumFunction()); + + sec.setRDDHandleForVariable(output.getName(), localRowCumsum); + sec.addLineageRDD(output.getName(), input1.getName()); + sec.getDataCharacteristics(output.getName()).set(mc); + } + + public static Tuple2, JavaPairRDD> + processRowCumsumWithEndValues(JavaPairRDD in) { + JavaPairRDD localRowCumsum = + in.mapToPair(new LocalRowCumsumFunction()); + + JavaPairRDD endValues = + localRowCumsum.mapToPair(new ExtractEndValuesFunction()); + + return new Tuple2<>(localRowCumsum, endValues); + } + + private void processCumsum(SparkExecutionContext sec, JavaPairRDD in, DataCharacteristics mc) { DataCharacteristics mcOut = new MatrixCharacteristics(mc); long rlen = mc.getRows(); int blen = mc.getBlocksize(); mcOut.setRows((long)(Math.ceil((double)rlen/blen))); - - //get input - JavaPairRDD in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() ); - + //execute unary aggregate (w/ implicit drop correction) AggregateUnaryOperator auop = (AggregateUnaryOperator) _optr; - JavaPairRDD out = - in.mapToPair(new RDDCumAggFunction(auop, rlen, blen)); + JavaPairRDD out = + in.mapToPair(new RDDCumAggFunction(auop, rlen, blen)); //merge partial aggregates, adjusting for correct number of partitions //as size can significant shrink (1K) but also grow (sparse-dense) int numParts = SparkUtils.getNumPreferredPartitions(mcOut); int minPar = (int)Math.min(SparkExecutionContext.getDefaultParallelism(true), mcOut.getNumBlocks()); out = RDDAggregateUtils.mergeByKey(out, Math.max(numParts, minPar), false); - + //put output handle in symbol table sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); sec.getDataCharacteristics(output.getName()).set(mcOut); } - private static class RDDCumAggFunction implements PairFunction, MatrixIndexes, MatrixBlock> + private static class LocalRowCumsumFunction implements PairFunction, MatrixIndexes, MatrixBlock> { + private static final long serialVersionUID = 123L; + + @Override + public Tuple2 call(Tuple2 kv) throws Exception { + MatrixIndexes idx = kv._1; + MatrixBlock inputBlock = kv._2; + MatrixBlock outBlock = new MatrixBlock(inputBlock.getNumRows(), inputBlock.getNumColumns(), false); + + for (int i = 0; i < inputBlock.getNumRows(); i++) { + KahanObject kbuff = new KahanObject(0, 0); + KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); + + for (int j = 0; j < inputBlock.getNumColumns(); j++) { + double val = inputBlock.get(i, j); + kplus.execute2(kbuff, val); + outBlock.set(i, j, kbuff._sum); + } + } + // original index, original matrix and local cumsum block + return new Tuple2<>(idx, outBlock); + } + } + + private static class ExtractEndValuesFunction implements PairFunction, MatrixIndexes, MatrixBlock> { + private static final long serialVersionUID = 123L; + + @Override + public Tuple2 call(Tuple2 kv) throws Exception { + MatrixIndexes idx = kv._1; + MatrixBlock cumsumBlock = kv._2; + + MatrixBlock endValuesBlock = new MatrixBlock(cumsumBlock.getNumRows(), 1, false); + for (int i = 0; i < cumsumBlock.getNumRows(); i++) { + if (cumsumBlock.getNumColumns() > 0) { + endValuesBlock.set(i, 0, cumsumBlock.get(i, cumsumBlock.getNumColumns() - 1)); + } else { + endValuesBlock.set(i, 0, 0.0); + } + } + return new Tuple2<>(idx, endValuesBlock); + } + } + + private static class RDDCumAggFunction implements PairFunction, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = 11324676268945117L; - + private final AggregateUnaryOperator _op; private UnaryOperator _uop = null; private final long _rlen; private final int _blen; - + public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, int blen ) { _op = op; _rlen = rlen; _blen = blen; } - + @Override - public Tuple2 call( Tuple2 arg0 ) - throws Exception + public Tuple2 call( Tuple2 arg0 ) + throws Exception { MatrixIndexes ixIn = arg0._1(); MatrixBlock blkIn = arg0._2(); MatrixIndexes ixOut = new MatrixIndexes(); MatrixBlock blkOut = new MatrixBlock(); - + //process instruction AggregateUnaryOperator aop = _op; if( aop.aggOp.increOp.fn instanceof PlusMultiply ) { //cumsumprod @@ -125,19 +200,19 @@ public Tuple2 call( Tuple2(ixOut, blkOut2); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java index 61b61b15332..9b57713e7ac 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java @@ -40,10 +40,10 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.UtilFunctions; +import scala.Serializable; import scala.Tuple2; -import java.util.ArrayList; -import java.util.Iterator; +import java.util.*; public class CumulativeOffsetSPInstruction extends BinarySPInstruction { private UnaryOperator _uop = null; @@ -92,12 +92,17 @@ public void processInstruction(ExecutionContext ec) { DataCharacteristics mc2 = sec.getDataCharacteristics(input2.getName()); long rlen = mc2.getRows(); int blen = mc2.getBlocksize(); - + + if (Opcodes.BROWCUMOFFKP.toString().equals(getOpcode())) { + processRowCumsumOffsets(sec, mc1, mc2); + return; + } + //get and join inputs JavaPairRDD inData = sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName()); JavaPairRDD> joined = null; boolean broadcast = _broadcast && !SparkUtils.isHashPartitioned(inData); - + if( broadcast ) { //broadcast offsets and broadcast join with data PartitionedBroadcast inAgg = sec.getBroadcastForVariable(input2.getName()); @@ -106,18 +111,18 @@ public void processInstruction(ExecutionContext ec) { else { //prepare aggregates (cumsplit of offsets) and repartition join with data joined = inData.join(sec - .getBinaryMatrixBlockRDDHandleForVariable(input2.getName()) - .flatMapToPair(new RDDCumSplitFunction(_initValue, rlen, blen))); + .getBinaryMatrixBlockRDDHandleForVariable(input2.getName()) + .flatMapToPair(new RDDCumSplitFunction(_initValue, rlen, blen))); } - + //execute cumulative offset (apply cumulative op w/ offsets) JavaPairRDD out = joined - .mapValues(new RDDCumOffsetFunction(_uop, _cumsumprod)); - + .mapValues(new RDDCumOffsetFunction(_uop, _cumsumprod)); + //put output handle in symbol table if( _cumsumprod ) sec.getDataCharacteristics(output.getName()) - .set(mc1.getRows(), 1, mc1.getBlocksize(), mc1.getBlocksize()); + .set(mc1.getRows(), 1, mc1.getBlocksize(), mc1.getBlocksize()); else //general case updateUnaryOutputDataCharacteristics(sec); sec.setRDDHandleForVariable(output.getName(), out); @@ -125,6 +130,94 @@ public void processInstruction(ExecutionContext ec) { sec.addLineage(output.getName(), input2.getName(), broadcast); } + public static JavaPairRDD processRowCumsumOffsetsDirectly( + JavaPairRDD localRowCumsum, + JavaPairRDD endValues) { + + // Collect end-values of every block of every row for offset calc by grouping by global row index + JavaPairRDD>> rowEndValues = endValues + .mapToPair(t -> { + // get index of block + MatrixIndexes idx = t._1; + // get cum matrix block + MatrixBlock endValuesBlock = t._2; + + // get row and column block index + long rowBlockIdx = idx.getRowIndex(); + long colBlockIdx = idx.getColumnIndex(); + + // Save end value of every row of every block (if block is empty save 0) + double[] lastValues = new double[endValuesBlock.getNumRows()]; + for (int i = 0; i < endValuesBlock.getNumRows(); i++) { + lastValues[i] = endValuesBlock.get(i, 0); + } + + return new Tuple2<>(rowBlockIdx, new Tuple3<>(rowBlockIdx, colBlockIdx, lastValues)); + }) + .groupByKey(); + + // compute offset for every block + List, double[]>> offsetList = rowEndValues + .flatMapToPair(t -> { + Long rowBlockIdx = t._1; + List> colValues = new ArrayList<>(); + for (Tuple3 cv : t._2) { + colValues.add(cv); + } + + // sort blocks from one row by column index + colValues.sort(Comparator.comparing(Tuple3::_2)); + + // get number of rows of a block by counting amount of end (row) values of said block + int numRows = 0; + if (!colValues.isEmpty()) { + double[] lastValuesArray = colValues.get(0)._3(); + numRows = lastValuesArray.length; + } + + List, double[]>> blockOffsets = new ArrayList<>(); + double[] cumulativeOffsets = new double[numRows]; + + for (Tuple3 colValue : colValues) { + Long colBlockIdx = colValue._2(); + double[] rowendValues = colValue._3(); + + // copy current offsets + double[] currentOffsets = cumulativeOffsets.clone(); + + // and save block indexes with its offsets + blockOffsets.add(new Tuple2<>(new Tuple2<>(rowBlockIdx, colBlockIdx), currentOffsets)); + + for (int i = 0; i < numRows && i < rowendValues.length; i++) { + cumulativeOffsets[i] += rowendValues[i]; + } + } + return blockOffsets.iterator(); + }) + .collect(); + + // convert list to map for easier access to offsets + Map, double[]> offsetMap = new HashMap<>(); + for (Tuple2, double[]> entry : offsetList) { + offsetMap.put(entry._1, entry._2); + } + return localRowCumsum.mapToPair(new FinalRowCumsumFunction(offsetMap)); + } + + private void processRowCumsumOffsets(SparkExecutionContext sec, DataCharacteristics mc1, DataCharacteristics mc2) { + JavaPairRDD localRowCumsum = + sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName()); + JavaPairRDD endValues = + sec.getBinaryMatrixBlockRDDHandleForVariable(input2.getName()); + + JavaPairRDD out = processRowCumsumOffsetsDirectly(localRowCumsum, endValues); + + sec.setRDDHandleForVariable(output.getName(), out); + sec.addLineageRDD(output.getName(), input1.getName()); + sec.addLineageRDD(output.getName(), input2.getName()); + sec.getDataCharacteristics(output.getName()).set(mc1); + } + public double getInitValue() { return _initValue; } @@ -133,36 +226,86 @@ public boolean getBroadcast() { return _broadcast; } - private static class RDDCumSplitFunction implements PairFlatMapFunction, MatrixIndexes, MatrixBlock> + private static class FinalRowCumsumFunction implements PairFunction, MatrixIndexes, MatrixBlock> { + private static final long serialVersionUID = 1L; + // map block indexes to the row offsets + private final Map, double[]> offsetMap; + + public FinalRowCumsumFunction(Map, double[]> offsetMap) { + this.offsetMap = offsetMap; + } + + @Override + public Tuple2 call(Tuple2 tuple) throws Exception { + MatrixIndexes idx = tuple._1; + MatrixBlock localRowCumsumBlock = tuple._2; + + // key to get the row offset for this block + Tuple2 blockKey = new Tuple2<>(idx.getRowIndex(), idx.getColumnIndex()); + double[] offsets = offsetMap.get(blockKey); + + MatrixBlock outBlock = new MatrixBlock(localRowCumsumBlock.getNumRows(), localRowCumsumBlock.getNumColumns(), false); + + for (int i = 0; i < localRowCumsumBlock.getNumRows(); i++) { + double rowOffset = (offsets != null && i < offsets.length) ? offsets[i] : 0.0; + for (int j = 0; j < localRowCumsumBlock.getNumColumns(); j++) { + double cumsumValue = localRowCumsumBlock.get(i, j); + outBlock.set(i, j, cumsumValue + rowOffset); + } + } + // block index and final cumsum block + return new Tuple2<>(idx, outBlock); + } + } + + // helper class + private static class Tuple3 implements Serializable { + private static final long serialVersionUID = 1L; + private final T1 _1; + private final T2 _2; + private final T3 _3; + + public Tuple3(T1 _1, T2 _2, T3 _3) { + this._1 = _1; + this._2 = _2; + this._3 = _3; + } + + public T1 _1() { return _1; } + public T2 _2() { return _2; } + public T3 _3() { return _3; } + } + + private static class RDDCumSplitFunction implements PairFlatMapFunction, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -8407407527406576965L; - + private double _initValue = 0; private int _blen = -1; private long _lastRowBlockIndex; - + public RDDCumSplitFunction( double initValue, long rlen, int blen ) { _initValue = initValue; _blen = blen; _lastRowBlockIndex = (long)Math.ceil((double)rlen/blen); } - + @Override - public Iterator> call( Tuple2 arg0 ) - throws Exception + public Iterator> call( Tuple2 arg0 ) + throws Exception { ArrayList> ret = new ArrayList<>(); - + MatrixIndexes ixIn = arg0._1(); MatrixBlock blkIn = arg0._2(); - + long rixOffset = (ixIn.getRowIndex()-1)*_blen; boolean firstBlk = (ixIn.getRowIndex() == 1); boolean lastBlk = (ixIn.getRowIndex() == _lastRowBlockIndex ); - - //introduce offsets w/ init value for first row - if( firstBlk ) { + + //introduce offsets w/ init value for first row + if( firstBlk ) { MatrixIndexes tmpix = new MatrixIndexes(1, ixIn.getColumnIndex()); MatrixBlock tmpblk = new MatrixBlock(1, blkIn.getNumColumns(), blkIn.isInSparseFormat()); if( _initValue != 0 ){ @@ -171,7 +314,7 @@ public Iterator> call( Tuple2(tmpix, tmpblk)); } - + //output splitting (shift by one), preaggregated offset used by subsequent block for( int i=0; i> call( Tuple2(tmpix, tmpblk)); } - + return ret.iterator(); } } - - private static class RDDCumSplitLookupFunction implements PairFunction, MatrixIndexes, Tuple2> + + private static class RDDCumSplitLookupFunction implements PairFunction, MatrixIndexes, Tuple2> { private static final long serialVersionUID = -2785629043886477479L; - + private final PartitionedBroadcast _pbc; private final double _initValue; private final int _blen; - + public RDDCumSplitLookupFunction(PartitionedBroadcast pbc, double initValue, long rlen, int blen) { _pbc = pbc; _initValue = initValue; _blen = blen; } - + @Override public Tuple2> call(Tuple2 arg0) throws Exception { MatrixIndexes ixIn = arg0._1(); MatrixBlock blkIn = arg0._2(); - + //compute block and row indexes long brix = UtilFunctions.computeBlockIndex(ixIn.getRowIndex()-1, _blen); int rix = UtilFunctions.computeCellInBlock(ixIn.getRowIndex()-1, _blen); - + //lookup offset row and return joined output MatrixBlock off = (ixIn.getRowIndex() == 1) ? new MatrixBlock(1, blkIn.getNumColumns(), _initValue) : - _pbc.getBlock((int)brix, (int)ixIn.getColumnIndex()).slice(rix, rix); + _pbc.getBlock((int)brix, (int)ixIn.getColumnIndex()).slice(rix, rix); return new Tuple2<>(ixIn, new Tuple2<>(blkIn,off)); } } - private static class RDDCumOffsetFunction implements Function, MatrixBlock> + private static class RDDCumOffsetFunction implements Function, MatrixBlock> { private static final long serialVersionUID = -5804080263258064743L; private final UnaryOperator _uop; private final boolean _cumsumprod; - + public RDDCumOffsetFunction(UnaryOperator uop, boolean cumsumprod) { _uop = uop; _cumsumprod = cumsumprod; @@ -231,17 +374,17 @@ public RDDCumOffsetFunction(UnaryOperator uop, boolean cumsumprod) { @Override public MatrixBlock call(Tuple2 arg0) throws Exception { //prepare inputs and outputs - MatrixBlock dblkIn = arg0._1(); //original data + MatrixBlock dblkIn = arg0._1(); //original data MatrixBlock oblkIn = arg0._2(); //offset row vector - + //allocate output block MatrixBlock blkOut = new MatrixBlock(dblkIn.getNumRows(), - _cumsumprod ? 1 : dblkIn.getNumColumns(), false); - + _cumsumprod ? 1 : dblkIn.getNumColumns(), false); + //blockwise cumagg computation, incl offset aggregation return LibMatrixAgg.cumaggregateUnaryMatrix(dblkIn, blkOut, _uop, - DataConverter.convertToDoubleVector(oblkIn, false, - ((Builtin)_uop.fn).bFunc == BuiltinCode.CUMSUM)); + DataConverter.convertToDoubleVector(oblkIn, false, + ((Builtin)_uop.fn).bFunc == BuiltinCode.CUMSUM)); } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java index eebeef0b2ec..6181d1c1d1a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java @@ -21,28 +21,18 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysds.runtime.functionobjects.KahanPlus; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; -import org.apache.sysds.runtime.instructions.cp.KahanObject; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; -import scala.Serializable; import scala.Tuple2; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - public class UnaryMatrixSPInstruction extends UnarySPInstruction { protected UnaryMatrixSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String instr) { @@ -73,189 +63,22 @@ public void processInstruction(ExecutionContext ec) { sec.setRDDHandleForVariable(output.getName(), out); sec.addLineageRDD(output.getName(), input1.getName()); - //FIXME: implement similar to cumsum through + //FIXME: implement similar to cumsum through // CumulativeAggregateSPInstruction (Spark) // UnaryMatrixCPInstruction (local cumsum on aggregates) - // CumulativeOffsetSPInstruction (Spark) - if ( "urowcumk+".equals(getOpcode()) ) { - - JavaPairRDD< MatrixIndexes, Tuple2 > localRowcumsum = in.mapToPair( new LocalRowCumsumFunction() ); - - // Collect end-values of every block of every row for offset calc by grouping by global row index - JavaPairRDD< Long, Iterable> > rowEndValues = localRowcumsum - .mapToPair( tuple2 -> { - // get index of block - MatrixIndexes indexes = tuple2._1; - // get cum matrix block - MatrixBlock localRowcumsumBlock = tuple2._2._2; - - // get row and column block index - long rowBlockIndex = indexes.getRowIndex(); - long colBlockIndex = indexes.getColumnIndex(); - - // Save end value of every row of every block (if block is empty save 0) - double[] endValues = new double[ localRowcumsumBlock.getNumRows() ]; - - for ( int i = 0; i < localRowcumsumBlock.getNumRows(); i ++ ) { - if (localRowcumsumBlock.getNumColumns() > 0) - endValues[i] = localRowcumsumBlock.get(i, localRowcumsumBlock.getNumColumns() - 1); - else - endValues[i] = 0.0 ; - } - return new Tuple2<>(rowBlockIndex, new Tuple3<>(rowBlockIndex, colBlockIndex, endValues)); - } - ).groupByKey(); - - // compute offset for every block - List< Tuple2 , double[]> > offsetList = rowEndValues - .flatMapToPair(tuple2 -> { - Long rowBlockIndex = tuple2._1; - List< Tuple3 > colValues = new ArrayList<>(); - for ( Tuple3 cv : tuple2._2 ) - colValues.add(cv); - - // sort blocks from one row by column index - colValues.sort(Comparator.comparing(Tuple3::_2)); - - // get number of rows of a block by counting amount of end (row) values of said block - int numberOfRows = 0; - if ( !colValues.isEmpty() ) { - Tuple3 firstTuple = colValues.get(0); - double[] lastValuesArray = firstTuple._3(); - numberOfRows = lastValuesArray.length; - } - - List, double[]>> blockOffsets = new ArrayList<>(); - double[] cumulativeOffsets = new double[numberOfRows]; - for (Tuple3 colValue : colValues) { - Long colBlockIndex = colValue._2(); - double[] endValues = colValue._3(); - - // copy current offsets - double[] currentOffsets = cumulativeOffsets.clone(); - - // and save block indexes with its offsets - blockOffsets.add( new Tuple2<>(new Tuple2<>(rowBlockIndex, colBlockIndex), currentOffsets) ); - - for ( int i = 0; i < numberOfRows && i < endValues.length; i++ ) { - cumulativeOffsets[i] += endValues[i]; - } - } - return blockOffsets.iterator(); - } - ).collect(); - - // convert list to map for easier access to offsets - Map< Tuple2, double[] > offsetMap = new HashMap<>(); - for (Tuple2, double[]> offset : offsetList) { - offsetMap.put(offset._1, offset._2); - } - - out = localRowcumsum.mapToPair( new FinalRowCumsumFunction(offsetMap)) ; - - updateUnaryOutputDataCharacteristics(sec); - sec.setRDDHandleForVariable(output.getName(), out); - sec.addLineageRDD(output.getName(), input1.getName()); - } - } - - - - private static class LocalRowCumsumFunction implements PairFunction< Tuple2, MatrixIndexes, Tuple2 > { - private static final long serialVersionUID = 2388003441846068046L; - - @Override - public Tuple2< MatrixIndexes, Tuple2 > call(Tuple2 tuple2) { - + // CumulativeOffsetSPInstruction (Spark) - MatrixBlock inputBlock = tuple2._2; - MatrixBlock cumsumBlock = new MatrixBlock( inputBlock.getNumRows(), inputBlock.getNumColumns(), false ); - - - for ( int i = 0; i < inputBlock.getNumRows(); i++ ) { - - KahanObject kbuff = new KahanObject(0, 0); - KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); - - for ( int j = 0; j < inputBlock.getNumColumns(); j++ ) { - - double val = inputBlock.get(i, j); - kplus.execute2(kbuff, val); - cumsumBlock.set(i, j, kbuff._sum); - } - } - // original index, original matrix and local cumsum block - return new Tuple2<>( tuple2._1, new Tuple2<>(inputBlock, cumsumBlock) ); - } - } + // rowcumsum processing + JavaPairRDD localRowcumsum = sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName()); + Tuple2, JavaPairRDD> results = + CumulativeAggregateSPInstruction.processRowCumsumWithEndValues(localRowcumsum); + JavaPairRDD rowEndValues = CumulativeOffsetSPInstruction.processRowCumsumOffsetsDirectly(results._1, results._2); - - private static class FinalRowCumsumFunction implements PairFunction >, MatrixIndexes, MatrixBlock> { - private static final long serialVersionUID = -6738155890298916270L; - // map block indexes to the row offsets - private final Map< Tuple2, double[] > offsetMap; - - public FinalRowCumsumFunction(Map, double[]> offsetMap) { - this.offsetMap = offsetMap; - } - - - @Override - public Tuple2 call( Tuple2< MatrixIndexes, Tuple2 > tuple ) { - - MatrixIndexes indexes = tuple._1; - MatrixBlock inputBlock = tuple._2._1; - MatrixBlock localRowCumsumBlock = tuple._2._2; - - // key to get the row offset for this block - Tuple2 blockKey = new Tuple2<>( indexes.getRowIndex(), indexes.getColumnIndex()) ; - double[] offsets = offsetMap.get(blockKey); - - MatrixBlock cumsumBlock = new MatrixBlock( inputBlock.getNumRows(), inputBlock.getNumColumns(), false ); - - - for ( int i = 0; i < inputBlock.getNumRows(); i++ ) { - - double rowOffset = 0.0; - if ( offsets != null && i < offsets.length ) { - rowOffset = offsets[i]; - } - - for ( int j = 0; j < inputBlock.getNumColumns(); j++ ) { - double cumsumValue = localRowCumsumBlock.get(i, j); - cumsumBlock.set(i, j, cumsumValue + rowOffset); - } - } - - // block index and final cumsum block - return new Tuple2<>(indexes, cumsumBlock); - } - } - - - - // helper class - private static class Tuple3 implements Serializable { - - private static final long serialVersionUID = 123; - private final Type2 _2; - private final Type3 _3; - - - public Tuple3( Type1 _1, Type2 _2, Type3 _3 ) { - this._2 = _2; - this._3 = _3; - } - - public Type2 _2() { - return _2; - } - - public Type3 _3() { - return _3; - } + sec.setRDDHandleForVariable(output.getName(), rowEndValues); + sec.addLineageRDD(output.getName(), input1.getName()); + updateUnaryOutputDataCharacteristics(sec); } private static class RDDMatrixBuiltinUnaryOp implements Function