/* trapezoid.c -- Parallel Trapezoidal Rule with MPI_Bcast
 *
 * Input: None.
 * Output:  Estimate of the integral from a to b of f(x)
 *    using the trapezoidal rule and n trapezoids.
 *
 * Algorithm:
 *    1.  Each process calculates "its" interval of
 *        integration.
 *    2.  Each process estimates the integral of f(x)
 *        over its interval using the trapezoidal rule.
 *    3a. Each process != 0 sends its integral to 0.
 *    3b. Process 0 sums the calculations received from
 *        the individual processes and prints the result.
 *
 * Notes:  
 *    1.  First parameter is a, the left end of the range, and 
 *        the second parameter is b, the right end of the range. 
 *    2.  f(x) and n are hardwired.
 *    3.  The number of processes (p) should evenly divide
 *        the number of trapezoids (n = 1024)
 *
 */
#include <stdio.h>
#include <stdlib.h>

/* We'll be using MPI routines, definitions, etc. */
#include "mpi.h"


main(int argc, char** argv) {
    int         my_rank;   /* My process rank           */
    int         p;         /* The number of processes   */
    float       endpoint[2];  /* Left and right         */
    int         n = 1024;  /* Number of trapezoids      */
    float       h;         /* Trapezoid base length     */
    float       local_a;   /* Left endpoint my process  */
    float       local_b;   /* Right endpoint my process */
    int         local_n;   /* Number of trapezoids for  */
                           /* my calculation            */
    float       integral;  /* Integral over my interval */
    float       total;     /* Total integral            */
    int         source;    /* Process sending integral  */
    int         dest = 0;  /* All messages go to 0      */
    int         tag = 0;
    MPI_Status  status;

    float Trap(float local_a, float local_b, int local_n,
              float h);    /* Calculate local integral  */

    /* Let the system do what it needs to start up MPI */
    MPI_Init(&argc, &argv);

    /* Get my process rank */
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

    /* Find out how many processes are being used */
    MPI_Comm_size(MPI_COMM_WORLD, &p);

    if (argc != 3) {
      if (my_rank==0) 
	printf("Usage:  mpirun -np <numprocs> trapezoid <left> <right>\n");
      MPI_Finalize();
      exit(0);
    }

    /* get the endpoints and calculate local variables */
    if (my_rank==0) {
      endpoint[0] = atof(argv[1]);   /*  left endpoint */
      endpoint[1] = atof(argv[2]);   /* right endpoint */
    }

    /* This is not very efficient here, but demonstrates
       the use of MPI_Bcast.                              */   
    MPI_Bcast(endpoint, 2, MPI_FLOAT, 0, MPI_COMM_WORLD);

    h = (endpoint[1]-endpoint[0])/n;    /* h is the same for all processes */
    local_n = n/p;                      /*  so is the number of trapezoids */
    if (my_rank == 0) printf("a=%f, b=%f, Local number of trapezoids=%d\n", 
			     endpoint[0], endpoint[1], local_n );

    /* Length of each process' interval of
     * integration = local_n*h.  So my interval
     * starts at: */
    local_a = endpoint[0] + my_rank*local_n*h;
    local_b = local_a + local_n*h;
    integral = Trap(local_a, local_b, local_n, h);

    /* Add up the integrals calculated by each process */
    MPI_Reduce(&integral, &total, 1, MPI_FLOAT, MPI_SUM, 0, MPI_COMM_WORLD);

    /* Print the result */
    if (my_rank == 0) {
        printf("With n = %d trapezoids, our estimate\n",
            n);
        printf("of the integral from %f to %f = %f\n",
            endpoint[0], endpoint[1], total);
    }

    /* Shut down MPI */
    MPI_Finalize();
} /*  main  */


float Trap(
          float  local_a   /* in */,
          float  local_b   /* in */,
          int    local_n   /* in */,
          float  h         /* in */) {

    float integral;   /* Store result in integral  */
    float x;
    int i;

    float f(float x); /* function we're integrating */

    integral = (f(local_a) + f(local_b))/2.0;
    x = local_a;
    for (i = 1; i <= local_n-1; i++) {
        x = x + h;
        integral = integral + f(x);
    }
    integral = integral*h;
    return integral;
} /*  Trap  */


float f(float x) {
    float return_val;
    /* Calculate f(x). */
    /* Store calculation in return_val. */
    return_val = x*x;
    return return_val;
} /* f */


