view artifacts/src/main/java/org/dive4elements/river/artifacts/model/sq/SQRelationCalculation.java @ 6787:51eb6491c537

S/Q: Excel compat completed: Now the data is linearized before fitting. This can be prevented by setting the system property S/Q: Excel compat completed: Now the data is linearized before fitting. This can be prevented by setting the system property "minfo.sq.calcution.non.linear.fitting" to true.
author Sascha L. Teichmann <teichmann@intevation.de>
date Thu, 08 Aug 2013 18:14:38 +0200
parents b8f94e865875
children 978ab716a15e
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.sq;

import org.dive4elements.artifacts.common.utils.StringUtils;
import org.dive4elements.river.artifacts.access.SQRelationAccess;

import org.dive4elements.river.artifacts.math.fitting.Function;
import org.dive4elements.river.artifacts.math.fitting.FunctionFactory;

import org.dive4elements.river.artifacts.model.Calculation;
import org.dive4elements.river.artifacts.model.CalculationResult;
import org.dive4elements.river.artifacts.model.DateRange;
import org.dive4elements.river.artifacts.model.Parameters;

import org.dive4elements.river.backend.SedDBSessionHolder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.log4j.Logger;

public class SQRelationCalculation extends Calculation {

    private static final Logger log =
        Logger.getLogger(SQRelationCalculation.class);

    public static final boolean NON_LINEAR_FITTING =
        Boolean.getBoolean("minfo.sq.calcution.non.linear.fitting");

    public static final String SQ_POW_FUNCTION_NAME = "sq-pow";
    public static final String SQ_LIN_FUNCTION_NAME = "linear";

    protected String    river;
    protected double    location;
    protected DateRange period;
    protected double    outliers;
    private   String    method;

    public SQRelationCalculation() {
    }

    public SQRelationCalculation(SQRelationAccess access) {

        String    river    = access.getRiver();
        Double    location = access.getLocation();
        DateRange period   = access.getPeriod();
        Double    outliers = access.getOutliers();
        String    method   = access.getOutlierMethod();

        if (river == null) {
            // TODO: i18n
            addProblem("sq.missing.river");
        }

        if (location == null) {
            // TODO: i18n
            addProblem("sq.missing.location");
        }

        if (period == null) {
            // TODO: i18n
            addProblem("sq.missing.periods");
        }

        if (outliers == null) {
            // TODO: i18n
            addProblem("sq.missing.outliers");
        }

        if (method == null) {
            //TODO: i18n
            addProblem("sq.missing.method");
        }

        if (!hasProblems()) {
            this.river    = river;
            this.location = location;
            this.period   = period;
            this.outliers = outliers;
            this.method   = method;
        }
    }


    public CalculationResult calculate() {
        log.debug("SQRelationCalculation.calculate");

        if (hasProblems()) {
            return new CalculationResult(this);
        }

        SedDBSessionHolder.acquire();
        try {
            return internalCalculate();
        }
        finally {
            SedDBSessionHolder.release();
        }
    }

    public interface TransformCoeffs {
        double [] transform(double [] coeffs);
    }

    public static final TransformCoeffs IDENTITY_TRANS =
        new TransformCoeffs() {
            @Override
            public double [] transform(double [] coeffs) {
                return coeffs;
            }
        };

    public static final TransformCoeffs LINEAR_TRANS =
        new TransformCoeffs() {
            @Override
            public double [] transform(double [] coeffs) {
                log.debug("before transform: " + Arrays.toString(coeffs));
                if (coeffs.length == 2) {
                    coeffs = new double [] { Math.exp(coeffs[1]), coeffs[0] };
                }
                log.debug("after transform: " + Arrays.toString(coeffs));
                return coeffs;
            }
        };

    protected CalculationResult internalCalculate() {

        Function powFunction = FunctionFactory
            .getInstance()
            .getFunction(SQ_POW_FUNCTION_NAME);

        if (powFunction == null) {
            log.error("No '" + SQ_POW_FUNCTION_NAME + "' function found.");
            // TODO: i18n
            addProblem("sq.missing.sq.function");
            return new CalculationResult(new SQResult[0], this);
        }

        Function         function;
        SQ.View          sqView;
        SQ.Factory       sqFactory;
        ParameterCreator pc;


        if (NON_LINEAR_FITTING) {
            log.debug("Use non linear fitting.");
            sqView    = SQ.SQ_VIEW;
            sqFactory = SQ.SQ_FACTORY;
            function  = powFunction;
            pc = new ParameterCreator(
                powFunction.getParameterNames(),
                powFunction.getParameterNames());
        }
        else {
            log.debug("Use linear fitting.");
            sqView    = LogSQ.LOG_SQ_VIEW;
            sqFactory = LogSQ.LOG_SQ_FACTORY;
            function  = FunctionFactory
                .getInstance()
                .getFunction(SQ_LIN_FUNCTION_NAME);
            if (function == null) {
                log.error("No '" + SQ_LIN_FUNCTION_NAME + "' function found.");
                // TODO: i18n
                addProblem("sq.missing.sq.function");
                return new CalculationResult(new SQResult[0], this);
            }
            pc = new LinearParameterCreator(
                powFunction.getParameterNames(),
                function.getParameterNames());
        }

        Measurements measurements =
            MeasurementFactory.getMeasurements(
                river, location, period, sqFactory);

        SQFractionResult [] fractionResults =
            new SQFractionResult[SQResult.NUMBER_FRACTIONS];


        for (int i = 0; i < fractionResults.length; ++i) {
            List<SQ> sqs = measurements.getSQs(i);

            SQFractionResult fractionResult;

            List<SQFractionResult.Iteration> iterations =
                doFitting(function, sqs, sqView, pc);

            if (iterations == null) {
                // TODO: i18n
                addProblem("sq.fitting.failed." + i);
                fractionResult = new SQFractionResult();
            }
            else {
                fractionResult = new SQFractionResult(
                    sqs.toArray(new SQ[sqs.size()]),
                    iterations);
            }

            fractionResults[i] = fractionResult;
        }

        return new CalculationResult(
            new SQResult[] { new SQResult(location, fractionResults) },
            this);
    }

    protected List<SQFractionResult.Iteration> doFitting(
        final Function         function,
        List<SQ>               sqs,
        SQ.View                sqView,
        final ParameterCreator pc
    ) {
        final List<SQFractionResult.Iteration> iterations =
            new ArrayList<SQFractionResult.Iteration>();

        boolean success = new Fitting(function, outliers, sqView).fit(
            sqs,
            method,
            new Fitting.Callback() {
                @Override
                public void afterIteration(
                    double [] coeffs,
                    SQ []     measurements,
                    SQ []     outliers,
                    double    standardDeviation,
                    double    chiSqr
                ) {
                    Parameters parameters = pc.createParameters(
                        coeffs,
                        standardDeviation,
                        chiSqr);
                    iterations.add(new SQFractionResult.Iteration(
                        parameters,
                        measurements,
                        outliers));
                }
            });

        return success ? iterations : null;
    }

    public static class ParameterCreator {

        protected String [] origNames;
        protected String [] proxyNames;

        public ParameterCreator(String [] origNames, String [] proxyNames) {
            this.origNames  = origNames;
            this.proxyNames = proxyNames;
        }

        protected double [] transformCoeffs(double [] coeffs) {
            return coeffs;
        }

        public Parameters createParameters(
            double [] coeffs,
            double    standardDeviation,
            double    chiSqr
        ) {
            String [] columns = new String[origNames.length + 2];
            columns[0] = "chi_sqr";
            columns[1] = "std_dev";
            System.arraycopy(origNames, 0, columns, 2, origNames.length);
            Parameters parameters = new Parameters(columns);
            int row = parameters.newRow();
            parameters.set(row, origNames, transformCoeffs(coeffs));
            parameters.set(row, "chi_sqr", chiSqr);
            parameters.set(row, "std_dev", standardDeviation);
            return parameters;
        }
    }

    /** We need to transform the coeffs back to the original function. */
    public static class LinearParameterCreator extends ParameterCreator {

        public LinearParameterCreator(
            String [] origNames,
            String [] proxyNames
        ) {
            super(origNames, proxyNames);
        }

        @Override
        protected double [] transformCoeffs(double [] coeffs) {

            int bP = StringUtils.indexOf("m", proxyNames);
            int mP = StringUtils.indexOf("b", proxyNames);

            int aO = StringUtils.indexOf("a", origNames);
            int bO = StringUtils.indexOf("b", origNames);

            if (bP == -1 || mP == -1 || aO == -1 || bO == -1) {
                log.error("index not found: "
                    + bP + " " + mP + " " 
                    + aO + " " + bO);
                return coeffs;
            }

            double [] ncoeffs = (double [])coeffs.clone();
            ncoeffs[aO] = Math.exp(coeffs[mP]);
            ncoeffs[bO] = coeffs[bP];

            if (log.isDebugEnabled()) {
                log.debug("before transform: " + Arrays.toString(coeffs));
                log.debug("after transform: " + Arrays.toString(ncoeffs));
            }

            return ncoeffs;
        }
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf-8 :

http://dive4elements.wald.intevation.org