/*
 * holefiller.h: abstract wrapper class that fills in removable
 * singularities in functions.
 */

/*
 * This class is a wrapper that sits in front of another spigot,
 * somewhat like enforce.cpp's Enforcer. You provide it with a value
 * x, which might be one of a finite set of special x values; the
 * basic semantics are that if x does turn out to be one of those
 * values then there's a corresponding special y value for each one
 * that should be returned, and if x turns out to be anything else
 * then there's a 'general y' that should be returned instead.
 *
 * You provide the special x and y values and the general y value by
 * means of subclassing HoleFiller itself; they're all given by
 * virtual methods.
 *
 * The point of this is for situations in which you have a means of
 * implementing _most_ of a continuous function f, apart from a small
 * number of 'removable singularities' - points where the function
 * really is continuous, but for some annoying reason, you have to use
 * a different evaluation strategy to compute its value there. But you
 * can't just write the obvious code (if x == this then return that,
 * else do the general thing), because that causes an exactness hazard
 * in the case where x _turns out_ to be the special value but that
 * wasn't obvious up front.
 *
 * Enter HoleFiller. HoleFiller will keep track of which special
 * value(s) the input x has not yet been proved to miss, and it will
 * return intervals narrowing about the corresponding special y value
 * for as long as there's still one of those in play. But if it then
 * turns out that x is not a special value after all, it'll switch to
 * returning intervals about whatever the general fallback evaluator
 * returns.
 *
 * In order to do this correctly, it must not return an interval in
 * the first phase which rules out any value that f(x) might still
 * turn out to have. So another thing your subclass must provide is a
 * function which takes as input the currently known interval about
 * the special value to which we have restricted x, and returns an
 * interval about the corresponding special y which also contains any
 * other value that f(x) might take on the input interval. That way,
 * if we later abandon the special value, we won't have returned any
 * interval that excludes the real answer to f(x).
 */
class HoleFiller : public BinaryIntervalSource {
    std::vector<BracketingGenerator *> bgs;
    BracketingGenerator *xbg, *ybg;
    int specindex;
    int crState;

  protected:
    Spigot *x;

  public:
    HoleFiller(Spigot *ax) : x(ax), xbg(NULL), ybg(NULL)
    {
        crState = -1;
        dprint("hello HoleFiller");
    }

    ~HoleFiller()
    {
        for (int i = 0; i < (int)bgs.size(); i++)
            if (bgs[i])
                delete bgs[i];
        if (xbg)
            delete xbg;
        if (ybg)
            delete ybg;
        delete x;
    }

    virtual Spigot *replace() {
        /*
         * See if we can _immediately_ detect exact equality to a
         * special-case input. If so, delete ourself and return that y
         * value instead.
         */
        Spigot *a, *diff;
        bigint n, d;
        for (int i = 0; (a = xspecial(i)) != NULL; i++) {
            diff = spigot_sub(x->clone(), a);
            if (diff->is_rational(&n, &d) && n == 0) {
                Spigot *ret = yspecial(i);
                delete diff;
                delete this;
                return ret;
            }
        }

        return this;
    }

    virtual Spigot *xspecial(int) = 0;
    virtual Spigot *yspecial(int) = 0;
    virtual Spigot *ygeneral(Spigot *) = 0;
    virtual bool combine(bigint *ret_lo, bigint *ret_hi, unsigned *ret_bits,
                         const bigint &xnlo, const bigint &xnhi,
                         const bigint &ynlo, const bigint &ynhi,
                         unsigned dbits, int /*index*/) {
        /*
         * This function may be reimplemented by a subclass. Its job
         * is: given an interval (xnlo,xnhi)/dbits saying how close x
         * is to the (index)th special input, and an interval
         * (ynlo,ynhi)/dbits bracketing the corresponding special
         * output, return an interval that we can be sure the _real_
         * output is in.
         *
         * We provide a default implementation here based on the
         * assumption that the function is Lipschitz-continuous within
         * a range of +-1/2 about the target, because that's really
         * easy and quite a weak constraint. If you felt like
         * squeezing more performance out then you could reimplement
         * this to be more aggressive; if you find yourself
         * implementing a function which is really steep at the
         * special point (perhaps even infinitely steep, e.g. if it
         * approaches the point in a sqrt-like way) then you _must_
         * reimplement this to be more conservative.
         *
         * Return false if we can't get any information at all yet.
         */
        bigint deviation = -xnlo;
        if (deviation < xnhi)
            deviation = xnhi;
        if (dbits < 1 || (deviation >> (dbits-1)) != 0)
            return false;
        *ret_bits = dbits;
        *ret_lo = ynlo - deviation;
        *ret_hi = ynhi + deviation;
        return true;
    }

    virtual HoleFiller *clone() = 0;

    virtual void gen_bin_interval(bigint *ret_lo, bigint *ret_hi,
                                  unsigned *ret_bits)
    {
        crBegin;

        /*
         * Set up a BracketingGenerator for x-a, for each special
         * input a.
         */
        {
            Spigot *a;
            for (int i = 0; (a = xspecial(i)) != NULL; i++)
                bgs.push_back(new BracketingGenerator
                              (spigot_sub(x->clone(), a)));
        }

        /*
         * Loop round and round, fetching more information about each
         * of those, until we narrow to 0 or 1 of them potentially
         * zero.
         *
         * This is a setup phase in which we don't return to the
         * caller at all, because if the input interval doesn't even
         * narrow to less than the distance between our special values
         * (if we have more than one of them), then no useful output
         * was going to be generated anyway.
         */
        while (1) {
            int n = 0;                 // count the still-active bgs
            bigint nlo, nhi;
            unsigned dbits;

            for (int i = 0; i < (int)bgs.size(); i++) {
                bgs[i]->get_bracket_shift(&nlo, &nhi, &dbits);
                dprint("input bracket for specials[%d]: (%b,%b) / 2^%d",
                       i, &nlo, &nhi, (int)dbits);

                if (nlo > 0 || nhi < 0) {
                    /*
                     * We know the sign of x-a, i.e. we know it's
                     * nonzero. Discard this special value.
                     */
                    delete bgs[i];
                    bgs[i] = NULL;
                } else if (nlo == 0 && nhi == 0) {
                    /*
                     * We know x-a is _exactly_ zero, i.e. this is
                     * precisely a special value of the function.
                     */
                    ybg = new BracketingGenerator(yspecial(i));
                    goto passthrough;  // multilevel break (sorry)
                } else {
                    /*
                     * This one is still a possible.
                     */
                    n++;
                }
            }

            if (n == 0) {
                /*
                 * No remaining possibilities for special values, so
                 * we just switch to constructing the general return
                 * value.
                 */
                ybg = new BracketingGenerator(ygeneral(x->clone()));
                goto passthrough;
            }

            if (n == 1) {
                /*
                 * There's one remaining possibility for a special
                 * value. This is the interesting case.
                 */
                for (int i = 0; i < (int)bgs.size(); i++) {
                    if (bgs[i]) {
                        specindex = i;
                        xbg = bgs[i];
                        bgs[i] = NULL;
                        ybg = new BracketingGenerator(yspecial(i));
                        goto narrowing;
                    }
                }
            }
        }

      narrowing:
        /*
         * Now we've got a single potential special value. Narrow an
         * interval about the SV output for as long as we can't prove
         * the input is not the SV input.
         */
        while (true) {
            {
                bigint xnlo, xnhi, ynlo, ynhi;
                unsigned xdbits, ydbits;

                xbg->get_bracket_shift(&xnlo, &xnhi, &xdbits);
                dprint("narrowing input bracket: (%b,%b) / 2^%d",
                       &xnlo, &xnhi, (int)xdbits);

                if (xnlo > 0 || xnhi < 0) {
                    /*
                     * Turns out we don't have the special value after
                     * all; reinitialise ybg with the general value, and
                     * go on to the passthrough phase.
                     */
                    delete xbg;
                    xbg = NULL;
                    delete ybg;
                    ybg = new BracketingGenerator(ygeneral(x->clone()));
                    goto passthrough;
                } else if (xnlo == 0 && xnhi == 0) {
                    /*
                     * Turns out we have _exactly_ the special value.
                     */
                    delete xbg;
                    xbg = NULL;
                    goto passthrough;
                }

                ybg->get_bracket_shift(&ynlo, &ynhi, &ydbits);
                dprint("narrowing output bracket: (%b,%b) / 2^%d",
                       &ynlo, &ynhi, (int)ydbits);

                /*
                 * Normalise the input and output brackets to the same
                 * denominator, which we leave in ydbits.
                 */
                if (ydbits < xdbits) {
                    ynlo <<= xdbits-ydbits;
                    ynhi <<= xdbits-ydbits;
                    ydbits = xdbits;
                } else if (ydbits > xdbits) {
                    xnlo <<= ydbits-xdbits;
                    xnhi <<= ydbits-xdbits;
                }
                dprint("narrowing normalised brackets: input (%b,%b) / 2^%d,"
                       " output (%b,%b) / 2^%d",
                       &xnlo, &xnhi, (int)ydbits, // (intentionally not xdbits)
                       &ynlo, &ynhi, (int)ydbits);

                /*
                 * And let our subclass figure out how closely it can
                 * afford to narrow the resulting interval.
                 */
                combine(ret_lo, ret_hi, ret_bits,
                        xnlo, xnhi, ynlo, ynhi, ydbits, specindex);
            }

            dprint("combined output bracket: (%b,%b) / 2^%d",
                   ret_lo, ret_hi, (int)*ret_bits);
            crReturnV;
        }

      passthrough:
        /*
         * The simple part: now we know what kind of output we're
         * generating, just pass through the results of an ordinary
         * BracketingGenerator.
         */

        // FIXME: here we'd really like to pass through an exact
        // termination, if we're a rational!

        while (1) {
            ybg->get_bracket_shift(ret_lo, ret_hi, ret_bits);
            dprint("passthrough bracket: (%b,%b) / 2^%d",
                   ret_lo, ret_hi, (int)*ret_bits);
            crReturnV;
        }

        crEnd;
    }
};
