Mercurial > dive4elements > river
diff flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3552:1df6984628c3
S/Q: Extented the result data model of the S/Q calculation to
store the curve coefficients for each iteration step
of the outlier elimination.
flys-artifacts/trunk@5146 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Fri, 27 Jul 2012 12:36:09 +0000 |
parents | 49fe2ed03c12 |
children | 8d0f06b76e09 |
line wrap: on
line diff
--- a/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java Fri Jul 27 08:36:24 2012 +0000 +++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java Fri Jul 27 12:36:09 2012 +0000 @@ -15,27 +15,42 @@ import org.apache.log4j.Logger; public class Fitting +implements Outlier.Callback { private static Logger log = Logger.getLogger(Fitting.class); + public interface Callback { + + void afterIteration( + double [] parameters, + SQ [] measurements, + SQ [] outliers, + double standardDeviation, + double chiSqr); + } // interfacte + protected Function function; - protected double [] parameters; + protected double [] coeffs; - protected double stdDevFactor; + protected de.intevation.flys.artifacts.math.Function instance; + + protected List<SQ> remainings; + protected List<SQ> outliers; protected double standardDeviation; - + protected double stdDevFactor; protected double chiSqr; - protected SQ [] remaining; - - protected List<SQ []> outliers; + protected Callback callback; public Fitting() { + remainings = new ArrayList<SQ>(); + outliers = new ArrayList<SQ>(); } public Fitting(Function function, double stdDevFactor) { + this(); this.function = function; this.stdDevFactor = stdDevFactor; } @@ -48,14 +63,6 @@ this.function = function; } - public double [] getParameters() { - return parameters; - } - - public void setParameters(double [] parameters) { - this.parameters = parameters; - } - public double getStdDevFactor() { return stdDevFactor; } @@ -64,45 +71,66 @@ this.stdDevFactor = stdDevFactor; } - public double getStandardDeviation() { - return standardDeviation; + @Override + public void initialize(Iterator<SQ> good) throws MathException { + + LevenbergMarquardtOptimizer lmo = + new LevenbergMarquardtOptimizer(); + + CurveFitter cf = new CurveFitter(lmo); + while (good.hasNext()) { + SQ sq = good.next(); + cf.addObservedPoint(sq.getQ(), sq.getS()); + } + + coeffs = cf.fit( + function, function.getInitialGuess()); + + instance = function.instantiate(coeffs); + + chiSqr = lmo.getChiSquare(); + } - public void setStandardDeviation(double standardDeviation) { + @Override + public double eval(SQ sq) { + double s = instance.value(sq.q); + return sq.s - s; + } + + @Override + public void outlier(SQ sq) { + outliers.add(sq); + } + + @Override + public void remaining(SQ sq) { + remainings.add(sq); + } + + @Override + public void standardDeviation(double standardDeviation) { this.standardDeviation = standardDeviation; } - public double getChiSqr() { - return chiSqr; - } - - public void setChiSqr(double chiSqr) { - this.chiSqr = chiSqr; - } - - public SQ [] getRemaining() { - return remaining; - } - - public void setRemaining(SQ [] remaining) { - this.remaining = remaining; - } - - public List<SQ []> getOutliers() { - return outliers; - } - - public void setOutliers(List<SQ []> outliers) { - this.outliers = outliers; - } - - public void reset() { - outliers = null; - remaining = null; - parameters = null; - standardDeviation = 0d; - standardDeviation = 0d; - chiSqr = 0d; + @Override + public void iterationFinished() { + if (log.isDebugEnabled()) { + log.debug("iterationFinished ----"); + log.debug(" num remainings: " + remainings.size()); + log.debug(" num outliers: " + outliers.size()); + log.debug(" standardDeviation: " + standardDeviation); + log.debug(" Chi^2: " + chiSqr); + log.debug("---- iterationFinished"); + } + callback.afterIteration( + coeffs, + remainings.toArray(new SQ[remainings.size()]), + outliers.toArray(new SQ[outliers.size()]), + standardDeviation, + chiSqr); + remainings.clear(); + outliers.clear(); } protected static final List<SQ> onlyValid(List<SQ> sqs) { @@ -118,7 +146,7 @@ return good; } - public boolean fit(List<SQ> sqs) { + public boolean fit(List<SQ> sqs, Callback callback) { sqs = onlyValid(sqs); @@ -127,94 +155,10 @@ return false; } - final LevenbergMarquardtOptimizer lmo = - new LevenbergMarquardtOptimizer(); - - CurveFitter cf = new CurveFitter(lmo); - - for (SQ sq: sqs) { - cf.addObservedPoint(sq.getQ(), sq.getS()); - } - - try { - parameters = cf.fit(function, function.getInitialGuess()); - } - catch (MathException me) { - log.warn(me); - return false; - } - - chiSqr = lmo.getChiSquare(); - - final de.intevation.flys.artifacts.math.Function [] instance = { - function.instantiate(parameters) - }; + this.callback = callback; try { - remaining = Outlier.detectOutliers( - new Outlier.Callback() { - - List<List<SQ>> outliers = - new ArrayList<List<SQ>>(); - - int currentIteration; - - @Override - public double eval(SQ sq) { - double s = instance[0].value(sq.q); - return s - sq.s; - } - - @Override - public void iteration(int i) { - currentIteration = i; - } - - @Override - public void outlier(SQ sq) { - if (currentIteration > outliers.size()) { - outliers.add(new ArrayList<SQ>(2)); - } - outliers.get(currentIteration-1).add(sq); - } - - @Override - public void standardDeviation(double stdDev) { - setStandardDeviation(stdDev); - } - - @Override - public void reinitialize(Iterator<SQ> good) - throws MathException - { - CurveFitter cf = new CurveFitter(lmo); - while (good.hasNext()) { - SQ sq = good.next(); - cf.addObservedPoint(sq.getQ(), sq.getS()); - } - - parameters = cf.fit( - function, function.getInitialGuess()); - - instance[0] = function.instantiate(parameters); - - chiSqr = lmo.getChiSquare(); - } - - @Override - public void finished() { - List<SQ []> result = - new ArrayList<SQ []>(outliers.size()); - - for (List<SQ> ols: outliers) { - result.add(ols.toArray(new SQ[ols.size()])); - } - - setOutliers(result); - } - }, - sqs, - stdDevFactor); + Outlier.detectOutliers(this, sqs, stdDevFactor); } catch (MathException me) { log.warn(me);