diff artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 5838:5aa05a7a34b7

Rename modules to more fitting names.
author Sascha L. Teichmann <teichmann@intevation.de>
date Thu, 25 Apr 2013 15:23:37 +0200
parents flys-artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java@bd047b71ab37
children 4897a58c8746
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java	Thu Apr 25 15:23:37 2013 +0200
@@ -0,0 +1,242 @@
+package org.dive4elements.river.artifacts.model.fixings;
+
+import gnu.trove.TDoubleArrayList;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.optimization.fitting.CurveFitter;
+import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
+import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
+import org.apache.log4j.Logger;
+
+import org.dive4elements.river.artifacts.math.GrubbsOutlier;
+import org.dive4elements.river.artifacts.math.fitting.Function;
+
+public class Fitting
+{
+    private static Logger log = Logger.getLogger(Fitting.class);
+
+    /** Use instance of this factory to find meta infos for outliers. */
+    public interface QWDFactory {
+
+        QWD create(double q, double w);
+
+    } // interface QWFactory
+
+    public static final QWDFactory QWD_FACTORY = new QWDFactory() {
+        @Override
+        public QWD create(double q, double w) {
+            return new QWD(q, w);
+        }
+    };
+
+    protected boolean        checkOutliers;
+    protected Function       function;
+    protected QWDFactory     qwdFactory;
+    protected double         chiSqr;
+    protected double []      parameters;
+    protected ArrayList<QWI> removed;
+    protected QWD []         referenced;
+    protected double         standardDeviation;
+
+
+    public Fitting() {
+        removed = new ArrayList<QWI>();
+    }
+
+    public Fitting(Function function) {
+        this(function, QWD_FACTORY);
+    }
+
+    public Fitting(Function function, QWDFactory qwdFactory) {
+        this(function, qwdFactory, false);
+    }
+
+    public Fitting(
+        Function   function,
+        QWDFactory qwdFactory,
+        boolean    checkOutliers
+    ) {
+        this();
+        this.function      = function;
+        this.qwdFactory    = qwdFactory;
+        this.checkOutliers = checkOutliers;
+    }
+
+    public Function getFunction() {
+        return function;
+    }
+
+    public void setFunction(Function function) {
+        this.function = function;
+    }
+
+    public boolean getCheckOutliers() {
+        return checkOutliers;
+    }
+
+    public void setCheckOutliers(boolean checkOutliers) {
+        this.checkOutliers = checkOutliers;
+    }
+
+    public double getChiSquare() {
+        return chiSqr;
+    }
+
+    public void reset() {
+        chiSqr     = 0.0;
+        parameters = null;
+        removed.clear();
+        referenced = null;
+        standardDeviation = 0.0;
+    }
+
+    public boolean hasOutliers() {
+        return !removed.isEmpty();
+    }
+
+    public List<QWI> getOutliers() {
+        return removed;
+    }
+
+    public QWI [] outliersToArray() {
+        return removed.toArray(new QWI[removed.size()]);
+    }
+
+    public QWD [] referencedToArray() {
+        return referenced != null ? (QWD [])referenced.clone() : null;
+    }
+
+    public double getMaxQ() {
+        double maxQ = -Double.MAX_VALUE;
+        if (referenced != null) {
+            for (QWI qw: referenced) {
+                if (qw.getQ() > maxQ) {
+                    maxQ = qw.getQ();
+                }
+            }
+        }
+        return maxQ;
+    }
+
+    public double [] getParameters() {
+        return parameters;
+    }
+
+    public double getStandardDeviation() {
+        return standardDeviation;
+    }
+
+    public boolean fit(double [] qs, double [] ws) {
+
+        TDoubleArrayList xs = new TDoubleArrayList(qs.length);
+        TDoubleArrayList ys = new TDoubleArrayList(ws.length);
+
+        for (int i = 0; i < qs.length; ++i) {
+            if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) {
+                xs.add(qs[i]);
+                ys.add(ws[i]);
+            }
+            else {
+                log.warn("remove invalid value " + qs[i] + " " + ws[i]);
+            }
+        }
+
+        if (xs.size() < 2) {
+            log.warn("Too less points.");
+            return false;
+        }
+
+        List<Double> inputs = new ArrayList<Double>(xs.size());
+
+        org.dive4elements.river.artifacts.math.Function instance = null;
+
+        LevenbergMarquardtOptimizer lmo = null;
+
+        for (;;) {
+            parameters = null;
+            for (double tolerance = 1e-10; tolerance < 1e-3; tolerance *= 10d) {
+
+                lmo = new LevenbergMarquardtOptimizer();
+                lmo.setCostRelativeTolerance(tolerance);
+                lmo.setOrthoTolerance(tolerance);
+                lmo.setParRelativeTolerance(tolerance);
+
+                CurveFitter cf = new CurveFitter(lmo);
+
+                for (int i = 0, N = xs.size(); i < N; ++i) {
+                    cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i));
+                }
+
+                try {
+                    parameters = cf.fit(function, function.getInitialGuess());
+                    break;
+                }
+                catch (MathException me) {
+                    if (log.isDebugEnabled()) {
+                        log.debug("tolerance " + tolerance + " + failed.");
+                    }
+                }
+            }
+            if (parameters == null) {
+                return false;
+            }
+
+            // This is the paraterized function for a given km.
+            instance = function.instantiate(parameters);
+
+            if (!checkOutliers) {
+                break;
+            }
+
+            inputs.clear();
+
+            for (int i = 0, N = xs.size(); i < N; ++i) {
+                double y = instance.value(xs.getQuick(i));
+                if (Double.isNaN(y)) {
+                    y = Double.MAX_VALUE;
+                }
+                inputs.add(Double.valueOf(ys.getQuick(i) - y));
+            }
+
+            Integer outlier = GrubbsOutlier.findOutlier(inputs);
+
+            if (outlier == null) {
+                break;
+            }
+
+            int idx = outlier.intValue();
+            removed.add(
+                qwdFactory.create(
+                    xs.getQuick(idx), ys.getQuick(idx)));
+            xs.remove(idx);
+            ys.remove(idx);
+        }
+
+        StandardDeviation stdDev = new StandardDeviation();
+
+        referenced = new QWD[xs.size()];
+        for (int i = 0; i < referenced.length; ++i) {
+            QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i));
+
+            if (qwd == null) {
+                log.warn("QW creation failed!");
+            }
+            else {
+                referenced[i] = qwd;
+                double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0;
+                qwd.setDeltaW(dw);
+                stdDev.increment(dw);
+            }
+        }
+
+        standardDeviation = stdDev.getResult();
+
+        chiSqr = lmo.getChiSquare();
+
+        return true;
+    }
+}
+// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org