rts-wapr.c

     1  //! @file rts-wapr.c
     2  //! @author J. Marcel van der Veer
     3  //
     4  //! @section Copyright
     5  //
     6  // This file is part of VIF - vintage FORTRAN compiler.
     7  // Copyright 2020-2025 J. Marcel van der Veer <algol68g@xs4all.nl>.
     8  //
     9  //! @section License
    10  //
    11  // This program is free software; you can redistribute it and/or modify it 
    12  // under the terms of the GNU General Public License as published by the 
    13  // Free Software Foundation; either version 3 of the License, or 
    14  // (at your option) any later version.
    15  //
    16  // This program is distributed in the hope that it will be useful, but 
    17  // WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 
    18  // or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for 
    19  // more details. You should have received a copy of the GNU General Public 
    20  // License along with this program. If not, see <http://www.gnu.org/licenses/>.
    21  
    22  //! @section Synopsis
    23  //!
    24  //! Runtime support implementing the Lambert W-function.
    25  
    26  #include <vif.h>
    27  
    28  // The Lambert W function y = W(x) is the solution to the equation y * exp(y) = x. 
    29  //
    30  // Original FORTRAN77 version by Andrew Barry, S. J. Barry, Patricia Culligan-Hensley.
    31  // Original C version by John Burkardt, distributed under the MIT license [2014].
    32  // Adapted for VIF by J.M. van der Veer [2024].
    33  //
    34  // Reference:
    35  //   Andrew Barry, S. J. Barry, Patricia Culligan-Hensley,
    36  //   Algorithm 743: WAPR - A Fortran routine for calculating real values of the W-function,
    37  //   ACM Transactions on Mathematical Software,
    38  //   Volume 21, Number 2, June 1995, pages 172-181.
    39  
    40  real_8 bisect (real_8 xx, int_4 nb, int_4 * ner, int_4 l);
    41  real_8 crude (real_8 xx, int_4 nb);
    42  int_4 nbits_compute ();
    43  real_8 wapr (real_8 x, int_4 nb, int_4 * nerror, int_4 l);
    44  real_8 _wapr (real_8 * x, int_4 * nb, int_4 * nerror, int_4 * l);
    45  
    46  real_8 bisect (real_8 xx, int_4 nb, int_4 *ner, int_4 l)
    47  {
    48  // BISECT approximates the W function using bisection.
    49  // After TOMS algorithm 743.
    50  //
    51  // Discussion:
    52  //
    53  //   The parameter TOL, which determines the accuracy of the bisection
    54  //   method, is calculated using NBITS (assuming the final bit is lost
    55  //   due to rounding error).
    56  //
    57  //   N0 is the maximum number of iterations used in the bisection
    58  //   method.
    59  //
    60  //   For XX close to 0 for Wp, the exponential approximation is used.
    61  //   The approximation is exact to O(XX^8) so, depending on the value
    62  //   of NBITS, the range of application of this formula varies. Outside
    63  //   this range, the usual bisection method is used.
    64  //
    65  // Parameters:
    66  //
    67  //   Input, real_8 XX, the argument.
    68  //
    69  //   Input, int_4 NB, indicates the branch of the W function.
    70  //   0, the upper branch;
    71  //   nonzero, the lower branch.
    72  //
    73  //   Output, int_4 *NER, the error flag.
    74  //   0, success;
    75  //   1, the routine did not converge.  Perhaps reduce NBITS and try again.
    76  //
    77  //   Input, int_4 L, the offset indicator.
    78  //   1, XX represents the offset of the argument from -exp(-1).
    79  //   not 1, XX is the actual argument.
    80  //
    81  //   Output, real_8 BISECT, the value of W(X), as determined
    82  
    83    const int_4 n0 = 500;
    84    int_4 i;
    85    real_8 d, f, fd, r, test, tol, u, x, value = 0.0;
    86    static int_4 nbits = 0;
    87    *ner = 0;
    88    if (nbits == 0) {
    89      nbits = nbits_compute ();
    90    }
    91    if (l == 1) {
    92      x = xx - exp (-1.0);
    93    } else {
    94      x = xx;
    95    }
    96    if (nb == 0) {
    97      test = 1.0 / pow (pow (2.0, nbits), (1.0 / 7.0));
    98      if (fabs (x) < test) {
    99        return x * exp (-x * exp (-x * exp (-x * exp (-x * exp (-x * exp (-x))))));
   100      } else {
   101        u = crude (x, nb) + 1.0e-3;
   102        tol = fabs (u) / pow (2.0, nbits);
   103        d = fmax (u - 2.0e-3, -1.0);
   104        for (i = 1; i <= n0; i++) {
   105    r = 0.5 * (u - d);
   106    value = d + r;
   107  // Find root using w*exp(w)-x to avoid ln(0) error.
   108    if (x < exp (1.0)) {
   109      f = value * exp (value) - x;
   110      fd = d * exp (d) - x;
   111    }
   112  // Find root using ln(w/x)+w to avoid overflow error.
   113    else {
   114      f = log (value / x) + value;
   115      fd = log (d / x) + d;
   116    }
   117    if (f == 0.0) {
   118      return value;
   119    }
   120    if (fabs (r) <= tol) {
   121      return value;
   122    }
   123    if (0.0 < fd * f) {
   124      d = value;
   125    } else {
   126      u = value;
   127    }
   128        }
   129      }
   130    } else {
   131      d = crude (x, nb) - 1.0e-3;
   132      u = fmin (d + 2.0e-3, -1.0);
   133      tol = fabs (u) / pow (2.0, nbits);
   134      for (i = 1; i <= n0; i++) {
   135        r = 0.5 * (u - d);
   136        value = d + r;
   137        f = value * exp (value) - x;
   138        if (f == 0.0) {
   139    return value;
   140        }
   141        if (fabs (r) <= tol) {
   142    return value;
   143        }
   144        fd = d * exp (d) - x;
   145        if (0.0 < fd * f) {
   146    d = value;
   147        } else {
   148    u = value;
   149        }
   150      }
   151    }
   152  // The iteration did not converge.
   153    *ner = 1;
   154    return value;
   155  }
   156  
   157  real_8 crude (real_8 xx, int_4 nb)
   158  {
   159  // CRUDE returns a crude approximation for the W function.
   160  //
   161  // Parameters:
   162  //
   163  //   Input, real_8 XX, the argument.
   164  //
   165  //   Input, int_4 NB, indicates the desired branch.
   166  //   * 0, the upper branch;
   167  //   * nonzero, the lower branch.
   168  //
   169  //   Output, real_8 CRUDE, the crude approximation to W at XX.
   170  
   171    real_8 an2, reta, t, ts, zl;
   172    static int_4 init = 0;
   173    static real_8 c13, em, em2, em9, eta, s2, s21, s22, s23;
   174  // Various mathematical constants.
   175    if (init == 0) {
   176      init = 1;
   177      em = -exp (-1.0);
   178      em9 = -exp (-9.0);
   179      c13 = 1.0 / 3.0;
   180      em2 = 2.0 / em;
   181      s2 = sqrt (2.0);
   182      s21 = 2.0 * s2 - 3.0;
   183      s22 = 4.0 - 3.0 * s2;
   184      s23 = s2 - 2.0;
   185    }
   186  // Crude Wp.
   187    if (nb == 0) {
   188      if (xx <= 20.0) {
   189        reta = s2 * sqrt (1.0 - xx / em);
   190        an2 = 4.612634277343749 * sqrt (sqrt (reta + 1.09556884765625));
   191        return reta / (1.0 + reta / (3.0 + (s21 * an2 + s22) * reta / (s23 * (an2 + reta)))) - 1.0;
   192      } else {
   193        zl = log (xx);
   194        return log (xx / log (xx / pow (zl, exp (-1.124491989777808 / (0.4225028202459761 + zl)))));
   195      }
   196    } else {
   197  // Crude Wm.
   198      if (xx <= em9) {
   199        zl = log (-xx);
   200        t = -1.0 - zl;
   201        ts = sqrt (t);
   202        return zl - (2.0 * ts) / (s2 + (c13 - t / (270.0 + ts * 127.0471381349219)) * ts);
   203      } else {
   204        zl = log (-xx);
   205        eta = 2.0 - em2 * xx;
   206        return log (xx / log (-xx / ((1.0 - 0.5043921323068457 * (zl + 1.0)) * (sqrt (eta) + eta / 3.0) + 1.0)));
   207      }
   208    }
   209  }
   210  
   211  int_4 nbits_compute ()
   212  {
   213  // NBITS_COMPUTE computes the mantissa length minus one.
   214  //
   215  // Discussion:
   216  //
   217  //   NBITS is the number of bits (less 1) in the mantissa of the
   218  //   floating point number number representation of your machine.
   219  //   It is used to determine the level of accuracy to which the W
   220  //   function should be calculated.
   221  //
   222  // Parameters:
   223  //
   224  //   Output, int_4 NBITS_COMPUTE, the mantissa length, in bits, minus one.
   225  //
   226    int m = 14;
   227    return _i1mach (&m) - 1;
   228  }
   229  
   230  real_8 wapr (real_8 x, int_4 nb, int_4 *nerror, int_4 l)
   231  {
   232  // WAPR approximates the W function.
   233  //
   234  // Discussion:
   235  //
   236  //   The call will fail if the input value X is out of range.
   237  //   The range requirement for the upper branch is:
   238  //     -exp(-1) <= X.
   239  //   The range requirement for the lower branch is:
   240  //     -exp(-1) < X < 0.
   241  //
   242  // Parameters:
   243  //
   244  //   Input, real_8 X, the argument.
   245  //
   246  //   Input, int_4 NB, indicates the desired branch.
   247  //   * 0, the upper branch;
   248  //   * nonzero, the lower branch.
   249  //
   250  //   Output, int_4 *NERROR, the error flag.
   251  //   * 0, successful call.
   252  //   * 1, failure, the input X is out of range.
   253  //
   254  //   Input, int_4 L, indicates the interpretation of X.
   255  //   * 1, X is actually the offset from -(exp-1), so compute W(X-exp(-1)).
   256  //   * not 1, X is the argument; compute W(X);
   257  //
   258  //   Output, real_8 WAPR, the approximate value of W(X).
   259  
   260    int_4 i;
   261    real_8 an2, delx, eta, reta, t, temp, temp2, ts, xx, zl, zn, value = 0.0;
   262    static int_4 init = 0, nbits, niter = 1;
   263    static real_8 an3, an4, an5, an6, c13, c23, d12, em, em2, em9;
   264    static real_8 s2, s21, s22, s23, tb, x0, x1;
   265    *nerror = 0;
   266    if (init == 0) {
   267      init = 1;
   268      nbits = nbits_compute ();
   269      if (56 <= nbits) {
   270        niter = 2;
   271      }
   272  // Various mathematical constants.
   273      em = -exp (-1.0);
   274      em9 = -exp (-9.0);
   275      c13 = 1.0 / 3.0;
   276      c23 = 2.0 * c13;
   277      em2 = 2.0 / em;
   278      d12 = -em2;
   279      tb = pow (0.5, nbits);
   280      x0 = pow (tb, 1.0 / 6.0) * 0.5;
   281      x1 = (1.0 - 17.0 * pow (tb, 2.0 / 7.0)) * em;
   282      an3 = 8.0 / 3.0;
   283      an4 = 135.0 / 83.0;
   284      an5 = 166.0 / 39.0;
   285      an6 = 3167.0 / 3549.0;
   286      s2 = sqrt (2.0);
   287      s21 = 2.0 * s2 - 3.0;
   288      s22 = 4.0 - 3.0 * s2;
   289      s23 = s2 - 2.0;
   290    }
   291    if (l == 1) {
   292      delx = x;
   293      if (delx < 0.0) {
   294        *nerror = 1;
   295        RTE ("wapr", "offset X must be non-negative");
   296      }
   297      xx = x + em;
   298    } else {
   299      if (x < em) {
   300        *nerror = 1;
   301        return value;
   302      } else if (x == em) {
   303        value = -1.0;
   304        return value;
   305      }
   306      xx = x;
   307      delx = xx - em;
   308    }
   309  // Calculations for Wp.
   310    if (nb == 0) {
   311      if (fabs (xx) <= x0) {
   312        value = xx / (1.0 + xx / (1.0 + xx / (2.0 + xx / (0.6 + 0.34 * xx))));
   313        return value;
   314      } else if (xx <= x1) {
   315        reta = sqrt (d12 * delx);
   316        value = reta / (1.0 + reta / (3.0 + reta / (reta / (an4 + reta / (reta * an6 + an5)) + an3))) - 1.0;
   317        return value;
   318      } else if (xx <= 20.0) {
   319        reta = s2 * sqrt (1.0 - xx / em);
   320        an2 = 4.612634277343749 * sqrt (sqrt (reta + 1.09556884765625));
   321        value = reta / (1.0 + reta / (3.0 + (s21 * an2 + s22) * reta / (s23 * (an2 + reta)))) - 1.0;
   322      } else {
   323        zl = log (xx);
   324        value = log (xx / log (xx / pow (zl, exp (-1.124491989777808 / (0.4225028202459761 + zl)))));
   325      }
   326    }
   327  // Calculations for Wm.
   328    else {
   329      if (0.0 <= xx) {
   330        *nerror = 1;
   331        return value;
   332      } else if (xx <= x1) {
   333        reta = sqrt (d12 * delx);
   334        value = reta / (reta / (3.0 + reta / (reta / (an4 + reta / (reta * an6 - an5)) - an3)) - 1.0) - 1.0;
   335        return value;
   336      } else if (xx <= em9) {
   337        zl = log (-xx);
   338        t = -1.0 - zl;
   339        ts = sqrt (t);
   340        value = zl - (2.0 * ts) / (s2 + (c13 - t / (270.0 + ts * 127.0471381349219)) * ts);
   341      } else {
   342        zl = log (-xx);
   343        eta = 2.0 - em2 * xx;
   344        value = log (xx / log (-xx / ((1.0 - 0.5043921323068457 * (zl + 1.0)) * (sqrt (eta) + eta / 3.0) + 1.0)));
   345      }
   346    }
   347    for (i = 1; i <= niter; i++) {
   348      zn = log (xx / value) - value;
   349      temp = 1.0 + value;
   350      temp2 = temp + c23 * zn;
   351      temp2 = 2.0 * temp * temp2;
   352      value = value * (1.0 + (zn / temp) * (temp2 - zn) / (temp2 - 2.0 * zn));
   353    }
   354    return value;
   355  }
   356  
   357  real_8 _wapr (real_8 *x, int_4 *nb, int_4 *nerror, int_4 *l)
   358  {
   359  // F77 API.
   360    return wapr (*x, *nb, nerror, *l);
   361  }


© 2002-2025 J.M. van der Veer (jmvdveer@xs4all.nl)