Commit 5311ba01 authored by cdanger's avatar cdanger

Optimizing commutative multary numeric operators (used in add and

multiply functions) - merge/replace all constant args into/with one
parent 9af15aa5
......@@ -18,11 +18,14 @@
*/
package org.ow2.authzforce.core.pdp.impl.func;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import org.ow2.authzforce.core.pdp.api.IndeterminateEvaluationException;
import org.ow2.authzforce.core.pdp.api.StatusHelper;
import org.ow2.authzforce.core.pdp.api.expression.ConstantPrimitiveAttributeValueExpression;
import org.ow2.authzforce.core.pdp.api.expression.Expression;
import org.ow2.authzforce.core.pdp.api.func.BaseFirstOrderFunctionCall.EagerSinglePrimitiveTypeEval;
import org.ow2.authzforce.core.pdp.api.func.FirstOrderFunctionCall;
......@@ -31,6 +34,8 @@ import org.ow2.authzforce.core.pdp.api.func.SingleParameterTypedFirstOrderFuncti
import org.ow2.authzforce.core.pdp.api.value.Datatype;
import org.ow2.authzforce.core.pdp.api.value.NumericValue;
import org.ow2.authzforce.core.pdp.api.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A class that implements all the numeric *-add functions (as opposed to date/time *-add-* functions).
......@@ -40,12 +45,11 @@ import org.ow2.authzforce.core.pdp.api.value.Value;
*
* @version $Id: $
*/
final class NumericArithmeticFunction<AV extends NumericValue<?, AV>>
extends SingleParameterTypedFirstOrderFunction<AV, AV>
final class NumericArithmeticFunction<AV extends NumericValue<?, AV>> extends SingleParameterTypedFirstOrderFunction<AV, AV>
{
private static final Logger LOGGER = LoggerFactory.getLogger(NumericArithmeticFunction.class);
private static final IllegalArgumentException UNDEF_PARAMETER_TYPES_EXCEPTION = new IllegalArgumentException(
"Undefined function parameter types");
private static final IllegalArgumentException UNDEF_PARAMETER_TYPES_EXCEPTION = new IllegalArgumentException("Undefined function parameter types");
private static <AV extends Value> List<Datatype<AV>> validate(final List<Datatype<AV>> paramTypes)
{
......@@ -62,13 +66,24 @@ final class NumericArithmeticFunction<AV extends NumericValue<?, AV>>
V eval(Deque<V> args) throws IllegalArgumentException, ArithmeticException;
}
/**
* Multary/Multiary/Polyadic operator
*
* @see "https://en.wikipedia.org/wiki/Arity#Other_names"
*
* @param <V>
*/
interface MultaryOperation<V extends NumericValue<?, V>> extends StaticOperation<V>
{
boolean isCommutative();
}
private static final class Call<V extends NumericValue<?, V>> extends EagerSinglePrimitiveTypeEval<V, V>
{
private final String invalidArgsErrMsg;
private final StaticOperation<V> op;
private Call(final SingleParameterTypedFirstOrderFunctionSignature<V, V> functionSig,
final StaticOperation<V> op, final List<Expression<?>> args, final Datatype<?>[] remainingArgTypes)
private Call(final SingleParameterTypedFirstOrderFunctionSignature<V, V> functionSig, final StaticOperation<V> op, final List<Expression<?>> args, final Datatype<?>[] remainingArgTypes)
throws IllegalArgumentException
{
super(functionSig, args, remainingArgTypes);
......@@ -104,8 +119,7 @@ final class NumericArithmeticFunction<AV extends NumericValue<?, AV>>
* whether this is a varargs function (like Java varargs method), i.e. last arg has variable-length
*
*/
NumericArithmeticFunction(final String funcURI, final boolean varArgs, final List<Datatype<AV>> paramTypes,
final StaticOperation<AV> op) throws IllegalArgumentException
NumericArithmeticFunction(final String funcURI, final boolean varArgs, final List<Datatype<AV>> paramTypes, final StaticOperation<AV> op) throws IllegalArgumentException
{
super(funcURI, validate(paramTypes).get(0), varArgs, paramTypes);
this.op = op;
......@@ -113,17 +127,54 @@ final class NumericArithmeticFunction<AV extends NumericValue<?, AV>>
/** {@inheritDoc} */
@Override
public FirstOrderFunctionCall<AV> newCall(final List<Expression<?>> argExpressions,
final Datatype<?>... remainingArgTypes) throws IllegalArgumentException
public FirstOrderFunctionCall<AV> newCall(final List<Expression<?>> argExpressions, final Datatype<?>... remainingArgTypes) throws IllegalArgumentException
{
/**
* TODO: optimize call to "add" (resp. "multiply") function call by checking all static/constant arguments and
* if there are more than one, pre-compute their sum (resp. product) and replace these arguments with one
* argument that is this sum (resp. product) in the function call. Indeed, 'add' function is commutative and
* (constant in upper case, variables in lower case): add(C1, C2, x, y...) = add(C1+C2, x, y...). Similarly,
* multiply(C1, C2, x, y...) = multiply(C1*C2, x, y...)
* If this.op is a commutative function (e.g. add or multiply function), we can simplify arguments if there are multiple constants. Indeed, if C1,...Cm are constants, then:
* <p>
* op(x1,..., x_{n1-1}, C1, x_n1, ..., x_{n2-1} C2, x_n2, ..., Cm, x_nm...) = op( C, x1.., x_{n1-1}, x_n1, x_{n2-2}, x_n2...), where C (constant) = op(C1, C2..., Cm)
* </p>
* In this case, we can pre-compute constant C and replace all constant args with one: C
*
*/
if (op instanceof MultaryOperation && ((MultaryOperation<AV>) op).isCommutative())
{
/*
* Constant argExpressions
*/
final Deque<AV> constants = new ArrayDeque<>(argExpressions.size());
/*
* Remaining variable argExpressions
*/
final List<Expression<?>> finalArgExpressions = new ArrayList<>(argExpressions.size());
final Datatype<AV> paramType = this.functionSignature.getParameterType();
for (final Expression<?> argExp : argExpressions)
{
final Value v = argExp.getValue();
if (v == null)
{
// variable
finalArgExpressions.add(argExp);
}
else
{
// constant
constants.add(paramType.cast(v));
}
}
if (constants.size() > 1)
{
/*
* we can replace all constant args C1, C2... with one constant C = op(C1, C2...)
*/
final AV constantResult = op.eval(constants);
LOGGER.warn("Function {}: optimizing call to this commutative function: replacing/merging constant args {} with/into one: {}", this.functionSignature, constants, constantResult);
finalArgExpressions.add(new ConstantPrimitiveAttributeValueExpression<>(paramType, constantResult));
return new Call<>(functionSignature, op, finalArgExpressions, remainingArgTypes);
}
}
return new Call<>(functionSignature, op, argExpressions, remainingArgTypes);
}
......
......@@ -23,6 +23,7 @@ import java.util.Deque;
import org.ow2.authzforce.core.pdp.api.value.DoubleValue;
import org.ow2.authzforce.core.pdp.api.value.IntegerValue;
import org.ow2.authzforce.core.pdp.api.value.NumericValue;
import org.ow2.authzforce.core.pdp.impl.func.NumericArithmeticFunction.MultaryOperation;
import org.ow2.authzforce.core.pdp.impl.func.NumericArithmeticFunction.StaticOperation;
final class NumericArithmeticOperators
......@@ -43,8 +44,14 @@ final class NumericArithmeticOperators
}
static final class AddOperator<NAV extends NumericValue<?, NAV>> implements StaticOperation<NAV>
static final class AddOperator<NAV extends NumericValue<?, NAV>> implements MultaryOperation<NAV>
{
@Override
public boolean isCommutative()
{
return true;
}
@Override
public NAV eval(final Deque<NAV> args)
{
......@@ -54,9 +61,15 @@ final class NumericArithmeticOperators
}
static final class MultiplyOperator<NAV extends NumericValue<?, NAV>> implements StaticOperation<NAV>
static final class MultiplyOperator<NAV extends NumericValue<?, NAV>> implements MultaryOperation<NAV>
{
@Override
public boolean isCommutative()
{
return true;
}
@Override
public NAV eval(final Deque<NAV> args)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment