Mercurial > dive4elements > river
diff artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 5838:5aa05a7a34b7
Rename modules to more fitting names.
author | Sascha L. Teichmann <teichmann@intevation.de> |
---|---|
date | Thu, 25 Apr 2013 15:23:37 +0200 |
parents | flys-artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java@bd047b71ab37 |
children | 4897a58c8746 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java Thu Apr 25 15:23:37 2013 +0200 @@ -0,0 +1,242 @@ +package org.dive4elements.river.artifacts.model.fixings; + +import gnu.trove.TDoubleArrayList; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.math.MathException; +import org.apache.commons.math.optimization.fitting.CurveFitter; +import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; +import org.apache.commons.math.stat.descriptive.moment.StandardDeviation; +import org.apache.log4j.Logger; + +import org.dive4elements.river.artifacts.math.GrubbsOutlier; +import org.dive4elements.river.artifacts.math.fitting.Function; + +public class Fitting +{ + private static Logger log = Logger.getLogger(Fitting.class); + + /** Use instance of this factory to find meta infos for outliers. */ + public interface QWDFactory { + + QWD create(double q, double w); + + } // interface QWFactory + + public static final QWDFactory QWD_FACTORY = new QWDFactory() { + @Override + public QWD create(double q, double w) { + return new QWD(q, w); + } + }; + + protected boolean checkOutliers; + protected Function function; + protected QWDFactory qwdFactory; + protected double chiSqr; + protected double [] parameters; + protected ArrayList<QWI> removed; + protected QWD [] referenced; + protected double standardDeviation; + + + public Fitting() { + removed = new ArrayList<QWI>(); + } + + public Fitting(Function function) { + this(function, QWD_FACTORY); + } + + public Fitting(Function function, QWDFactory qwdFactory) { + this(function, qwdFactory, false); + } + + public Fitting( + Function function, + QWDFactory qwdFactory, + boolean checkOutliers + ) { + this(); + this.function = function; + this.qwdFactory = qwdFactory; + this.checkOutliers = checkOutliers; + } + + public Function getFunction() { + return function; + } + + public void setFunction(Function function) { + this.function = function; + } + + public boolean getCheckOutliers() { + return checkOutliers; + } + + public void setCheckOutliers(boolean checkOutliers) { + this.checkOutliers = checkOutliers; + } + + public double getChiSquare() { + return chiSqr; + } + + public void reset() { + chiSqr = 0.0; + parameters = null; + removed.clear(); + referenced = null; + standardDeviation = 0.0; + } + + public boolean hasOutliers() { + return !removed.isEmpty(); + } + + public List<QWI> getOutliers() { + return removed; + } + + public QWI [] outliersToArray() { + return removed.toArray(new QWI[removed.size()]); + } + + public QWD [] referencedToArray() { + return referenced != null ? (QWD [])referenced.clone() : null; + } + + public double getMaxQ() { + double maxQ = -Double.MAX_VALUE; + if (referenced != null) { + for (QWI qw: referenced) { + if (qw.getQ() > maxQ) { + maxQ = qw.getQ(); + } + } + } + return maxQ; + } + + public double [] getParameters() { + return parameters; + } + + public double getStandardDeviation() { + return standardDeviation; + } + + public boolean fit(double [] qs, double [] ws) { + + TDoubleArrayList xs = new TDoubleArrayList(qs.length); + TDoubleArrayList ys = new TDoubleArrayList(ws.length); + + for (int i = 0; i < qs.length; ++i) { + if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { + xs.add(qs[i]); + ys.add(ws[i]); + } + else { + log.warn("remove invalid value " + qs[i] + " " + ws[i]); + } + } + + if (xs.size() < 2) { + log.warn("Too less points."); + return false; + } + + List<Double> inputs = new ArrayList<Double>(xs.size()); + + org.dive4elements.river.artifacts.math.Function instance = null; + + LevenbergMarquardtOptimizer lmo = null; + + for (;;) { + parameters = null; + for (double tolerance = 1e-10; tolerance < 1e-3; tolerance *= 10d) { + + lmo = new LevenbergMarquardtOptimizer(); + lmo.setCostRelativeTolerance(tolerance); + lmo.setOrthoTolerance(tolerance); + lmo.setParRelativeTolerance(tolerance); + + CurveFitter cf = new CurveFitter(lmo); + + for (int i = 0, N = xs.size(); i < N; ++i) { + cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); + } + + try { + parameters = cf.fit(function, function.getInitialGuess()); + break; + } + catch (MathException me) { + if (log.isDebugEnabled()) { + log.debug("tolerance " + tolerance + " + failed."); + } + } + } + if (parameters == null) { + return false; + } + + // This is the paraterized function for a given km. + instance = function.instantiate(parameters); + + if (!checkOutliers) { + break; + } + + inputs.clear(); + + for (int i = 0, N = xs.size(); i < N; ++i) { + double y = instance.value(xs.getQuick(i)); + if (Double.isNaN(y)) { + y = Double.MAX_VALUE; + } + inputs.add(Double.valueOf(ys.getQuick(i) - y)); + } + + Integer outlier = GrubbsOutlier.findOutlier(inputs); + + if (outlier == null) { + break; + } + + int idx = outlier.intValue(); + removed.add( + qwdFactory.create( + xs.getQuick(idx), ys.getQuick(idx))); + xs.remove(idx); + ys.remove(idx); + } + + StandardDeviation stdDev = new StandardDeviation(); + + referenced = new QWD[xs.size()]; + for (int i = 0; i < referenced.length; ++i) { + QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i)); + + if (qwd == null) { + log.warn("QW creation failed!"); + } + else { + referenced[i] = qwd; + double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0; + qwd.setDeltaW(dw); + stdDev.increment(dw); + } + } + + standardDeviation = stdDev.getResult(); + + chiSqr = lmo.getChiSquare(); + + return true; + } +} +// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :