teichmann@5863: /* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde teichmann@5863: * Software engineering by Intevation GmbH teichmann@5863: * teichmann@5994: * This file is Free Software under the GNU AGPL (>=v3) teichmann@5863: * and comes with ABSOLUTELY NO WARRANTY! Check out the teichmann@5994: * documentation coming with Dive4Elements River for details. teichmann@5863: */ teichmann@5863: teichmann@5831: package org.dive4elements.river.artifacts.model.sq; sascha@3188: teichmann@5831: import org.dive4elements.river.artifacts.math.fitting.Function; sascha@3188: sascha@3188: import java.util.ArrayList; sascha@3188: import java.util.List; sascha@3188: sascha@3188: import org.apache.commons.math.MathException; sascha@3188: sascha@3188: import org.apache.commons.math.optimization.fitting.CurveFitter; sascha@3188: sascha@3188: import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; teichmann@6777: import org.apache.commons.math.stat.regression.SimpleRegression; sascha@3188: sascha@3188: import org.apache.log4j.Logger; sascha@3188: sascha@3188: public class Fitting sascha@3552: implements Outlier.Callback sascha@3188: { teichmann@6761: // XXX: Hack to force linear fitting! teichmann@6761: private static final boolean USE_NON_LINEAR_FITTING = teichmann@6761: Boolean.getBoolean("minfo.sq.fitting.nonlinear"); teichmann@6761: sascha@3188: private static Logger log = Logger.getLogger(Fitting.class); sascha@3188: sascha@3552: public interface Callback { sascha@3552: sascha@3552: void afterIteration( sascha@3552: double [] parameters, sascha@3552: SQ [] measurements, sascha@3552: SQ [] outliers, sascha@3552: double standardDeviation, sascha@3552: double chiSqr); sascha@3552: } // interfacte sascha@3552: sascha@3188: protected Function function; sascha@3188: sascha@3552: protected double [] coeffs; sascha@3188: teichmann@5831: protected org.dive4elements.river.artifacts.math.Function instance; sascha@3552: sascha@3552: protected double stdDevFactor; sascha@3188: protected double chiSqr; sascha@3188: sascha@3552: protected Callback callback; sascha@3188: sascha@3188: public Fitting() { sascha@3188: } sascha@3188: sascha@3188: public Fitting(Function function, double stdDevFactor) { sascha@3552: this(); sascha@3188: this.function = function; sascha@3188: this.stdDevFactor = stdDevFactor; sascha@3188: } sascha@3188: sascha@3188: public Function getFunction() { sascha@3188: return function; sascha@3188: } sascha@3188: sascha@3188: public void setFunction(Function function) { sascha@3188: this.function = function; sascha@3188: } sascha@3188: sascha@3188: public double getStdDevFactor() { sascha@3188: return stdDevFactor; sascha@3188: } sascha@3188: sascha@3188: public void setStdDevFactor(double stdDevFactor) { sascha@3188: this.stdDevFactor = stdDevFactor; sascha@3188: } sascha@3188: sascha@3552: @Override sascha@3566: public void initialize(List sqs) throws MathException { sascha@3552: teichmann@6777: if (USE_NON_LINEAR_FITTING teichmann@6777: || function.getInitialGuess().length != 2) { teichmann@6777: nonLinearFitting(sqs); teichmann@6777: } teichmann@6777: else { teichmann@6777: linearFitting(sqs); teichmann@6777: } teichmann@6777: } teichmann@6777: teichmann@6777: protected void linearFitting(List sqs) { teichmann@6777: teichmann@6777: coeffs = linearRegression(sqs); teichmann@6777: teichmann@6777: instance = function.instantiate(coeffs); teichmann@6777: } teichmann@6777: teichmann@6777: protected double [] linearRegression(List sqs) { teichmann@6777: teichmann@6777: SimpleRegression reg = new SimpleRegression(); teichmann@6777: teichmann@6777: int invalidPoints = 0; teichmann@6777: for (SQ sq: sqs) { teichmann@6777: double s = sq.getS(); teichmann@6777: double q = sq.getQ(); teichmann@6777: if (s <= 0d || q <= 0d) { teichmann@6777: ++invalidPoints; teichmann@6777: continue; teichmann@6777: } teichmann@6777: reg.addData(Math.log(q), Math.log(s)); teichmann@6777: } teichmann@6777: teichmann@6777: if (sqs.size() - invalidPoints < 2) { teichmann@6777: log.debug("not enough points"); teichmann@6777: return new double [] { 0, 0 }; teichmann@6777: } teichmann@6777: teichmann@6777: double a = Math.exp(reg.getIntercept()); teichmann@6777: double b = reg.getSlope(); teichmann@6777: teichmann@6777: if (log.isDebugEnabled()) { teichmann@6777: log.debug("invalid points: " + teichmann@6777: invalidPoints + " (" + sqs.size() + ")"); teichmann@6777: log.debug("a: " + a + " (" + Math.log(a) + ")"); teichmann@6777: log.debug("b: " + b); teichmann@6777: } teichmann@6777: teichmann@6777: return new double [] { a, b }; teichmann@6777: } teichmann@6777: teichmann@6777: teichmann@6777: protected void nonLinearFitting(List sqs) throws MathException { teichmann@6777: teichmann@6777: LevenbergMarquardtOptimizer optimizer = teichmann@6777: new LevenbergMarquardtOptimizer(); sascha@3552: teichmann@6761: CurveFitter cf = new CurveFitter(optimizer); teichmann@6777: sascha@3566: for (SQ sq: sqs) { teichmann@6777: cf.addObservedPoint(sq.getS(), sq.getQ()); sascha@3552: } sascha@3552: sascha@3552: coeffs = cf.fit( sascha@3552: function, function.getInitialGuess()); sascha@3552: sascha@3552: instance = function.instantiate(coeffs); sascha@3552: teichmann@6761: chiSqr = optimizer.getChiSquare(); teichmann@6761: } teichmann@6761: sascha@3552: @Override sascha@3552: public double eval(SQ sq) { sascha@3552: double s = instance.value(sq.q); sascha@3552: return sq.s - s; sascha@3552: } sascha@3552: sascha@3552: @Override sascha@3566: public void iterationFinished( sascha@3566: double standardDeviation, sascha@3566: SQ outlier, sascha@3566: List remainings sascha@3566: ) { sascha@3552: if (log.isDebugEnabled()) { sascha@3552: log.debug("iterationFinished ----"); sascha@3552: log.debug(" num remainings: " + remainings.size()); sascha@3566: log.debug(" has outlier: " + outlier != null); sascha@3552: log.debug(" standardDeviation: " + standardDeviation); sascha@3552: log.debug(" Chi^2: " + chiSqr); sascha@3552: log.debug("---- iterationFinished"); sascha@3552: } sascha@3552: callback.afterIteration( sascha@3552: coeffs, sascha@3552: remainings.toArray(new SQ[remainings.size()]), sascha@3566: outlier != null ? new SQ [] { outlier} : new SQ [] {}, sascha@3552: standardDeviation, sascha@3552: chiSqr); sascha@3188: } sascha@3188: sascha@3188: protected static final List onlyValid(List sqs) { sascha@3188: sascha@3188: List good = new ArrayList(sqs.size()); sascha@3188: sascha@3188: for (SQ sq: sqs) { sascha@3188: if (sq.isValid()) { sascha@3188: good.add(sq); sascha@3188: } sascha@3188: } sascha@3188: sascha@3188: return good; sascha@3188: } sascha@3188: rrenkert@5396: public boolean fit(List sqs, String method, Callback callback) { sascha@3188: sascha@3188: sqs = onlyValid(sqs); sascha@3188: sascha@3188: if (sqs.size() < 2) { sascha@3188: log.warn("Too less points for fitting."); sascha@3188: return false; sascha@3188: } sascha@3188: sascha@3552: this.callback = callback; sascha@3188: sascha@3188: try { rrenkert@5396: Outlier.detectOutliers(this, sqs, stdDevFactor, method); sascha@3188: } sascha@3188: catch (MathException me) { sascha@3188: log.warn(me); sascha@3188: return false; sascha@3188: } sascha@3188: sascha@3188: return true; sascha@3188: } sascha@3188: } sascha@3188: // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :