teichmann@5863: /* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde teichmann@5863: * Software engineering by Intevation GmbH teichmann@5863: * teichmann@5994: * This file is Free Software under the GNU AGPL (>=v3) teichmann@5863: * and comes with ABSOLUTELY NO WARRANTY! Check out the teichmann@5994: * documentation coming with Dive4Elements River for details. teichmann@5863: */ teichmann@5863: teichmann@5831: package org.dive4elements.river.artifacts.model.fixings; sascha@3011: sascha@3107: import gnu.trove.TDoubleArrayList; sascha@3011: sascha@3107: import java.util.ArrayList; sascha@3107: import java.util.List; sascha@3011: sascha@3011: import org.apache.commons.math.MathException; sascha@3011: import org.apache.commons.math.optimization.fitting.CurveFitter; sascha@3011: import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; sascha@3107: import org.apache.commons.math.stat.descriptive.moment.StandardDeviation; rrenkert@4794: import org.apache.log4j.Logger; sascha@3011: teichmann@5831: import org.dive4elements.river.artifacts.math.GrubbsOutlier; teichmann@5831: import org.dive4elements.river.artifacts.math.fitting.Function; sascha@3011: sascha@3011: public class Fitting sascha@3011: { sascha@3011: private static Logger log = Logger.getLogger(Fitting.class); sascha@3011: sascha@3011: /** Use instance of this factory to find meta infos for outliers. */ sascha@3096: public interface QWDFactory { sascha@3011: sascha@3096: QWD create(double q, double w); sascha@3011: sascha@3011: } // interface QWFactory sascha@3011: sascha@3096: public static final QWDFactory QWD_FACTORY = new QWDFactory() { sascha@3022: @Override sascha@3096: public QWD create(double q, double w) { sascha@3096: return new QWD(q, w); sascha@3022: } sascha@3022: }; sascha@3022: sascha@3729: protected boolean checkOutliers; sascha@3729: protected Function function; sascha@3729: protected QWDFactory qwdFactory; sascha@3729: protected double chiSqr; sascha@3729: protected double [] parameters; sascha@3729: protected ArrayList removed; sascha@3729: protected QWD [] referenced; sascha@3729: protected double standardDeviation; sascha@3011: sascha@3011: sascha@3011: public Fitting() { sascha@3729: removed = new ArrayList(); sascha@3011: } sascha@3011: sascha@3011: public Fitting(Function function) { sascha@3096: this(function, QWD_FACTORY); sascha@3011: } sascha@3011: sascha@3096: public Fitting(Function function, QWDFactory qwdFactory) { sascha@3096: this(function, qwdFactory, false); sascha@3022: } sascha@3022: sascha@3022: public Fitting( sascha@3096: Function function, sascha@3096: QWDFactory qwdFactory, sascha@3096: boolean checkOutliers sascha@3022: ) { sascha@3022: this(); sascha@3022: this.function = function; sascha@3096: this.qwdFactory = qwdFactory; sascha@3022: this.checkOutliers = checkOutliers; sascha@3011: } sascha@3011: sascha@3011: public Function getFunction() { sascha@3011: return function; sascha@3011: } sascha@3011: sascha@3011: public void setFunction(Function function) { sascha@3011: this.function = function; sascha@3011: } sascha@3011: sascha@3022: public boolean getCheckOutliers() { sascha@3022: return checkOutliers; sascha@3022: } sascha@3022: sascha@3022: public void setCheckOutliers(boolean checkOutliers) { sascha@3022: this.checkOutliers = checkOutliers; sascha@3022: } sascha@3022: sascha@3011: public double getChiSquare() { sascha@3011: return chiSqr; sascha@3011: } sascha@3011: sascha@3011: public void reset() { sascha@3011: chiSqr = 0.0; sascha@3011: parameters = null; sascha@3011: removed.clear(); sascha@3022: referenced = null; sascha@3107: standardDeviation = 0.0; sascha@3011: } sascha@3011: sascha@3011: public boolean hasOutliers() { sascha@3011: return !removed.isEmpty(); sascha@3011: } sascha@3011: sascha@3729: public List getOutliers() { sascha@3011: return removed; sascha@3011: } sascha@3011: sascha@3729: public QWI [] outliersToArray() { sascha@3729: return removed.toArray(new QWI[removed.size()]); sascha@3011: } sascha@3011: sascha@3096: public QWD [] referencedToArray() { sascha@3096: return referenced != null ? (QWD [])referenced.clone() : null; sascha@3022: } sascha@3022: sascha@3065: public double getMaxQ() { sascha@3065: double maxQ = -Double.MAX_VALUE; sascha@3065: if (referenced != null) { sascha@3729: for (QWI qw: referenced) { teichmann@6868: double q = qw.getQ(); teichmann@6868: if (q > maxQ) { teichmann@6868: maxQ = q; sascha@3065: } sascha@3065: } sascha@3065: } sascha@3065: return maxQ; sascha@3065: } sascha@3065: sascha@3011: public double [] getParameters() { sascha@3011: return parameters; sascha@3011: } sascha@3011: sascha@3107: public double getStandardDeviation() { sascha@3107: return standardDeviation; sascha@3107: } sascha@3107: sascha@3011: public boolean fit(double [] qs, double [] ws) { sascha@3011: sascha@3011: TDoubleArrayList xs = new TDoubleArrayList(qs.length); sascha@3011: TDoubleArrayList ys = new TDoubleArrayList(ws.length); sascha@3011: sascha@3011: for (int i = 0; i < qs.length; ++i) { sascha@3011: if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { sascha@3011: xs.add(qs[i]); sascha@3011: ys.add(ws[i]); sascha@3011: } sascha@3011: } sascha@3011: sascha@3011: if (xs.size() < 2) { sascha@3073: log.warn("Too less points."); sascha@3011: return false; sascha@3011: } sascha@3011: sascha@3565: List inputs = new ArrayList(xs.size()); sascha@3011: teichmann@5831: org.dive4elements.river.artifacts.math.Function instance = null; sascha@3096: sascha@3202: LevenbergMarquardtOptimizer lmo = null; sascha@3202: sascha@3011: for (;;) { sascha@3202: parameters = null; aheinecke@7300: for (double tolerance = 1e-10; tolerance < 1e-1; tolerance *= 10d) { sascha@3011: sascha@3202: lmo = new LevenbergMarquardtOptimizer(); sascha@3202: lmo.setCostRelativeTolerance(tolerance); sascha@3202: lmo.setOrthoTolerance(tolerance); sascha@3202: lmo.setParRelativeTolerance(tolerance); sascha@3202: sascha@3202: CurveFitter cf = new CurveFitter(lmo); sascha@3202: sascha@3202: for (int i = 0, N = xs.size(); i < N; ++i) { sascha@3202: cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); sascha@3202: } sascha@3202: sascha@3202: try { sascha@3202: parameters = cf.fit(function, function.getInitialGuess()); sascha@3202: break; sascha@3202: } sascha@3202: catch (MathException me) { sascha@3202: if (log.isDebugEnabled()) { sascha@3202: log.debug("tolerance " + tolerance + " + failed."); sascha@3202: } sascha@3202: } sascha@3011: } sascha@3202: if (parameters == null) { aheinecke@7300: /* aheinecke@7300: log.debug("Parameters is null"); aheinecke@7300: for (int i = 0, N = xs.size(); i < N; ++i) { aheinecke@7300: log.debug("DATA: " + xs.getQuick(i) + " " + ys.getQuick(i)); aheinecke@7300: }*/ sascha@3011: return false; sascha@3011: } sascha@3011: raimund@3127: // This is the paraterized function for a given km. raimund@3127: instance = function.instantiate(parameters); raimund@3127: sascha@3022: if (!checkOutliers) { sascha@3011: break; sascha@3011: } sascha@3011: sascha@3011: inputs.clear(); sascha@3011: sascha@3011: for (int i = 0, N = xs.size(); i < N; ++i) { sascha@3011: double y = instance.value(xs.getQuick(i)); sascha@3011: if (Double.isNaN(y)) { sascha@3565: y = Double.MAX_VALUE; sascha@3011: } sascha@3565: inputs.add(Double.valueOf(ys.getQuick(i) - y)); sascha@3011: } sascha@3011: rrenkert@4794: Integer outlier = GrubbsOutlier.findOutlier(inputs); sascha@3011: sascha@3565: if (outlier == null) { sascha@3011: break; sascha@3011: } sascha@3011: sascha@3565: int idx = outlier.intValue(); sascha@3565: removed.add( sascha@3565: qwdFactory.create( sascha@3565: xs.getQuick(idx), ys.getQuick(idx))); sascha@3565: xs.remove(idx); sascha@3565: ys.remove(idx); sascha@3011: } sascha@3011: sascha@3107: StandardDeviation stdDev = new StandardDeviation(); sascha@3107: sascha@3096: referenced = new QWD[xs.size()]; sascha@3022: for (int i = 0; i < referenced.length; ++i) { sascha@3096: QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i)); ingo@3066: sascha@3096: if (qwd == null) { ingo@3066: log.warn("QW creation failed!"); ingo@3066: } ingo@3066: else { sascha@3096: referenced[i] = qwd; sascha@3096: double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0; sascha@3096: qwd.setDeltaW(dw); sascha@3107: stdDev.increment(dw); ingo@3066: } sascha@3022: } sascha@3022: sascha@3107: standardDeviation = stdDev.getResult(); sascha@3107: sascha@3011: chiSqr = lmo.getChiSquare(); sascha@3011: sascha@3011: return true; sascha@3011: } sascha@3011: } sascha@3011: // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :