sascha@3011: package de.intevation.flys.artifacts.model.fixings; sascha@3011: sascha@3011: import de.intevation.flys.artifacts.math.fitting.Function; sascha@3011: sascha@3011: import de.intevation.flys.artifacts.math.Outlier; sascha@3011: sascha@3011: import de.intevation.flys.artifacts.math.Outlier.IndexedValue; sascha@3011: import de.intevation.flys.artifacts.math.Outlier.Outliers; sascha@3011: sascha@3011: import org.apache.commons.math.MathException; sascha@3011: sascha@3011: import org.apache.commons.math.optimization.fitting.CurveFitter; sascha@3011: sascha@3011: import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; sascha@3011: sascha@3011: import gnu.trove.TDoubleArrayList; sascha@3011: sascha@3011: import org.apache.log4j.Logger; sascha@3011: sascha@3011: import java.util.ArrayList; sascha@3011: import java.util.List; sascha@3011: sascha@3011: public class Fitting sascha@3011: { sascha@3011: private static Logger log = Logger.getLogger(Fitting.class); sascha@3011: sascha@3011: /** Use instance of this factory to find meta infos for outliers. */ sascha@3011: public interface QWFactory { sascha@3011: sascha@3011: QW create(double q, double w); sascha@3011: sascha@3011: } // interface QWFactory sascha@3011: sascha@3011: protected Function function; sascha@3011: protected QWFactory qwFactory; sascha@3011: protected double chiSqr; sascha@3011: protected double [] parameters; sascha@3011: protected ArrayList removed; sascha@3011: sascha@3011: sascha@3011: public Fitting() { sascha@3011: removed = new ArrayList(); sascha@3011: } sascha@3011: sascha@3011: public Fitting(Function function) { sascha@3011: this(); sascha@3011: this.function = function; sascha@3011: } sascha@3011: sascha@3011: public Fitting(Function function, QWFactory qwFactory) { sascha@3011: this(function); sascha@3011: this.qwFactory = qwFactory; sascha@3011: } sascha@3011: sascha@3011: public Function getFunction() { sascha@3011: return function; sascha@3011: } sascha@3011: sascha@3011: public void setFunction(Function function) { sascha@3011: this.function = function; sascha@3011: } sascha@3011: sascha@3011: public double getChiSquare() { sascha@3011: return chiSqr; sascha@3011: } sascha@3011: sascha@3011: public void reset() { sascha@3011: chiSqr = 0.0; sascha@3011: parameters = null; sascha@3011: removed.clear(); sascha@3011: } sascha@3011: sascha@3011: public boolean hasOutliers() { sascha@3011: return !removed.isEmpty(); sascha@3011: } sascha@3011: sascha@3011: public List getOutliers() { sascha@3011: return removed; sascha@3011: } sascha@3011: sascha@3011: public QW [] outliersToArray() { sascha@3011: return removed.toArray(new QW[removed.size()]); sascha@3011: } sascha@3011: sascha@3011: public double [] getParameters() { sascha@3011: return parameters; sascha@3011: } sascha@3011: sascha@3011: public boolean fit(double [] qs, double [] ws) { sascha@3011: sascha@3011: TDoubleArrayList xs = new TDoubleArrayList(qs.length); sascha@3011: TDoubleArrayList ys = new TDoubleArrayList(ws.length); sascha@3011: sascha@3011: for (int i = 0; i < qs.length; ++i) { sascha@3011: if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { sascha@3011: xs.add(qs[i]); sascha@3011: ys.add(ws[i]); sascha@3011: } sascha@3011: else { sascha@3011: log.warn("remove invalid value " + qs[i] + " " + ws[i]); sascha@3011: } sascha@3011: } sascha@3011: sascha@3011: if (xs.size() < 2) { sascha@3011: return false; sascha@3011: } sascha@3011: sascha@3011: LevenbergMarquardtOptimizer lmo = new LevenbergMarquardtOptimizer(); sascha@3011: sascha@3011: double [] parameters; sascha@3011: sascha@3011: List inputs = new ArrayList(xs.size()); sascha@3011: sascha@3011: for (;;) { sascha@3011: CurveFitter cf = new CurveFitter(lmo); sascha@3011: sascha@3011: for (int i = 0, N = xs.size(); i < N; ++i) { sascha@3011: cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); sascha@3011: } sascha@3011: sascha@3011: try { sascha@3011: parameters = cf.fit(function, function.getInitialGuess()); sascha@3011: } sascha@3011: catch (MathException me) { sascha@3011: log.warn(me); sascha@3011: return false; sascha@3011: } sascha@3011: sascha@3011: if (qwFactory == null) { sascha@3011: break; sascha@3011: } sascha@3011: sascha@3011: inputs.clear(); sascha@3011: sascha@3011: // This is the paraterized function for a given km. sascha@3011: de.intevation.flys.artifacts.math.Function instance = sascha@3011: function.instantiate(parameters); sascha@3011: sascha@3011: for (int i = 0, N = xs.size(); i < N; ++i) { sascha@3011: double y = instance.value(xs.getQuick(i)); sascha@3011: if (Double.isNaN(y)) { sascha@3011: continue; sascha@3011: } sascha@3011: inputs.add(new IndexedValue(i, ys.getQuick(i) - y)); sascha@3011: } sascha@3011: sascha@3011: Outliers outliers = Outlier.findOutliers(inputs); sascha@3011: sascha@3011: if (!outliers.hasOutliers()) { sascha@3011: break; sascha@3011: } sascha@3011: sascha@3011: List rem = outliers.getRemoved(); sascha@3011: sascha@3011: for (int i = rem.size()-1; i >= 0; --i) { sascha@3011: int idx = rem.get(i).getIndex(); sascha@3011: removed.add( sascha@3011: qwFactory.create( sascha@3011: xs.getQuick(idx), ys.getQuick(idx))); sascha@3011: xs.remove(idx); sascha@3011: ys.remove(idx); sascha@3011: } sascha@3011: } sascha@3011: sascha@3011: chiSqr = lmo.getChiSquare(); sascha@3011: sascha@3011: return true; sascha@3011: } sascha@3011: } sascha@3011: // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :