teichmann@5831: package org.dive4elements.river.artifacts.model.sq; sascha@3188: teichmann@5831: import org.dive4elements.river.artifacts.math.fitting.Function; sascha@3188: sascha@3188: import java.util.ArrayList; sascha@3188: import java.util.List; sascha@3188: sascha@3188: import org.apache.commons.math.MathException; sascha@3188: sascha@3188: import org.apache.commons.math.optimization.fitting.CurveFitter; sascha@3188: sascha@3188: import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; sascha@3188: sascha@3188: import org.apache.log4j.Logger; sascha@3188: sascha@3188: public class Fitting sascha@3552: implements Outlier.Callback sascha@3188: { sascha@3188: private static Logger log = Logger.getLogger(Fitting.class); sascha@3188: sascha@3552: public interface Callback { sascha@3552: sascha@3552: void afterIteration( sascha@3552: double [] parameters, sascha@3552: SQ [] measurements, sascha@3552: SQ [] outliers, sascha@3552: double standardDeviation, sascha@3552: double chiSqr); sascha@3552: } // interfacte sascha@3552: sascha@3188: protected Function function; sascha@3188: sascha@3552: protected double [] coeffs; sascha@3188: teichmann@5831: protected org.dive4elements.river.artifacts.math.Function instance; sascha@3552: sascha@3552: protected double stdDevFactor; sascha@3188: protected double chiSqr; sascha@3188: sascha@3552: protected Callback callback; sascha@3188: sascha@3188: public Fitting() { sascha@3188: } sascha@3188: sascha@3188: public Fitting(Function function, double stdDevFactor) { sascha@3552: this(); sascha@3188: this.function = function; sascha@3188: this.stdDevFactor = stdDevFactor; sascha@3188: } sascha@3188: sascha@3188: public Function getFunction() { sascha@3188: return function; sascha@3188: } sascha@3188: sascha@3188: public void setFunction(Function function) { sascha@3188: this.function = function; sascha@3188: } sascha@3188: sascha@3188: public double getStdDevFactor() { sascha@3188: return stdDevFactor; sascha@3188: } sascha@3188: sascha@3188: public void setStdDevFactor(double stdDevFactor) { sascha@3188: this.stdDevFactor = stdDevFactor; sascha@3188: } sascha@3188: sascha@3552: @Override sascha@3566: public void initialize(List sqs) throws MathException { sascha@3552: sascha@3552: LevenbergMarquardtOptimizer lmo = sascha@3552: new LevenbergMarquardtOptimizer(); sascha@3552: sascha@3552: CurveFitter cf = new CurveFitter(lmo); sascha@3566: for (SQ sq: sqs) { sascha@3552: cf.addObservedPoint(sq.getQ(), sq.getS()); sascha@3552: } sascha@3552: sascha@3552: coeffs = cf.fit( sascha@3552: function, function.getInitialGuess()); sascha@3552: sascha@3552: instance = function.instantiate(coeffs); sascha@3552: sascha@3552: chiSqr = lmo.getChiSquare(); sascha@3188: } sascha@3188: sascha@3552: @Override sascha@3552: public double eval(SQ sq) { sascha@3552: double s = instance.value(sq.q); sascha@3552: return sq.s - s; sascha@3552: } sascha@3552: sascha@3552: @Override sascha@3566: public void iterationFinished( sascha@3566: double standardDeviation, sascha@3566: SQ outlier, sascha@3566: List remainings sascha@3566: ) { sascha@3552: if (log.isDebugEnabled()) { sascha@3552: log.debug("iterationFinished ----"); sascha@3552: log.debug(" num remainings: " + remainings.size()); sascha@3566: log.debug(" has outlier: " + outlier != null); sascha@3552: log.debug(" standardDeviation: " + standardDeviation); sascha@3552: log.debug(" Chi^2: " + chiSqr); sascha@3552: log.debug("---- iterationFinished"); sascha@3552: } sascha@3552: callback.afterIteration( sascha@3552: coeffs, sascha@3552: remainings.toArray(new SQ[remainings.size()]), sascha@3566: outlier != null ? new SQ [] { outlier} : new SQ [] {}, sascha@3552: standardDeviation, sascha@3552: chiSqr); sascha@3188: } sascha@3188: sascha@3188: protected static final List onlyValid(List sqs) { sascha@3188: sascha@3188: List good = new ArrayList(sqs.size()); sascha@3188: sascha@3188: for (SQ sq: sqs) { sascha@3188: if (sq.isValid()) { sascha@3188: good.add(sq); sascha@3188: } sascha@3188: } sascha@3188: sascha@3188: return good; sascha@3188: } sascha@3188: rrenkert@5396: public boolean fit(List sqs, String method, Callback callback) { sascha@3188: sascha@3188: sqs = onlyValid(sqs); sascha@3188: sascha@3188: if (sqs.size() < 2) { sascha@3188: log.warn("Too less points for fitting."); sascha@3188: return false; sascha@3188: } sascha@3188: sascha@3552: this.callback = callback; sascha@3188: sascha@3188: try { rrenkert@5396: Outlier.detectOutliers(this, sqs, stdDevFactor, method); sascha@3188: } sascha@3188: catch (MathException me) { sascha@3188: log.warn(me); sascha@3188: return false; sascha@3188: } sascha@3188: sascha@3188: return true; sascha@3188: } sascha@3188: } sascha@3188: // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :