teichmann@5863: /* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde teichmann@5863: * Software engineering by Intevation GmbH teichmann@5863: * teichmann@5994: * This file is Free Software under the GNU AGPL (>=v3) teichmann@5863: * and comes with ABSOLUTELY NO WARRANTY! Check out the teichmann@5994: * documentation coming with Dive4Elements River for details. teichmann@5863: */ teichmann@5863: teichmann@5831: package org.dive4elements.river.artifacts.model.sq; sascha@3188: teichmann@6787: import org.dive4elements.artifacts.common.utils.StringUtils; teichmann@5831: import org.dive4elements.river.artifacts.math.fitting.Function; sascha@3188: sascha@3188: import java.util.ArrayList; sascha@3188: import java.util.List; sascha@3188: sascha@3188: import org.apache.commons.math.MathException; sascha@3188: sascha@3188: import org.apache.commons.math.optimization.fitting.CurveFitter; sascha@3188: sascha@3188: import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; teichmann@6777: import org.apache.commons.math.stat.regression.SimpleRegression; sascha@3188: sascha@3188: import org.apache.log4j.Logger; sascha@3188: sascha@3188: public class Fitting sascha@3552: implements Outlier.Callback sascha@3188: { teichmann@6761: // XXX: Hack to force linear fitting! teichmann@6761: private static final boolean USE_NON_LINEAR_FITTING = teichmann@6761: Boolean.getBoolean("minfo.sq.fitting.nonlinear"); teichmann@6761: sascha@3188: private static Logger log = Logger.getLogger(Fitting.class); sascha@3188: sascha@3552: public interface Callback { sascha@3552: sascha@3552: void afterIteration( sascha@3552: double [] parameters, sascha@3552: SQ [] measurements, sascha@3552: SQ [] outliers, sascha@3552: double standardDeviation, sascha@3552: double chiSqr); sascha@3552: } // interfacte sascha@3552: sascha@3188: protected Function function; sascha@3188: sascha@3552: protected double [] coeffs; sascha@3188: teichmann@5831: protected org.dive4elements.river.artifacts.math.Function instance; sascha@3552: sascha@3552: protected double stdDevFactor; sascha@3188: protected double chiSqr; sascha@3188: sascha@3552: protected Callback callback; sascha@3188: teichmann@6780: protected SQ.View sqView; teichmann@6780: sascha@3188: public Fitting() { sascha@3188: } sascha@3188: teichmann@6780: public Fitting(Function function, double stdDevFactor, SQ.View sqView) { sascha@3188: this.function = function; sascha@3188: this.stdDevFactor = stdDevFactor; teichmann@6780: this.sqView = sqView; sascha@3188: } sascha@3188: sascha@3188: public Function getFunction() { sascha@3188: return function; sascha@3188: } sascha@3188: sascha@3188: public void setFunction(Function function) { sascha@3188: this.function = function; sascha@3188: } sascha@3188: sascha@3188: public double getStdDevFactor() { sascha@3188: return stdDevFactor; sascha@3188: } sascha@3188: sascha@3188: public void setStdDevFactor(double stdDevFactor) { sascha@3188: this.stdDevFactor = stdDevFactor; sascha@3188: } sascha@3188: sascha@3552: @Override sascha@3566: public void initialize(List sqs) throws MathException { sascha@3552: teichmann@6777: if (USE_NON_LINEAR_FITTING teichmann@6787: || function.getParameterNames().length != 2) { teichmann@6777: nonLinearFitting(sqs); teichmann@6777: } teichmann@6777: else { teichmann@6777: linearFitting(sqs); teichmann@6777: } teichmann@6777: } teichmann@6777: teichmann@6777: protected void linearFitting(List sqs) { teichmann@6787: coeffs = linearRegression(sqs); teichmann@6777: instance = function.instantiate(coeffs); teichmann@6777: } teichmann@6777: teichmann@6777: protected double [] linearRegression(List sqs) { teichmann@6777: teichmann@6787: String [] pns = function.getParameterNames(); teichmann@6787: double [] result = new double[pns.length]; teichmann@6787: teichmann@6787: if (sqs.size() < 2) { teichmann@6787: log.debug("not enough points"); teichmann@6787: return result; teichmann@6787: } teichmann@6787: teichmann@6777: SimpleRegression reg = new SimpleRegression(); teichmann@6777: teichmann@6777: for (SQ sq: sqs) { teichmann@6787: double s = sqView.getS(sq); teichmann@6787: double q = sqView.getQ(sq); teichmann@6787: reg.addData(q, s); teichmann@6777: } teichmann@6777: teichmann@6787: double m = reg.getIntercept(); teichmann@6777: double b = reg.getSlope(); teichmann@6777: teichmann@6777: if (log.isDebugEnabled()) { teichmann@6787: log.debug("m: " + m); teichmann@6777: log.debug("b: " + b); teichmann@6777: } teichmann@6777: teichmann@6787: int mIdx = StringUtils.indexOf("m", pns); teichmann@6787: int bIdx = StringUtils.indexOf("b", pns); teichmann@6787: teichmann@6787: if (mIdx == -1 || bIdx == -1) { teichmann@6787: log.error("index not found: " + mIdx + " " + bIdx); teichmann@6787: return result; teichmann@6787: } teichmann@6787: teichmann@6787: result[bIdx] = m; teichmann@6787: result[mIdx] = b; teichmann@6787: teichmann@6787: return result; teichmann@6777: } teichmann@6777: teichmann@6777: teichmann@6777: protected void nonLinearFitting(List sqs) throws MathException { teichmann@6777: teichmann@6777: LevenbergMarquardtOptimizer optimizer = teichmann@6777: new LevenbergMarquardtOptimizer(); sascha@3552: teichmann@6761: CurveFitter cf = new CurveFitter(optimizer); teichmann@6777: sascha@3566: for (SQ sq: sqs) { teichmann@6787: cf.addObservedPoint(sqView.getQ(sq), sqView.getS(sq)); sascha@3552: } sascha@3552: sascha@3552: coeffs = cf.fit( sascha@3552: function, function.getInitialGuess()); sascha@3552: sascha@3552: instance = function.instantiate(coeffs); sascha@3552: teichmann@6761: chiSqr = optimizer.getChiSquare(); teichmann@6761: } teichmann@6761: sascha@3552: @Override sascha@3552: public double eval(SQ sq) { teichmann@6780: double s = instance.value(sqView.getQ(sq)); teichmann@6780: return sqView.getS(sq) - s; sascha@3552: } sascha@3552: sascha@3552: @Override sascha@3566: public void iterationFinished( sascha@3566: double standardDeviation, sascha@3566: SQ outlier, sascha@3566: List remainings sascha@3566: ) { sascha@3552: if (log.isDebugEnabled()) { sascha@3552: log.debug("iterationFinished ----"); sascha@3552: log.debug(" num remainings: " + remainings.size()); sascha@3566: log.debug(" has outlier: " + outlier != null); sascha@3552: log.debug(" standardDeviation: " + standardDeviation); sascha@3552: log.debug(" Chi^2: " + chiSqr); sascha@3552: log.debug("---- iterationFinished"); sascha@3552: } sascha@3552: callback.afterIteration( sascha@3552: coeffs, sascha@3552: remainings.toArray(new SQ[remainings.size()]), sascha@3566: outlier != null ? new SQ [] { outlier} : new SQ [] {}, sascha@3552: standardDeviation, sascha@3552: chiSqr); sascha@3188: } sascha@3188: teichmann@6787: public boolean fit(List sqs, String method, Callback callback) { sascha@3188: sascha@3188: if (sqs.size() < 2) { sascha@3188: log.warn("Too less points for fitting."); sascha@3188: return false; sascha@3188: } sascha@3188: teichmann@6780: sqs = new ArrayList(sqs); teichmann@6780: sascha@3552: this.callback = callback; sascha@3188: sascha@3188: try { rrenkert@5396: Outlier.detectOutliers(this, sqs, stdDevFactor, method); sascha@3188: } sascha@3188: catch (MathException me) { sascha@3188: log.warn(me); sascha@3188: return false; sascha@3188: } sascha@3188: sascha@3188: return true; sascha@3188: } sascha@3188: } sascha@3188: // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :