Mercurial > dive4elements > river
view artifacts/src/main/java/org/dive4elements/river/artifacts/model/sq/Fitting.java @ 6761:9479cb7c8cd5
flys/issue748: Force linear curve fitting. This is a real hack! Set the system property "minfo.sq.fitting.nonlinear" to re-enable the old behavior.
author | Sascha L. Teichmann <teichmann@intevation.de> |
---|---|
date | Tue, 06 Aug 2013 17:00:49 +0200 |
parents | af13ceeba52a |
children | 48f6780c372d |
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde * Software engineering by Intevation GmbH * * This file is Free Software under the GNU AGPL (>=v3) * and comes with ABSOLUTELY NO WARRANTY! Check out the * documentation coming with Dive4Elements River for details. */ package org.dive4elements.river.artifacts.model.sq; import org.dive4elements.river.artifacts.math.fitting.Function; import org.dive4elements.river.artifacts.math.fitting.Linear; import java.util.ArrayList; import java.util.List; import org.apache.commons.math.MathException; import org.apache.commons.math.optimization.fitting.CurveFitter; import org.apache.commons.math.optimization.general.AbstractLeastSquaresOptimizer; import org.apache.commons.math.optimization.general.GaussNewtonOptimizer; import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; import org.apache.log4j.Logger; public class Fitting implements Outlier.Callback { // XXX: Hack to force linear fitting! private static final boolean USE_NON_LINEAR_FITTING = Boolean.getBoolean("minfo.sq.fitting.nonlinear"); private static Logger log = Logger.getLogger(Fitting.class); public interface Callback { void afterIteration( double [] parameters, SQ [] measurements, SQ [] outliers, double standardDeviation, double chiSqr); } // interfacte protected Function function; protected double [] coeffs; protected org.dive4elements.river.artifacts.math.Function instance; protected double stdDevFactor; protected double chiSqr; protected Callback callback; public Fitting() { } public Fitting(Function function, double stdDevFactor) { this(); this.function = function; this.stdDevFactor = stdDevFactor; } public Function getFunction() { return function; } public void setFunction(Function function) { this.function = function; } public double getStdDevFactor() { return stdDevFactor; } public void setStdDevFactor(double stdDevFactor) { this.stdDevFactor = stdDevFactor; } @Override public void initialize(List<SQ> sqs) throws MathException { AbstractLeastSquaresOptimizer optimizer = getOptimizer(); CurveFitter cf = new CurveFitter(optimizer); double [] values = new double[2]; for (SQ sq: sqs) { values[0] = sq.getQ(); values[1] = sq.getS(); transformInputValues(values); cf.addObservedPoint(values[0], values[1]); } coeffs = cf.fit( function, function.getInitialGuess()); transformCoeffsBack(coeffs); instance = function.instantiate(coeffs); chiSqr = optimizer.getChiSquare(); } protected Function getFunction(Function function) { return USE_NON_LINEAR_FITTING ? function : Linear.INSTANCE; } protected void transformInputValues(double [] values) { if (!USE_NON_LINEAR_FITTING) { for (int i = 0; i < values.length; ++i) { values[i] = Math.log(values[i]); } } } protected AbstractLeastSquaresOptimizer getOptimizer() { return USE_NON_LINEAR_FITTING ? new LevenbergMarquardtOptimizer() : new GaussNewtonOptimizer(false); } protected void transformCoeffsBack(double [] coeffs) { if (!USE_NON_LINEAR_FITTING && coeffs.length > 0) { coeffs[0] = Math.exp(coeffs[0]); } } @Override public double eval(SQ sq) { double s = instance.value(sq.q); return sq.s - s; } @Override public void iterationFinished( double standardDeviation, SQ outlier, List<SQ> remainings ) { if (log.isDebugEnabled()) { log.debug("iterationFinished ----"); log.debug(" num remainings: " + remainings.size()); log.debug(" has outlier: " + outlier != null); log.debug(" standardDeviation: " + standardDeviation); log.debug(" Chi^2: " + chiSqr); log.debug("---- iterationFinished"); } callback.afterIteration( coeffs, remainings.toArray(new SQ[remainings.size()]), outlier != null ? new SQ [] { outlier} : new SQ [] {}, standardDeviation, chiSqr); } protected static final List<SQ> onlyValid(List<SQ> sqs) { List<SQ> good = new ArrayList<SQ>(sqs.size()); for (SQ sq: sqs) { if (sq.isValid()) { good.add(sq); } } return good; } public boolean fit(List<SQ> sqs, String method, Callback callback) { sqs = onlyValid(sqs); if (sqs.size() < 2) { log.warn("Too less points for fitting."); return false; } this.callback = callback; try { Outlier.detectOutliers(this, sqs, stdDevFactor, method); } catch (MathException me) { log.warn(me); return false; } return true; } } // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :