Skip to content

Commit 04a7c4f

Browse files
Update fit.jl
1 parent 1b30628 commit 04a7c4f

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

src/fit.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,27 @@ function regife(
6969
end
7070
formula, formula_fes = FixedEffectModels.parse_fe(formula)
7171
has_fes = formula_fes != FormulaTerm(ConstantTerm(0), ConstantTerm(0))
72+
7273
fes, feids, fekeys = FixedEffectModels.parse_fixedeffect(df, formula_fes)
73-
has_fes_intercept = false
74-
## Compute factors, an array of AbtractFixedEffects
74+
has_fe_intercept = any(fe.interaction isa UnitWeights for fe in fes)
75+
76+
# remove intercept if absorbed by fixed effects
77+
if has_fe_intercept
78+
formula = FormulaTerm(formula.lhs, tuple(InterceptTerm{false}(), (term for term in eachterm(formula.rhs) if !isa(term, Union{ConstantTerm,InterceptTerm}))...))
79+
end
80+
has_intercept = hasintercept(formula)
81+
82+
7583
if has_fes
76-
if any([isa(fe.interaction, Ones) for fe in fes])
77-
formula = FormulaTerm(formula.lhs, tuple(ConstantTerm(0), (t for t in eachterm(formula.rhs) if t!= ConstantTerm(1))...))
78-
has_fes_intercept = true
84+
if any(fe.interaction isa UnitWeights for fe in fes)
85+
has_fe_intercept = true
7986
end
8087
fes = FixedEffect[fe[esample] for fe in fes]
8188
feM = AbstractFixedEffectSolver{Float64}(fes, weights, Val{:cpu})
8289
end
8390

8491

85-
has_intercept = ConstantTerm(1) eachterm(formula.rhs)
92+
8693

8794

8895
iterations = 0
@@ -102,7 +109,7 @@ function regife(
102109
formula_schema = apply_schema(formula, schema(formula, subdf, contrasts), StatisticalModel)
103110

104111
y = convert(Vector{Float64}, response(formula_schema, subdf))
105-
tss_total = tss(y, has_intercept || has_fes_intercept, weights)
112+
tss_total = tss(y, has_intercept | has_fe_intercept, weights)
106113

107114
X = convert(Matrix{Float64}, modelmatrix(formula_schema, subdf))
108115

@@ -228,7 +235,7 @@ function regife(
228235
# compute various r2
229236
nobs = sum(esample)
230237
rss = sum(abs2, residualsm)
231-
_tss = tss(ym ./ sqrtw, has_intercept || has_fes_intercept, weights)
238+
_tss = tss(ym ./ sqrtw, has_intercept | has_fe_intercept, weights)
232239
r2_within = 1 - rss / _tss
233240

234241
rss = sum(abs2, residuals)

0 commit comments

Comments
 (0)