diff flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3552:1df6984628c3

S/Q: Extented the result data model of the S/Q calculation to store the curve coefficients for each iteration step of the outlier elimination. flys-artifacts/trunk@5146 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 27 Jul 2012 12:36:09 +0000
parents 49fe2ed03c12
children 8d0f06b76e09
line wrap: on
line diff
--- a/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java	Fri Jul 27 08:36:24 2012 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java	Fri Jul 27 12:36:09 2012 +0000
@@ -15,27 +15,42 @@
 import org.apache.log4j.Logger;
 
 public class Fitting
+implements   Outlier.Callback
 {
     private static Logger log = Logger.getLogger(Fitting.class);
 
+    public interface Callback {
+
+        void afterIteration(
+            double [] parameters,
+            SQ []     measurements,
+            SQ []     outliers,
+            double    standardDeviation,
+            double    chiSqr);
+    } // interfacte
+
     protected Function function;
 
-    protected double [] parameters;
+    protected double [] coeffs;
 
-    protected double stdDevFactor;
+    protected de.intevation.flys.artifacts.math.Function instance;
+
+    protected List<SQ> remainings;
+    protected List<SQ> outliers;
 
     protected double standardDeviation;
-
+    protected double stdDevFactor;
     protected double chiSqr;
 
-    protected SQ [] remaining;
-
-    protected List<SQ []> outliers;
+    protected Callback callback;
 
     public Fitting() {
+        remainings = new ArrayList<SQ>();
+        outliers   = new ArrayList<SQ>();
     }
 
     public Fitting(Function function, double stdDevFactor) {
+        this();
         this.function     = function;
         this.stdDevFactor = stdDevFactor;
     }
@@ -48,14 +63,6 @@
         this.function = function;
     }
 
-    public double [] getParameters() {
-        return parameters;
-    }
-
-    public void setParameters(double [] parameters) {
-        this.parameters = parameters;
-    }
-
     public double getStdDevFactor() {
         return stdDevFactor;
     }
@@ -64,45 +71,66 @@
         this.stdDevFactor = stdDevFactor;
     }
 
-    public double getStandardDeviation() {
-        return standardDeviation;
+    @Override
+    public void initialize(Iterator<SQ> good) throws MathException {
+
+        LevenbergMarquardtOptimizer lmo =
+            new LevenbergMarquardtOptimizer();
+
+        CurveFitter cf = new CurveFitter(lmo);
+        while (good.hasNext()) {
+            SQ sq = good.next();
+            cf.addObservedPoint(sq.getQ(), sq.getS());
+        }
+
+        coeffs = cf.fit(
+            function, function.getInitialGuess());
+
+        instance = function.instantiate(coeffs);
+
+        chiSqr = lmo.getChiSquare();
+
     }
 
-    public void setStandardDeviation(double standardDeviation) {
+    @Override
+    public double eval(SQ sq) {
+        double s = instance.value(sq.q);
+        return sq.s - s;
+    }
+
+    @Override
+    public void outlier(SQ sq) {
+        outliers.add(sq);
+    }
+
+    @Override
+    public void remaining(SQ sq) {
+        remainings.add(sq);
+    }
+
+    @Override
+    public void standardDeviation(double standardDeviation) {
         this.standardDeviation = standardDeviation;
     }
 
-    public double getChiSqr() {
-        return chiSqr;
-    }
-
-    public void setChiSqr(double chiSqr) {
-        this.chiSqr = chiSqr;
-    }
-
-    public SQ [] getRemaining() {
-        return remaining;
-    }
-
-    public void setRemaining(SQ [] remaining) {
-        this.remaining = remaining;
-    }
-
-    public List<SQ []> getOutliers() {
-        return outliers;
-    }
-
-    public void setOutliers(List<SQ []> outliers) {
-        this.outliers = outliers;
-    }
-
-    public void reset() {
-        outliers          = null;
-        remaining         = null;
-        parameters        = null;
-        standardDeviation = 0d;
-        standardDeviation = 0d;
-        chiSqr            = 0d;
+    @Override
+    public void iterationFinished() {
+        if (log.isDebugEnabled()) {
+            log.debug("iterationFinished ----");
+            log.debug(" num remainings: " + remainings.size());
+            log.debug(" num outliers: " + outliers.size());
+            log.debug(" standardDeviation: " + standardDeviation);
+            log.debug(" Chi^2: " + chiSqr);
+            log.debug("---- iterationFinished");
+        }
+        callback.afterIteration(
+            coeffs,
+            remainings.toArray(new SQ[remainings.size()]),
+            outliers.toArray(new SQ[outliers.size()]),
+            standardDeviation,
+            chiSqr);
+        remainings.clear();
+        outliers.clear();
     }
 
     protected static final List<SQ> onlyValid(List<SQ> sqs) {
@@ -118,7 +146,7 @@
         return good;
     }
 
-    public boolean fit(List<SQ> sqs) {
+    public boolean fit(List<SQ> sqs, Callback callback) {
 
         sqs = onlyValid(sqs);
 
@@ -127,94 +155,10 @@
             return false;
         }
 
-        final LevenbergMarquardtOptimizer lmo =
-            new LevenbergMarquardtOptimizer();
-
-        CurveFitter cf = new CurveFitter(lmo);
-
-        for (SQ sq: sqs) {
-            cf.addObservedPoint(sq.getQ(), sq.getS());
-        }
-
-        try {
-            parameters = cf.fit(function, function.getInitialGuess());
-        }
-        catch (MathException me) {
-            log.warn(me);
-            return false;
-        }
-
-        chiSqr = lmo.getChiSquare();
-
-        final de.intevation.flys.artifacts.math.Function [] instance = {
-            function.instantiate(parameters)
-        };
+        this.callback = callback;
 
         try {
-            remaining = Outlier.detectOutliers(
-                new Outlier.Callback() {
-
-                    List<List<SQ>> outliers =
-                        new ArrayList<List<SQ>>();
-
-                    int currentIteration;
-
-                    @Override
-                    public double eval(SQ sq) {
-                        double s = instance[0].value(sq.q);
-                        return s - sq.s;
-                    }
-
-                    @Override
-                    public void iteration(int i) {
-                        currentIteration = i;
-                    }
-
-                    @Override
-                    public void outlier(SQ sq) {
-                        if (currentIteration > outliers.size()) {
-                            outliers.add(new ArrayList<SQ>(2));
-                        }
-                        outliers.get(currentIteration-1).add(sq);
-                    }
-
-                    @Override
-                    public void standardDeviation(double stdDev) {
-                        setStandardDeviation(stdDev);
-                    }
-
-                    @Override
-                    public void reinitialize(Iterator<SQ> good)
-                    throws MathException
-                    {
-                        CurveFitter cf = new CurveFitter(lmo);
-                        while (good.hasNext()) {
-                            SQ sq = good.next();
-                            cf.addObservedPoint(sq.getQ(), sq.getS());
-                        }
-
-                        parameters = cf.fit(
-                            function, function.getInitialGuess());
-
-                        instance[0] = function.instantiate(parameters);
-
-                        chiSqr = lmo.getChiSquare();
-                    }
-
-                    @Override
-                    public void finished() {
-                        List<SQ []> result =
-                            new ArrayList<SQ []>(outliers.size());
-
-                        for (List<SQ> ols: outliers) {
-                            result.add(ols.toArray(new SQ[ols.size()]));
-                        }
-
-                        setOutliers(result);
-                    }
-                },
-                sqs,
-                stdDevFactor);
+            Outlier.detectOutliers(this, sqs, stdDevFactor);
         }
         catch (MathException me) {
             log.warn(me);

http://dive4elements.wald.intevation.org