Skip to content

Commit 0a72978

Browse files
authored
Fallbacks for product sectors involving fusion tensors (#80)
* fallbacks for product sectors involving fusion tensors * attempt at refactor * apply suggested change * attempt to make v1.10 inference happy * potato brain moment * move around auxiliary functions
1 parent a90386c commit 0a72978

2 files changed

Lines changed: 71 additions & 83 deletions

File tree

src/auxiliary.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ function _kron(A, B)
1313
end
1414
return C
1515
end
16+
_kron_promote(A::Number, B::Number, _, _) = A * B
17+
function _kron_promote(A₁, B₁, sz₁, sz₂)
18+
return _kron(A₁ isa Number ? fill(A₁, sz₁) : A₁, B₁ isa Number ? fill(B₁, sz₂) : B₁)
19+
end
1620

1721
# Manhattan based distance enumeration: I is supposed to be one-based index
1822
# TODO: is there any way to make this faster?

src/product.jl

Lines changed: 67 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -87,111 +87,95 @@ end
8787
_firstsector(x::ProductSector) = x.sectors[1]
8888
_tailsector(x::ProductSector) = ProductSector(Base.tail(x.sectors))
8989

90+
# handles R-, A- and B-symbols correctly because of Frobenius reciprocity
91+
# i.e. Nsymbol(a, b, c) = Nsymbol(c, dual(b), a) = Nsymbol(dual(a), c, b)
92+
# and for braided categories Nsymbol(a, b, c) = Nsymbol(b, a, c)
93+
_symbol_size((a, b, c)::NTuple{3, Sector}) = (n = Nsymbol(a, b, c); (n, n))
94+
_symbol_size((a, b, c, d, e, f)::NTuple{6, Sector}) = (Nsymbol(a, b, e), Nsymbol(e, c, d), Nsymbol(b, c, f), Nsymbol(a, f, d))
95+
96+
@inline function _kron_promote_inputs(sectors)
97+
heads = map(_firstsector, sectors)
98+
tails = map(_tailsector, sectors)
99+
sz₁ = _symbol_size(heads)
100+
sz₂ = _symbol_size(tails)
101+
return heads, tails, sz₁, sz₂
102+
end
103+
90104
function Fsymbol(a::I, b::I, c::I, d::I, e::I, f::I) where {I <: ProductSector}
91-
heads = map(_firstsector, (a, b, c, d, e, f))
92-
tails = map(_tailsector, (a, b, c, d, e, f))
93-
F₁ = Fsymbol(heads...)
94-
F₂ = Fsymbol(tails...)
95-
if F₁ isa Number && F₂ isa Number
96-
return F₁ * F₂
97-
elseif F₁ isa Number
98-
a₁, b₁, c₁, d₁, e₁, f₁ = heads
99-
sz₁ = (
100-
Nsymbol(a₁, b₁, e₁), Nsymbol(e₁, c₁, d₁), Nsymbol(b₁, c₁, f₁), Nsymbol(a₁, f₁, d₁),
101-
)
102-
F₁′ = fill(F₁, sz₁)
103-
return _kron(F₁′, F₂)
104-
elseif F₂ isa Number
105-
a₂, b₂, c₂, d₂, e₂, f₂ = tails
106-
sz₂ = (
107-
Nsymbol(a₂, b₂, e₂), Nsymbol(e₂, c₂, d₂), Nsymbol(b₂, c₂, f₂), Nsymbol(a₂, f₂, d₂),
108-
)
109-
F₂′ = fill(F₂, sz₂)
110-
return _kron(F₁, F₂′)
111-
else
112-
return _kron(F₁, F₂)
113-
end
105+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c, d, e, f))
106+
V₁ = Fsymbol(heads...)
107+
V₂ = Fsymbol(tails...)
108+
return _kron_promote(V₁, V₂, sz₁, sz₂)
114109
end
115-
function Fsymbol(
116-
a::I, b::I, c::I, d::I, e::I, f::I
117-
) where {I <: ProductSector{<:Tuple{Sector}}}
110+
function Fsymbol(a::I, b::I, c::I, d::I, e::I, f::I) where {I <: ProductSector{<:Tuple{Sector}}}
118111
return Fsymbol(map(_firstsector, (a, b, c, d, e, f))...)
119112
end
113+
function Fsymbol_from_fusiontensor(a::I, b::I, c::I, d::I, e::I, f::I) where {I <: ProductSector}
114+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c, d, e, f))
115+
V₁ = Fsymbol_from_fusiontensor(heads...)
116+
V₂ = Fsymbol_from_fusiontensor(tails...)
117+
return _kron_promote(V₁, V₂, sz₁, sz₂)
118+
end
119+
function Fsymbol_from_fusiontensor(a::I, b::I, c::I, d::I, e::I, f::I) where {I <: ProductSector{<:Tuple{Sector}}}
120+
return Fsymbol_from_fusiontensor(map(_firstsector, (a, b, c, d, e, f))...)
121+
end
120122

121123
function Rsymbol(a::I, b::I, c::I) where {I <: ProductSector}
122-
heads = map(_firstsector, (a, b, c))
123-
tails = map(_tailsector, (a, b, c))
124-
R₁ = Rsymbol(heads...)
125-
R₂ = Rsymbol(tails...)
126-
if R₁ isa Number && R₂ isa Number
127-
R₁ * R₂
128-
elseif R₁ isa Number
129-
a₁, b₁, c₁ = heads
130-
sz₁ = (Nsymbol(a₁, b₁, c₁), Nsymbol(b₁, a₁, c₁)) # 0 x 0 or 1 x 1
131-
R₁′ = fill(R₁, sz₁)
132-
return _kron(R₁′, R₂)
133-
elseif R₂ isa Number
134-
a₂, b₂, c₂ = tails
135-
sz₂ = (Nsymbol(a₂, b₂, c₂), Nsymbol(b₂, a₂, c₂)) # 0 x 0 or 1 x 1
136-
R₂′ = fill(R₂, sz₂)
137-
return _kron(R₁, R₂′)
138-
else
139-
return _kron(R₁, R₂)
140-
end
124+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c))
125+
V₁ = Rsymbol(heads...)
126+
V₂ = Rsymbol(tails...)
127+
return _kron_promote(V₁, V₂, sz₁, sz₂)
141128
end
142129
function Rsymbol(a::I, b::I, c::I) where {I <: ProductSector{<:Tuple{Sector}}}
143130
return Rsymbol(map(_firstsector, (a, b, c))...)
144131
end
132+
function Rsymbol_from_fusiontensor(a::I, b::I, c::I) where {I <: ProductSector}
133+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c))
134+
V₁ = Rsymbol_from_fusiontensor(heads...)
135+
V₂ = Rsymbol_from_fusiontensor(tails...)
136+
return _kron_promote(V₁, V₂, sz₁, sz₂)
137+
end
138+
function Rsymbol_from_fusiontensor(a::I, b::I, c::I) where {I <: ProductSector{<:Tuple{Sector}}}
139+
return Rsymbol_from_fusiontensor(map(_firstsector, (a, b, c))...)
140+
end
145141

146142
function Bsymbol(a::I, b::I, c::I) where {I <: ProductSector}
147-
heads = map(_firstsector, (a, b, c))
148-
tails = map(_tailsector, (a, b, c))
149-
B₁ = Bsymbol(heads...)
150-
B₂ = Bsymbol(tails...)
151-
if B₁ isa Number && B₂ isa Number
152-
B₁ * B₂
153-
elseif B₁ isa Number
154-
a₁, b₁, c₁ = heads
155-
sz₁ = (Nsymbol(a₁, b₁, c₁), Nsymbol(c₁, dual(b₁), a₁)) # 0 x 0 or 1 x 1
156-
B₁′ = fill(B₁, sz₁)
157-
return _kron(B₁′, B₂)
158-
elseif B₂ isa Number
159-
a₂, b₂, c₂ = tails
160-
sz₂ = (Nsymbol(a₂, b₂, c₂), Nsymbol(c₂, dual(b₂), a₂)) # 0 x 0 or 1 x 1
161-
B₂′ = fill(B₂, sz₂)
162-
return _kron(B₁, B₂′)
163-
else
164-
return _kron(B₁, B₂)
165-
end
143+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c))
144+
V₁ = Bsymbol(heads...)
145+
V₂ = Bsymbol(tails...)
146+
return _kron_promote(V₁, V₂, sz₁, sz₂)
166147
end
167148
function Bsymbol(a::I, b::I, c::I) where {I <: ProductSector{<:Tuple{Sector}}}
168149
return Bsymbol(map(_firstsector, (a, b, c))...)
169150
end
151+
function Bsymbol_from_fusiontensor(a::I, b::I, c::I) where {I <: ProductSector}
152+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c))
153+
V₁ = Bsymbol_from_fusiontensor(heads...)
154+
V₂ = Bsymbol_from_fusiontensor(tails...)
155+
return _kron_promote(V₁, V₂, sz₁, sz₂)
156+
end
157+
function Bsymbol_from_fusiontensor(a::I, b::I, c::I) where {I <: ProductSector{<:Tuple{Sector}}}
158+
return Bsymbol_from_fusiontensor(map(_firstsector, (a, b, c))...)
159+
end
170160

171161
function Asymbol(a::I, b::I, c::I) where {I <: ProductSector}
172-
heads = map(_firstsector, (a, b, c))
173-
tails = map(_tailsector, (a, b, c))
174-
A₁ = Asymbol(heads...)
175-
A₂ = Asymbol(tails...)
176-
if A₁ isa Number && A₂ isa Number
177-
A₁ * A₂
178-
elseif A₁ isa Number
179-
a₁, b₁, c₁ = heads
180-
sz₁ = (Nsymbol(a₁, b₁, c₁), Nsymbol(dual(a₁), c₁, b₁)) # 0 x 0 or 1 x 1
181-
A₁′ = fill(A₁, sz₁)
182-
return _kron(A₁′, A₂)
183-
elseif A₂ isa Number
184-
a₂, b₂, c₂ = tails
185-
sz₂ = (Nsymbol(a₂, b₂, c₂), Nsymbol(dual(a₂), c₂, b₂)) # 0 x 0 or 1 x 1
186-
A₂′ = fill(A₂, sz₂)
187-
return _kron(A₁, A₂′)
188-
else
189-
return _kron(A₁, A₂)
190-
end
162+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c))
163+
V₁ = Asymbol(heads...)
164+
V₂ = Asymbol(tails...)
165+
return _kron_promote(V₁, V₂, sz₁, sz₂)
191166
end
192167
function Asymbol(a::I, b::I, c::I) where {I <: ProductSector{<:Tuple{Sector}}}
193168
return Asymbol(map(_firstsector, (a, b, c))...)
194169
end
170+
function Asymbol_from_fusiontensor(a::I, b::I, c::I) where {I <: ProductSector}
171+
heads, tails, sz₁, sz₂ = _kron_promote_inputs((a, b, c))
172+
V₁ = Asymbol_from_fusiontensor(heads...)
173+
V₂ = Asymbol_from_fusiontensor(tails...)
174+
return _kron_promote(V₁, V₂, sz₁, sz₂)
175+
end
176+
function Asymbol_from_fusiontensor(a::I, b::I, c::I) where {I <: ProductSector{<:Tuple{Sector}}}
177+
return Asymbol_from_fusiontensor(map(_firstsector, (a, b, c))...)
178+
end
195179

196180
frobenius_schur_phase(p::ProductSector) = prod(frobenius_schur_phase, p.sectors)
197181
frobenius_schur_indicator(p::ProductSector) = prod(frobenius_schur_indicator, p.sectors)

0 commit comments

Comments
 (0)