Skip to content

Under univariate input, kanpiler fails to learn additive function combination, such as sin(5a) + cos(3a). #557

@CrayonNAS

Description

@CrayonNAS

from kan import *
import matplotlib.pyplot as plt
from kan.compiler import *
from sympy import *
a = symbols('a')
input_vars = [a]
expression = sin(5a)+cos(3a)
model = kanpiler(input_vars, expression)
x = torch.rand(100, 1) * 2 - 2
model(x)
model.plot()
plt.savefig('111.png')
formula = ex_round(model.symbolic_formula()[0][0], 4)
print(formula)
when I run this code, I get 1.0sin(5.0x_1).The model seems to have only captured the sine term in the original expression sin(5a) + cos(3a), while missing the cosine term cos(3*a).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions