Skip to content

Commit 81e5217

Browse files
authored
Nick/matlab support (#360)
* Internal: Try new link checker * Internal: Add codespell and fix typos. * Internal: See if codespell precommit finds config. * Internal: Found config. Now enable reading it * MATLAB: Add initial support for more matlab support. Closes #350
1 parent c3249ad commit 81e5217

File tree

6 files changed

+172
-0
lines changed

6 files changed

+172
-0
lines changed

pyttb/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pyttb.import_data import import_data
2020
from pyttb.khatrirao import khatrirao
2121
from pyttb.ktensor import ktensor
22+
from pyttb.matlab import matlab_support
2223
from pyttb.sptenmat import sptenmat
2324
from pyttb.sptensor import sptendiag, sptenrand, sptensor
2425
from pyttb.sptensor3 import sptensor3
@@ -51,6 +52,7 @@ def ignore_warnings(ignore=True):
5152
import_data.__name__,
5253
khatrirao.__name__,
5354
ktensor.__name__,
55+
matlab_support.__name__,
5456
sptenmat.__name__,
5557
sptendiag.__name__,
5658
sptenrand.__name__,

pyttb/matlab/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Partial support of MATLAB users in PYTTB."""

pyttb/matlab/matlab_support.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""A limited number of utilities to support users coming from MATLAB."""
2+
3+
# Copyright 2024 National Technology & Engineering Solutions of Sandia,
4+
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
5+
# U.S. Government retains certain rights in this software.
6+
7+
from typing import Optional, Union
8+
9+
import numpy as np
10+
11+
from pyttb.tensor import tensor
12+
13+
from .matlab_utilities import _matlab_array_str
14+
15+
PRINT_CLASSES = Union[tensor, np.ndarray]
16+
17+
18+
def matlab_print(
19+
data: Union[tensor, np.ndarray],
20+
format: Optional[str] = None,
21+
name: Optional[str] = None,
22+
):
23+
"""Print data in a format more similar to MATLAB.
24+
25+
Arguments
26+
---------
27+
data: Object to print
28+
format: Numerical formatting
29+
"""
30+
if not isinstance(data, (tensor, np.ndarray)):
31+
raise ValueError(
32+
f"matlab_print only supports inputs of type {PRINT_CLASSES} but got"
33+
f" {type(data)}."
34+
)
35+
if isinstance(data, np.ndarray):
36+
print(_matlab_array_str(data, format, name))
37+
return
38+
print(data._matlab_str(format, name))

pyttb/matlab/matlab_utilities.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Internal tools to aid in building MATLAB support.
2+
3+
Tensor classes can use these common tools, where matlab_support uses tensors.
4+
matlab_support can depend on this, but tensors and this shouldn't depend on it.
5+
Probably best for everything here to be private functions.
6+
"""
7+
8+
# Copyright 2024 National Technology & Engineering Solutions of Sandia,
9+
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
10+
# U.S. Government retains certain rights in this software.
11+
12+
import textwrap
13+
from typing import Optional, Tuple, Union
14+
15+
import numpy as np
16+
17+
18+
def _matlab_array_str(
19+
array: np.ndarray,
20+
format: Optional[str] = None,
21+
name: Optional[str] = None,
22+
skip_name: bool = False,
23+
) -> str:
24+
"""Convert numpy array to string more similar to MATLAB."""
25+
if name is None:
26+
name = type(array).__name__
27+
header_str = ""
28+
body_str = ""
29+
if len(array.shape) > 2:
30+
matlab_str = ""
31+
# Iterate over all possible slices (in Fortran order)
32+
for index in np.ndindex(
33+
array.shape[2:][::-1]
34+
): # Skip the first two dimensions and reverse the order
35+
original_index = index[::-1] # Reverse the order back to the original
36+
# Construct the slice indices
37+
slice_indices: Tuple[Union[int, slice], ...] = (
38+
slice(None),
39+
slice(None),
40+
*original_index,
41+
)
42+
slice_data = array[slice_indices]
43+
matlab_str += f"{name}(:,:, {', '.join(map(str, original_index))}) ="
44+
matlab_str += "\n"
45+
array_str = _matlab_array_str(slice_data, format, name, skip_name=True)
46+
matlab_str += textwrap.indent(array_str, "\t")
47+
matlab_str += "\n"
48+
return matlab_str[:-1] # Trim extra newline
49+
elif len(array.shape) == 2:
50+
header_str += f"{name}(:,:) ="
51+
for row in array:
52+
if format is None:
53+
body_str += " ".join(f"{val}" for val in row)
54+
else:
55+
body_str += " ".join(f"{val:{format}}" for val in row)
56+
body_str += "\n"
57+
else:
58+
header_str += f"{name}(:) ="
59+
for val in array:
60+
if format is None:
61+
body_str += f"{val}"
62+
else:
63+
body_str += f"{val:{format}}"
64+
body_str += "\n"
65+
66+
if skip_name:
67+
return body_str
68+
return header_str + "\n" + textwrap.indent(body_str[:-1], "\t")

pyttb/tensor.py

+19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import logging
10+
import textwrap
1011
from collections.abc import Iterable
1112
from inspect import signature
1213
from itertools import combinations_with_replacement, permutations
@@ -30,6 +31,7 @@
3031
from scipy import sparse
3132

3233
import pyttb as ttb
34+
from pyttb.matlab.matlab_utilities import _matlab_array_str
3335
from pyttb.pyttb_utils import (
3436
IndexVariant,
3537
OneDArray,
@@ -2723,6 +2725,23 @@ def __repr__(self):
27232725

27242726
__str__ = __repr__
27252727

2728+
def _matlab_str(
2729+
self, format: Optional[str] = None, name: Optional[str] = None
2730+
) -> str:
2731+
"""Non-standard representation to be more similar to MATLAB."""
2732+
header = name
2733+
if name is None:
2734+
name = "data"
2735+
if header is None:
2736+
header = "This"
2737+
2738+
matlab_str = f"{header} is a tensor of shape " + " x ".join(
2739+
map(str, self.shape)
2740+
)
2741+
2742+
array_str = _matlab_array_str(self.data, format, name)
2743+
return matlab_str + "\n" + textwrap.indent(array_str, "\t")
2744+
27262745

27272746
def tenones(shape: Shape, order: Union[Literal["F"], Literal["C"]] = "F") -> tensor:
27282747
"""Create a tensor of all ones.

tests/matlab/test_matlab_support.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2024 National Technology & Engineering Solutions of Sandia,
2+
# LLC (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the
3+
# U.S. Government retains certain rights in this software.
4+
5+
import numpy as np
6+
import pytest
7+
8+
from pyttb import matlab_support, tensor
9+
10+
11+
def test_matlab_printing_negative():
12+
with pytest.raises(ValueError):
13+
matlab_support.matlab_print("foo")
14+
15+
16+
def test_np_printing():
17+
"""These are just smoke tests since formatting needs manual style verification."""
18+
# Check different dimensionality support
19+
one_d_array = np.ones((1,))
20+
matlab_support.matlab_print(one_d_array)
21+
two_d_array = np.ones((1, 1))
22+
matlab_support.matlab_print(two_d_array)
23+
three_d_array = np.ones((1, 1, 1))
24+
matlab_support.matlab_print(three_d_array)
25+
26+
# Check name and format
27+
matlab_support.matlab_print(one_d_array, format="5.1f", name="X")
28+
matlab_support.matlab_print(two_d_array, format="5.1f", name="X")
29+
matlab_support.matlab_print(three_d_array, format="5.1f", name="X")
30+
31+
32+
def test_dense_printing():
33+
"""These are just smoke tests since formatting needs manual style verification."""
34+
# Check different dimensionality support
35+
example = tensor(np.arange(16), shape=(2, 2, 2, 2))
36+
# 4D
37+
matlab_support.matlab_print(example)
38+
# 2D
39+
matlab_support.matlab_print(example[:, :, 0, 0])
40+
# 1D
41+
matlab_support.matlab_print(example[:, 0, 0, 0])
42+
43+
# Check name and format
44+
matlab_support.matlab_print(example, format="5.1f", name="X")

0 commit comments

Comments
 (0)