How would you calculate j in this function?

advertisements

consider the function below which converts the result of a * b in a couple of numbers i and j, where:

  1. a, b, x, y are int (Suppose they are always => 32bit-long)
  2. a and b are <= n*m, where n = 10^3 and m=10^5. n*m = BASE.
  3. a * b can be written as i*BASE + j

How would you calculate j without using any types larger than int (in case be careful about overflows with int's which are UB):

#include <iostream>
#include <cstdlib>

using namespace std;

int n = 1000, m = 100000;

struct N {
        int i, j;
};

N f(int a, int b) {
        N x;
        int a0, a1, b0, b1, o;
        a1 = a / n;
        a0 = a - (a1 * n); // a0 = a % n
        b1 = b / m;
        b0 = b - (b1 * m);  // b0 = b % m
        o = a1 * b1 + (a0 * b1) / n + (b0 * a1) / m;
        x.i = o;
        x.j = 0; // CALCULATE J WITH INTs MATH
        return x;
}

int main(int, char* argv[]) {
        int a = atoi(argv[1]),
        b = atoi(argv[2]);
        N x = f(a, b);
        cout << a << " * " << b << " = " << x.i << "*" << n*m
             << " + " << x.j << endl;
        cout << "which is: " << (long long)a * b << endl;
        return 0;
}


You started correctly, but lost the plot around calculation of o. First, my assumptions: you don't want to deal with any integer greater than n*m, so taking mod n*m is cheating. I am saying this, because given m > 2^16, I have to assume int is 32-bit long, which is capable of dealing with your numbers without overflowing.

In any case. You have correctly (I guess, since purpose of n and m are not specified) written:

a=a0 + a1*n (a0<n)
b=b0 + b1*m (b0<m)

So, if we do the math:

a*b = a0*b0 + a0*b1*m + a1*b0*n + a1*b1*n*m

Here, a0*b0 < n*m, so it is part of j, and a1*b1*n*m > n*m, so it is part of i. It is the other two terms that you need to split into two again. But you cannot calculate each and take the mod n*m, since that would be cheating (as per my rule above). If you write:

a0*b1 = a0b1_0 + a0b1_1*n

You get:

a0*b1*m = a0b1_0*m + a0b1_1*n*m

Since a0b1_0 < n, a0b1_0*m < n*m, which means this part goes to j. Obviously, a0b1_1 goes to i.

Repeat a similar logic for a1*b0, and you've got three terms to add up for j, and three more to add up for i.


EDIT: Forgot to mention a few things:

  • You need the constraints a < n^2 and b < m^2 for this to work. Otherwise, you need more ai "words". e.g.: a = a0 + a1*n + a2*n^2, ai < n.

  • The final sum of j may be greater than n*m. You need to watch for overflow ( n*m - o < addend, or a similar logic, and add 1 to i when this happens - while calculating j + addend - n*m without overflow).