Skip to content

Commit 805a227

Browse files
committed
Assert standard form in _to_std_string
1 parent c78dae4 commit 805a227

1 file changed

Lines changed: 37 additions & 13 deletions

File tree

src/std.jl

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,33 @@ function throw_not_std(arg::Tensor)
210210
throw(DomainError(arg, "Cannot write expression in standard notation"))
211211
end
212212

213+
"""
214+
Returns a string representation of the input expression. The input must be in standard form.
215+
"""
216+
function _to_std_string end
217+
218+
function is_standard_form(arg::Tensor)
219+
free_ids = unique(get_free_indices(arg))
220+
221+
if length(free_ids) > 2
222+
return false
223+
end
224+
225+
if length(free_ids) == 2
226+
if first(free_ids) isa Upper && last(free_ids) isa Lower ||
227+
first(free_ids) isa Lower && last(free_ids) isa Upper
228+
return true
229+
end
230+
231+
return false
232+
end
233+
234+
return true
235+
end
236+
213237
function _to_std_string(arg::Monomial)
238+
@assert is_standard_form(arg)
239+
214240
ids = get_indices(arg)
215241

216242
if length(ids) == 2
@@ -227,28 +253,26 @@ function _to_std_string(arg::Monomial)
227253
elseif typeof(ids[1]) == Lower
228254
return arg.id * ""
229255
end
230-
elseif isempty(ids)
231-
return arg.id
232256
end
233257

234-
throw_not_std(arg)
258+
return arg.id
235259
end
236260

237261
function _to_std_string(arg::KrD)
262+
@assert is_standard_form(arg)
263+
238264
ids = get_indices(arg)
239265

240-
if length(ids) == 2
241-
if typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
242-
return "I"
243-
elseif typeof(ids[1]) == Lower && typeof(ids[2]) == Upper
244-
return "Iᵀ"
245-
end
266+
if typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
267+
return "I"
268+
elseif typeof(ids[1]) == Lower && typeof(ids[2]) == Upper
269+
return "Iᵀ"
246270
end
247-
248-
throw_not_std(arg)
249271
end
250272

251273
function _to_std_string(arg::Zero)
274+
@assert is_standard_form(arg)
275+
252276
ids = get_indices(arg)
253277

254278
if length(ids) == 2
@@ -264,8 +288,6 @@ function _to_std_string(arg::Zero)
264288
return "vec(0)ᵀ"
265289
end
266290
end
267-
268-
throw_not_std(arg)
269291
end
270292

271293
function _to_std_string(arg::Real)
@@ -319,6 +341,8 @@ function get_contra_covariant_matrix(arg1::Tensor, arg2::Tensor)
319341
end
320342

321343
function _to_std_string(arg::BinaryOperation{Mult})
344+
@assert is_standard_form(arg)
345+
322346
if is_elementwise_multiplication(arg.arg1, arg.arg2)
323347
indices = get_indices(arg)
324348
target_indices = unique(eliminate_indices(indices))

0 commit comments

Comments
 (0)