Skip to content

Commit 8f47a1b

Browse files
datatypes for math methods extended
1 parent 31ed748 commit 8f47a1b

7 files changed

Lines changed: 541 additions & 93 deletions

numpower.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,9 +3717,11 @@ static NDArray *ndarray_resolve_unary_input(zval *array, int *owned)
37173717
zend_throw_error(NULL,
37183718
"Numeric string expected, got a whitespace-only value.");
37193719
} else {
3720+
/* Length-bounded (`%.*s`) so an embedded NUL doesn't
3721+
truncate the offending literal in the diagnostic. */
37203722
zend_throw_error(NULL,
3721-
"Numeric string expected, got malformed literal: \"%s\".",
3722-
p);
3723+
"Numeric string expected, got malformed literal: \"%.*s\".",
3724+
(int)n, p);
37233725
}
37243726
return NULL;
37253727
}

src/dd_math.c

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,6 @@ static const ndarray_dd_t DD_LN2 = {
236236
0.6931471805599453, /* hi: 0.693147180559945286... (closest double) */
237237
2.3190468138462996e-17 /* lo: residual = ln2 - hi */
238238
};
239-
static const ndarray_dd_t DD_LN10 = {
240-
/* ln(10) = 2.30258509299404568401799145468... */
241-
2.302585092994046,
242-
-2.1707562233822494e-16
243-
};
244239
static const ndarray_dd_t DD_LOG2_E = {
245240
/* 1/ln(2) = 1.44269504088896340735992468100... */
246241
1.4426950408889634,
@@ -259,9 +254,13 @@ static const ndarray_dd_t DD_LOG10_E = {
259254
* |r| ≤ ln(2)/2 ≈ 0.347. Then exp(x) = 2^k · exp(r); the 2^k factor
260255
* is exact in fp64 (just shifts the exponent), and exp(r) is evaluated
261256
* via the Taylor series 1 + r + r²/2! + r³/3! + … in DD arithmetic
262-
* using Horner's method. Twenty terms suffice for |r| ≤ 0.347 because
263-
* the (i+1)-th term shrinks by factor ≤ 0.347/(i+1) — at i = 19 the
264-
* term magnitude is well below DD epsilon (~2⁻¹⁰⁶ ≈ 1.2e-32).
257+
* using Horner's method. The series is summed through r²⁴/24!: at the
258+
* worst-case |r| ≤ ln(2)/2 ≈ 0.3466 the first omitted term r²⁵/25! ≈
259+
* 6.8e-37 is far below DD epsilon (~2⁻¹⁰⁶ ≈ 1.2e-32), so the result
260+
* carries full ~32-digit DD precision. (Twenty terms — the original
261+
* cutoff — left r²¹/21! ≈ 4.2e-30 in the remainder, capping accuracy
262+
* at ~29 digits and making the GPU DD path diverge from the CPU
263+
* libquadmath path at the 31st digit.)
265264
*
266265
* Handles overflow (`exp(x) > DBL_MAX`) by returning +inf and underflow
267266
* (`exp(x) < DBL_MIN_SUBNORMAL`) by returning 0. NaN propagates.
@@ -281,9 +280,9 @@ ndarray_dd_t ndarray_dd_exp(ndarray_dd_t a) {
281280
ndarray_dd_t k_dd = ndarray_dd_from_double(k_d);
282281
ndarray_dd_t r = ndarray_dd_sub(a, ndarray_dd_mul(k_dd, DD_LN2));
283282

284-
/* Horner evaluation of 1 + r·(1 + r/2·(1 + r/3·(… + r/20))) */
283+
/* Horner evaluation of 1 + r·(1 + r/2·(1 + r/3·(… + r/24))) */
285284
ndarray_dd_t result = ndarray_dd_from_double(1.0);
286-
for (int i = 20; i >= 1; i--) {
285+
for (int i = 24; i >= 1; i--) {
287286
/* result = 1 + (r/i) · result */
288287
ndarray_dd_t r_over_i = ndarray_dd_div(r, ndarray_dd_from_double((double)i));
289288
result = ndarray_dd_add(ndarray_dd_from_double(1.0),
@@ -332,8 +331,8 @@ ndarray_dd_t ndarray_dd_expm1(ndarray_dd_t a) {
332331
* shift m into [√0.5, √2) ≈ [0.707, 1.414); then |u| ≤ 0.172. The
333332
* atanh-style series ln(m) = 2·(u + u³/3 + u⁵/5 + u⁷/7 + …) converges
334333
* about twice as fast as the plain Taylor of ln(1+y) because the
335-
* even-power terms vanish. Eleven odd terms (u^21/21) give ~30 sig
336-
* digits at the |u| ≤ 0.172 boundary.
334+
* even-power terms vanish. Twenty-six odd terms (through u^51/51) give
335+
* full ~32-digit DD precision at the |u| ≤ 0.172 boundary.
337336
*
338337
* Final: log(x) = 2·Σ + e·ln(2).
339338
*
@@ -390,22 +389,39 @@ ndarray_dd_t ndarray_dd_log(ndarray_dd_t a) {
390389
/**
391390
* @brief DD-precision log1p(x) = log(1 + x).
392391
*
393-
* For |x| ≤ 0.5 use the Taylor series directly:
394-
* log1p(x) = x − x²/2 + x³/3 − x⁴/4 + …
395-
* evaluated via Horner so the cancellation at small x is avoided.
396-
* For |x| > 0.5 fall back to `log(1 + x)` — 1 + x has no cancellation
397-
* there.
392+
* For |x| ≤ 0.5 the value 1 + x suffers catastrophic cancellation of
393+
* x's sub-fp64 information (when |x| ≲ fp64 epsilon the whole of x lands
394+
* in the lo limb and is rounded away by the subsequent range reduction).
395+
* Use instead the area-hyperbolic-tangent identity
396+
* log1p(x) = 2·atanh( x / (2 + x) ),
397+
* with u = x / (2 + x). The divisor 2 + x stays in [1.5, 2.5] so it
398+
* never cancels and the DD add/divide preserve x's lo limb in full;
399+
* |u| ≤ 0.2 over the branch, so the odd series 2·(u + u³/3 + … + u⁵¹/51)
400+
* (26 odd terms) is below DD epsilon. For |x| > 0.5 there is no
401+
* cancellation in 1 + x, so defer to `dd_log(1 + x)` — that path also
402+
* covers the x ≤ −1 (→ NaN / −inf) and +inf edges.
398403
*
399404
* @param[in] a Input DD value (a > −1 for a finite result).
400405
* @return log(1 + a) in DD precision.
401406
*/
402407
ndarray_dd_t ndarray_dd_log1p(ndarray_dd_t a) {
403408
if (ndarray_dd_isnan(a)) return a;
404-
/* `dd_add(1, a)` preserves DD precision even when |a| is at DD
405-
epsilon — the lo limb of the sum captures the contribution of `a`
406-
past fp64's 53 bits. So `dd_log(1 + a)` is precision-faithful
407-
across the full input range without needing a Taylor branch. */
408-
return ndarray_dd_log(ndarray_dd_add(ndarray_dd_from_double(1.0), a));
409+
if (a.hi >= 0.5 || a.hi <= -0.5) {
410+
return ndarray_dd_log(ndarray_dd_add(ndarray_dd_from_double(1.0), a));
411+
}
412+
ndarray_dd_t one = ndarray_dd_from_double(1.0);
413+
ndarray_dd_t u = ndarray_dd_div(a,
414+
ndarray_dd_add(ndarray_dd_from_double(2.0), a));
415+
ndarray_dd_t u2 = ndarray_dd_mul(u, u);
416+
/* Same 26-odd-term atanh ladder as ndarray_dd_log; |u| ≤ 0.2 here so
417+
the truncated term u^53/53 is far below DD epsilon. */
418+
ndarray_dd_t sum = ndarray_dd_div(one, ndarray_dd_from_double(51.0));
419+
for (int k = 49; k >= 1; k -= 2) {
420+
ndarray_dd_t inv_k = ndarray_dd_div(one, ndarray_dd_from_double((double)k));
421+
sum = ndarray_dd_add(inv_k, ndarray_dd_mul(u2, sum));
422+
}
423+
ndarray_dd_t r = ndarray_dd_mul(u, sum);
424+
return ndarray_dd_add(r, r); /* · 2 */
409425
}
410426

411427
/**

src/ndmath/arithmetics.c

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3627,13 +3627,28 @@ static int unary_validate_numeric_string(const char *str, const char *which) {
36273627
return -1;
36283628
}
36293629
const char *p = unary_skip_sign_ws(str);
3630-
/* Accept inf / nan tokens (case-insensitive, with optional trailing
3631-
junk consistent with strtod's contract). */
3630+
/* Accept inf / infinity / nan tokens (case-insensitive). The token
3631+
must consume the rest of the (trimmed) literal — trailing junk
3632+
such as "infX" / "nanZ" is rejected rather than silently read as
3633+
a valid prefix the way strtod would, mirroring the strict array-
3634+
input inferrer `ndarray_infer_dtype_from_string`. */
36323635
char low3[4] = {0};
36333636
for (int i = 0; i < 3 && p[i]; i++) {
36343637
low3[i] = (char)(p[i] | 0x20);
36353638
}
36363639
if (!strncmp(low3, "inf", 3) || !strncmp(low3, "nan", 3)) {
3640+
const char *t = p + 3;
3641+
if (low3[0] == 'i') { /* maybe the "infinity" spelling */
3642+
char low5[6] = {0};
3643+
for (int i = 0; i < 5 && t[i]; i++) low5[i] = (char)(t[i] | 0x20);
3644+
if (!strncmp(low5, "inity", 5)) t += 5;
3645+
}
3646+
while (*t == ' ' || *t == '\t' || *t == '\n' || *t == '\r') t++;
3647+
if (*t != '\0') {
3648+
zend_throw_error(NULL,
3649+
"NDArray clip: '%s' is not a valid number: %s.", which, str);
3650+
return -1;
3651+
}
36373652
return 0;
36383653
}
36393654
int saw_digit = 0;
@@ -3670,12 +3685,65 @@ static int unary_validate_numeric_string(const char *str, const char *which) {
36703685
/**
36713686
* @brief Skip leading ASCII whitespace, returning the first non-space char's
36723687
* pointer. Mirrors `strtoll`'s leading-whitespace handling.
3688+
*
3689+
* @param[in] s NUL-terminated string to scan.
3690+
* @return Pointer into @p s at the first non-whitespace character (the
3691+
* terminating NUL when @p s is empty or all whitespace).
36733692
*/
36743693
static inline const char *unary_skip_ws(const char *s) {
36753694
while (*s == ' ' || *s == '\t' || *s == '\n' || *s == '\r') s++;
36763695
return s;
36773696
}
36783697

3698+
/** Special-value kind of a validated clip-bound literal. */
3699+
typedef enum {
3700+
UNARY_FINITE = 0, UNARY_POS_INF, UNARY_NEG_INF, UNARY_NAN
3701+
} unary_special_t;
3702+
3703+
/**
3704+
* @brief Classify an already-validated clip-bound literal as finite, ±inf,
3705+
* or nan (case-insensitive, honouring an optional leading sign).
3706+
*
3707+
* @param[in] str NUL-terminated, syntactically validated literal.
3708+
* @return The special-value kind; `UNARY_FINITE` for an ordinary number.
3709+
*/
3710+
static unary_special_t unary_classify_special(const char *str) {
3711+
const char *p = unary_skip_ws(str);
3712+
int neg = 0;
3713+
if (*p == '+' || *p == '-') { neg = (*p == '-'); p++; }
3714+
char low3[4] = {0};
3715+
for (int i = 0; i < 3 && p[i]; i++) low3[i] = (char)(p[i] | 0x20);
3716+
if (!strncmp(low3, "inf", 3)) return neg ? UNARY_NEG_INF : UNARY_POS_INF;
3717+
if (!strncmp(low3, "nan", 3)) return UNARY_NAN;
3718+
return UNARY_FINITE;
3719+
}
3720+
3721+
/**
3722+
* @brief Write the representable extreme of an integer dtype into @p out_buf.
3723+
*
3724+
* Used to give an inf/nan clip bound PyTorch's "no bound" semantics on the
3725+
* 8 integer dtypes (strtoll/strtoull would otherwise read zero digits from
3726+
* the token and yield 0, collapsing the clip range).
3727+
*
3728+
* @param[in] dt Canonical dtype string.
3729+
* @param[in] want_max Non-zero → dtype maximum; zero → dtype minimum
3730+
* (0 for the unsigned dtypes).
3731+
* @param[out] out_buf Buffer of `elsize(dt)` bytes to receive the value.
3732+
* @return 1 when @p dt is one of the 8 integer dtypes (value written);
3733+
* 0 otherwise (a float dtype — caller handles it).
3734+
*/
3735+
static int unary_write_int_extreme(const char *dt, int want_max, void *out_buf) {
3736+
if (!strcmp(dt, "int8")) { *(int8_t *)out_buf = want_max ? INT8_MAX : INT8_MIN; return 1; }
3737+
if (!strcmp(dt, "int16")) { *(int16_t *)out_buf = want_max ? INT16_MAX : INT16_MIN; return 1; }
3738+
if (!strcmp(dt, "int32")) { *(int32_t *)out_buf = want_max ? INT32_MAX : INT32_MIN; return 1; }
3739+
if (!strcmp(dt, "int64")) { *(int64_t *)out_buf = want_max ? INT64_MAX : INT64_MIN; return 1; }
3740+
if (!strcmp(dt, "uint8")) { *(uint8_t *)out_buf = want_max ? UINT8_MAX : 0; return 1; }
3741+
if (!strcmp(dt, "uint16")) { *(uint16_t *)out_buf = want_max ? UINT16_MAX : 0; return 1; }
3742+
if (!strcmp(dt, "uint32")) { *(uint32_t *)out_buf = want_max ? UINT32_MAX : 0; return 1; }
3743+
if (!strcmp(dt, "uint64")) { *(uint64_t *)out_buf = want_max ? UINT64_MAX : 0; return 1; }
3744+
return 0;
3745+
}
3746+
36793747
/**
36803748
* @brief Parse @p str into the typed scalar buffer @p out_buf for @p dt.
36813749
*
@@ -3703,6 +3771,21 @@ static int unary_parse_typed_scalar(const char *dt, const char *str,
37033771
const char *which, void *out_buf) {
37043772
if (unary_validate_numeric_string(str, which) < 0) return -1;
37053773

3774+
/* inf / nan bounds on integer dtypes: strtoll/strtoull read zero
3775+
digits from the token and yield 0, which would collapse the clip
3776+
range. Map the token to the dtype's representable extreme so an
3777+
inf bound acts as PyTorch's "no bound" (−inf → MIN, +inf → MAX),
3778+
and a nan bound becomes the no-op extreme for whichever side it
3779+
sits on (min → MIN, max → MAX), matching how the float path
3780+
silently ignores a nan bound. Float dtypes fall through so strtod
3781+
yields a real ±inf / nan. */
3782+
unary_special_t sp = unary_classify_special(str);
3783+
if (sp != UNARY_FINITE) {
3784+
int want_max = (sp == UNARY_POS_INF) ||
3785+
(sp == UNARY_NAN && !strcmp(which, "max"));
3786+
if (unary_write_int_extreme(dt, want_max, out_buf)) return 0;
3787+
}
3788+
37063789
/* Narrow integer dtypes — saturate the bound to the dtype range so
37073790
out-of-range literals don't wrap via the implicit `(T)strtoll(...)`
37083791
cast inside `ndarray_set_from_string`. int64/uint64 keep the
@@ -3713,39 +3796,39 @@ static int unary_parse_typed_scalar(const char *dt, const char *str,
37133796
if (!strcmp(dt, "uint8")) {
37143797
if (is_neg) { *(uint8_t *)out_buf = 0; return 0; }
37153798
unsigned long long v = strtoull(p, NULL, 10);
3716-
*(uint8_t *)out_buf = (uint8_t)(v > 0xFFu ? 0xFFu : v);
3799+
*(uint8_t *)out_buf = (uint8_t)(v > UINT8_MAX ? UINT8_MAX : v);
37173800
return 0;
37183801
}
37193802
if (!strcmp(dt, "uint16")) {
37203803
if (is_neg) { *(uint16_t *)out_buf = 0; return 0; }
37213804
unsigned long long v = strtoull(p, NULL, 10);
3722-
*(uint16_t *)out_buf = (uint16_t)(v > 0xFFFFu ? 0xFFFFu : v);
3805+
*(uint16_t *)out_buf = (uint16_t)(v > UINT16_MAX ? UINT16_MAX : v);
37233806
return 0;
37243807
}
37253808
if (!strcmp(dt, "uint32")) {
37263809
if (is_neg) { *(uint32_t *)out_buf = 0; return 0; }
37273810
unsigned long long v = strtoull(p, NULL, 10);
3728-
*(uint32_t *)out_buf = (uint32_t)(v > 0xFFFFFFFFu ? 0xFFFFFFFFu : v);
3811+
*(uint32_t *)out_buf = (uint32_t)(v > UINT32_MAX ? UINT32_MAX : v);
37293812
return 0;
37303813
}
37313814
if (!strcmp(dt, "int8")) {
37323815
long long v = strtoll(str, NULL, 10);
3733-
if (v > 0x7F) v = 0x7F;
3734-
else if (v < -0x80) v = -0x80;
3816+
if (v > INT8_MAX) v = INT8_MAX;
3817+
else if (v < INT8_MIN) v = INT8_MIN;
37353818
*(int8_t *)out_buf = (int8_t)v;
37363819
return 0;
37373820
}
37383821
if (!strcmp(dt, "int16")) {
37393822
long long v = strtoll(str, NULL, 10);
3740-
if (v > 0x7FFF) v = 0x7FFF;
3741-
else if (v < -0x8000) v = -0x8000;
3823+
if (v > INT16_MAX) v = INT16_MAX;
3824+
else if (v < INT16_MIN) v = INT16_MIN;
37423825
*(int16_t *)out_buf = (int16_t)v;
37433826
return 0;
37443827
}
37453828
if (!strcmp(dt, "int32")) {
37463829
long long v = strtoll(str, NULL, 10);
3747-
if (v > 0x7FFFFFFFLL) v = 0x7FFFFFFFLL;
3748-
else if (v < -0x80000000LL) v = -0x80000000LL;
3830+
if (v > INT32_MAX) v = INT32_MAX;
3831+
else if (v < INT32_MIN) v = INT32_MIN;
37493832
*(int32_t *)out_buf = (int32_t)v;
37503833
return 0;
37513834
}

0 commit comments

Comments
 (0)