Skip to content

Commit 9249996

Browse files
datatypes for math methods extended
1 parent 76afe42 commit 9249996

7 files changed

Lines changed: 548 additions & 41 deletions

numpower.c

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3581,12 +3581,10 @@ static const char *ndarray_infer_dtype_from_string(const char *str, size_t len)
35813581
literal contains a fractional / exponent part (→ fp128) and where
35823582
the magnitude digits live (for the int64-vs-uint64 split). Any
35833583
character that does not fit the grammar fails the scan. */
3584-
size_t k = i;
3585-
int has_sign = 0;
3586-
int is_neg = 0;
3584+
size_t k = i;
3585+
int is_neg = 0;
35873586
if (k < end && (str[k] == '+' || str[k] == '-')) {
3588-
has_sign = 1;
3589-
is_neg = (str[k] == '-');
3587+
is_neg = (str[k] == '-');
35903588
k++;
35913589
}
35923590
/* Mantissa integer part. */
@@ -3605,7 +3603,6 @@ static const char *ndarray_infer_dtype_from_string(const char *str, size_t len)
36053603
}
36063604
/* At least one digit in mantissa (integer or fractional). */
36073605
if (mant_int_len == 0 && frac_len == 0) {
3608-
(void)has_sign;
36093606
return NULL;
36103607
}
36113608

@@ -3631,17 +3628,25 @@ static const char *ndarray_infer_dtype_from_string(const char *str, size_t len)
36313628
return "float128";
36323629
}
36333630

3634-
/* Pure integer literal. */
3635-
if (is_neg) {
3636-
return "int64";
3637-
}
3638-
/* Magnitude check on the integer digits (skip leading zeros). */
3631+
/* Pure integer literal — magnitude check on the digits (skip leading
3632+
zeros) decides int64 vs uint64 vs float128. */
36393633
const char *p = str + mant_int_start;
36403634
size_t m = mant_int_len;
36413635
while (m > 1 && *p == '0') { p++; m--; }
36423636

36433637
static const char int64_max_str[] = "9223372036854775807"; /* 19 digits */
3638+
static const char int64_min_mag[] = "9223372036854775808"; /* |INT64_MIN|, 19 digits */
36443639
static const char uint64_max_str[] = "18446744073709551615"; /* 20 digits */
3640+
3641+
if (is_neg) {
3642+
/* Negative magnitude: int64 holds down to -INT64_MIN = -9223372036854775808;
3643+
anything wider must escalate to float128 (uint64 cannot represent
3644+
negatives, and strtoll would silently saturate to INT64_MIN with
3645+
errno=ERANGE if we'd routed it through int64). */
3646+
if (m > 19) return "float128";
3647+
if (m == 19 && memcmp(p, int64_min_mag, 19) > 0) return "float128";
3648+
return "int64";
3649+
}
36453650
if (m > 20) {
36463651
/* Past UINT64_MAX — escalate to fp128 to keep precision rather
36473652
than saturate the integer dtypes. */

src/debug.c

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ print_array_float32(float* buffer, int ndims, int* shape, int* strides, int cur_
9797
}
9898

9999
if (ndims == 0) {
100-
sprintf(str, "%g\n", buffer[0]);
100+
float v0 = buffer[0];
101+
if (isnan(v0)) sprintf(str, "nan\n");
102+
else sprintf(str, "%g\n", (double)v0);
101103
return str;
102104
}
103105

@@ -114,8 +116,14 @@ print_array_float32(float* buffer, int ndims, int* shape, int* strides, int cur_
114116
for (int k = 0; k < ndims; k++) {
115117
offset += index[k] * strides[k];
116118
}
117-
// Print the element
118-
sprintf(str + strlen(str), "%.8g", buffer[offset / sizeof(float)]);
119+
// Print the element — NaN canonicalized to unsigned form to
120+
// match PyTorch / Python repr (glibc `%g` emits "-nan" for
121+
// sign-bit-set NaN; users wouldn't expect the sign on a
122+
// numerical printout). Same normalization applies to fp64
123+
// and to every dtype routed through `ndarray_element_to_string`.
124+
float ve = buffer[offset / sizeof(float)];
125+
if (isnan(ve)) sprintf(str + strlen(str), "nan");
126+
else sprintf(str + strlen(str), "%.8g", (double)ve);
119127

120128
// Print a comma if this is not the last element in the dimension
121129
if (i < shape[cur_dim] - 1) {
@@ -234,7 +242,9 @@ print_array_float64(double* buffer, int ndims, int* shape, int* strides, int cur
234242
}
235243

236244
if (ndims == 0) {
237-
sprintf(str, "%g\n", buffer[0]);
245+
double v0 = buffer[0];
246+
if (isnan(v0)) sprintf(str, "nan\n");
247+
else sprintf(str, "%g\n", v0);
238248
return str;
239249
}
240250

@@ -251,8 +261,11 @@ print_array_float64(double* buffer, int ndims, int* shape, int* strides, int cur
251261
for (int k = 0; k < ndims; k++) {
252262
offset += index[k] * strides[k];
253263
}
254-
// Print the element
255-
sprintf(str + strlen(str), "%.16g", buffer[offset / sizeof(double)]);
264+
// Print the element — NaN canonicalized to unsigned form
265+
// (see analogous fp32 path above).
266+
double ve = buffer[offset / sizeof(double)];
267+
if (isnan(ve)) sprintf(str + strlen(str), "nan");
268+
else sprintf(str + strlen(str), "%.16g", ve);
256269

257270
// Print a comma if this is not the last element in the dimension
258271
if (i < shape[cur_dim] - 1) {

src/ndarray_types.c

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,17 +530,26 @@ void ndarray_element_to_string(const char *type,
530530
} else if (strcmp(type, "float16") == 0) {
531531
uint16_t fp16;
532532
memcpy(&fp16, data + byte_offset, 2);
533-
snprintf(buf, bufsize, "%.6g", ndarray_fp16_to_double(fp16));
533+
double v = ndarray_fp16_to_double(fp16);
534+
/* NaN-sign normalization: glibc `%g` prints "-nan" for a quiet
535+
NaN with the sign bit set, which PyTorch / Python `repr` do
536+
not (they always print "nan"). Match PyTorch parity across
537+
every fp dtype — fp128 / fp8 already canonicalize the same
538+
way; fp16/fp32/fp64 now follow suit. */
539+
if (isnan(v)) snprintf(buf, bufsize, "nan");
540+
else snprintf(buf, bufsize, "%.6g", v);
534541

535542
} else if (strcmp(type, "float32") == 0) {
536543
float v;
537544
memcpy(&v, data + byte_offset, 4);
538-
snprintf(buf, bufsize, "%.8g", (double)v);
545+
if (isnan(v)) snprintf(buf, bufsize, "nan");
546+
else snprintf(buf, bufsize, "%.8g", (double)v);
539547

540548
} else if (strcmp(type, "float64") == 0) {
541549
double v;
542550
memcpy(&v, data + byte_offset, 8);
543-
snprintf(buf, bufsize, "%.16g", v);
551+
if (isnan(v)) snprintf(buf, bufsize, "nan");
552+
else snprintf(buf, bufsize, "%.16g", v);
544553

545554
} else if (strcmp(type, "float128") == 0) {
546555
ndarray_fp128_t v;

src/ndmath/arithmetics.c

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3667,13 +3667,31 @@ static int unary_validate_numeric_string(const char *str, const char *which) {
36673667
return 0;
36683668
}
36693669

3670+
/**
3671+
* @brief Skip leading ASCII whitespace, returning the first non-space char's
3672+
* pointer. Mirrors `strtoll`'s leading-whitespace handling.
3673+
*/
3674+
static inline const char *unary_skip_ws(const char *s) {
3675+
while (*s == ' ' || *s == '\t' || *s == '\n' || *s == '\r') s++;
3676+
return s;
3677+
}
3678+
36703679
/**
36713680
* @brief Parse @p str into the typed scalar buffer @p out_buf for @p dt.
36723681
*
36733682
* Validates the string syntactically first so callers get a clean error
3674-
* instead of a silent 0 coerced from a malformed input. Routes through
3675-
* `ndarray_set_from_string` so `float128` / `int64` / `uint64` strings
3676-
* preserve precision. Other dtypes route through `double`.
3683+
* instead of a silent 0 coerced from a malformed input. For integer
3684+
* dtypes the value is *saturated* to the dtype's representable range
3685+
* (PyTorch `clamp` semantics): a negative bound for an unsigned dtype
3686+
* collapses to 0; a magnitude exceeding the signed dtype's `INT*_MAX`
3687+
* saturates to that max (or `INT*_MIN` if negative); without this
3688+
* saturation, `clip(uint8 tensor, -50, 100)` would silently wrap `-50`
3689+
* via the modulo-2^N cast inside `ndarray_set_from_string`, then see
3690+
* `lo (206) > hi (100)` and collapse every element to `100`. For
3691+
* float dtypes (and `int64`/`uint64`/`float128` where wide-precision
3692+
* strings carry the only loss-free intake), the call falls through to
3693+
* `ndarray_set_from_string` so `strtoll`/`strtoull`/`strtoflt128` keep
3694+
* the full source precision.
36773695
*
36783696
* @param[in] dt Canonical dtype string.
36793697
* @param[in] str Decimal literal.
@@ -3684,6 +3702,66 @@ static int unary_validate_numeric_string(const char *str, const char *which) {
36843702
static int unary_parse_typed_scalar(const char *dt, const char *str,
36853703
const char *which, void *out_buf) {
36863704
if (unary_validate_numeric_string(str, which) < 0) return -1;
3705+
3706+
/* Narrow integer dtypes — saturate the bound to the dtype range so
3707+
out-of-range literals don't wrap via the implicit `(T)strtoll(...)`
3708+
cast inside `ndarray_set_from_string`. int64/uint64 keep the
3709+
wide-precision intake path (their saturating boundary is exactly
3710+
at the strtoll/strtoull edge already). */
3711+
const char *p = unary_skip_ws(str);
3712+
int is_neg = (*p == '-');
3713+
if (!strcmp(dt, "uint8")) {
3714+
if (is_neg) { *(uint8_t *)out_buf = 0; return 0; }
3715+
unsigned long long v = strtoull(p, NULL, 10);
3716+
*(uint8_t *)out_buf = (uint8_t)(v > 0xFFu ? 0xFFu : v);
3717+
return 0;
3718+
}
3719+
if (!strcmp(dt, "uint16")) {
3720+
if (is_neg) { *(uint16_t *)out_buf = 0; return 0; }
3721+
unsigned long long v = strtoull(p, NULL, 10);
3722+
*(uint16_t *)out_buf = (uint16_t)(v > 0xFFFFu ? 0xFFFFu : v);
3723+
return 0;
3724+
}
3725+
if (!strcmp(dt, "uint32")) {
3726+
if (is_neg) { *(uint32_t *)out_buf = 0; return 0; }
3727+
unsigned long long v = strtoull(p, NULL, 10);
3728+
*(uint32_t *)out_buf = (uint32_t)(v > 0xFFFFFFFFu ? 0xFFFFFFFFu : v);
3729+
return 0;
3730+
}
3731+
if (!strcmp(dt, "int8")) {
3732+
long long v = strtoll(str, NULL, 10);
3733+
if (v > 0x7F) v = 0x7F;
3734+
else if (v < -0x80) v = -0x80;
3735+
*(int8_t *)out_buf = (int8_t)v;
3736+
return 0;
3737+
}
3738+
if (!strcmp(dt, "int16")) {
3739+
long long v = strtoll(str, NULL, 10);
3740+
if (v > 0x7FFF) v = 0x7FFF;
3741+
else if (v < -0x8000) v = -0x8000;
3742+
*(int16_t *)out_buf = (int16_t)v;
3743+
return 0;
3744+
}
3745+
if (!strcmp(dt, "int32")) {
3746+
long long v = strtoll(str, NULL, 10);
3747+
if (v > 0x7FFFFFFFLL) v = 0x7FFFFFFFLL;
3748+
else if (v < -0x80000000LL) v = -0x80000000LL;
3749+
*(int32_t *)out_buf = (int32_t)v;
3750+
return 0;
3751+
}
3752+
/* uint64: strtoull silently wraps a negative literal modulo 2^64
3753+
(`strtoull("-50")` returns `UINT64_MAX - 49`) — saturate to 0
3754+
explicitly. ERANGE on a positive overflow is what strtoull would
3755+
cap at UINT64_MAX anyway. */
3756+
if (!strcmp(dt, "uint64")) {
3757+
if (is_neg) { *(uint64_t *)out_buf = 0; return 0; }
3758+
/* Fall through to ndarray_set_from_string for the positive path
3759+
so wide-precision literals route through the same parser used
3760+
elsewhere. */
3761+
}
3762+
/* int64 / uint64 (positive) / float* : strtoll / strtoull / strtod /
3763+
strtoflt128 already saturate at the dtype's edge under ERANGE,
3764+
matching the behaviour we want without an explicit upper-bound check. */
36873765
ndarray_set_from_string(dt, (char *)out_buf, 0, str);
36883766
return 0;
36893767
}
@@ -3912,20 +3990,14 @@ static int unary_run_cpu_inplace(void *data, long n, const char *dt,
39123990
}
39133991
default: break;
39143992
}
3915-
/* Normalize NaN sign bit to canonical +NaN. libquadmath /
3916-
libm leak a sign-bit-set "-nan" out of `logq(-x)`,
3917-
`sqrtq(-x)`, `log1pq(-x)` etc. The fp64 path returns
3918-
NaN with the sign bit set too, but PHP's float
3919-
stringifier hides the sign (`var_dump(NAN)` prints
3920-
"NAN"); `quadmath_snprintf` honours it, so the user
3921-
sees the inconsistency only on fp128. Force-clear the
3922-
sign bit so display matches the rest of the unary
3923-
family. Skip on NDARRAY_UNOP_SIGN — that op uses NaN
3924-
propagation as a meaningful value (PyTorch parity) and
3925-
the input's sign bit is part of its signal. */
3926-
if (op != NDARRAY_UNOP_SIGN && NDARRAY_FP128_ISNAN(y)) {
3927-
y = NDARRAY_FP128_NAN();
3928-
}
3993+
/* NaN-sign canonicalization happens at stringification time
3994+
(`ndarray_fp128_to_string`) rather than here. This keeps
3995+
the in-memory bit pattern mathematically faithful:
3996+
`NumPower::negative(NaN)` flips the sign bit (matches
3997+
NumPy / PyTorch `neg` on NaN), `NumPower::positive(NaN)`
3998+
preserves the input, while `__toString` / `toArray`
3999+
render every NaN as the unsigned `"nan"` literal across
4000+
every fp dtype. */
39294001
p[i] = y;
39304002
}
39314003
return 0;

tests/math/115-unary-math-explog-string-scalar.phpt

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,15 @@ check("sign('9223372036854775808') uint64",
7676
is_string(NumPower::sign('9223372036854775808')),true);
7777
check("sign('18446744073709551615') uint64",
7878
is_string(NumPower::sign('18446744073709551615')),true);
79-
/* Negative magnitude can never fit uint64 → must stay int64 (signed). */
80-
check("sign('-18446744073709551615') int64",
81-
is_int(NumPower::sign('-18446744073709551615')), true);
79+
/* Negative magnitudes up to |INT64_MIN| = 9223372036854775808 fit int64;
80+
anything larger escalates to float128 to avoid the silent INT64_MIN
81+
saturation a naïve `strtoll` would deliver. */
82+
check("sign('-9223372036854775808') int64",
83+
is_int(NumPower::sign('-9223372036854775808')), true);
84+
check("sign('-9223372036854775809') float128",
85+
is_string(NumPower::sign('-9223372036854775809')), true);
86+
check("sign('-18446744073709551615') float128",
87+
is_string(NumPower::sign('-18446744073709551615')), true);
8288

8389
/* ── MATHEMATICAL FUNCTIONS ───────────────────────────────────────────── */
8490

@@ -220,7 +226,9 @@ OK sign('0') is int (int64)
220226
OK sign('9223372036854775807') int64
221227
OK sign('9223372036854775808') uint64
222228
OK sign('18446744073709551615') uint64
223-
OK sign('-18446744073709551615') int64
229+
OK sign('-9223372036854775808') int64
230+
OK sign('-9223372036854775809') float128
231+
OK sign('-18446744073709551615') float128
224232
OK abs('-3.5') (fp128)
225233
OK abs('-100') (int64)
226234
OK abs('18446744073709551615') (uint64)

0 commit comments

Comments
 (0)