-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathpyproject.toml
More file actions
145 lines (132 loc) · 4.14 KB
/
pyproject.toml
File metadata and controls
145 lines (132 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
[project]
name = "spd"
version = "0.0.1"
description = "Sparse Parameter Decomposition"
requires-python = "==3.13.*"
urls = { "Homepage" = "https://github.com/goodfire-ai/spd" }
license = { text = "MIT" }
readme = "README.md"
dependencies = [
"torch>=2.6",
"torchvision>=0.23,<0.24",
"pydantic<2.12", # https://github.com/goodfire-ai/spd/pull/232 , https://github.com/goodfire-ai/spd/issues/221
"wandb>=0.20.1", # Avoid wandb.sdk.wandb_manager.ManagerConnectionRefusedError
"fire",
"tqdm",
"ipykernel",
"transformers",
"jaxtyping",
"einops",
"matplotlib",
"numpy",
"python-dotenv",
"wandb-workspaces==0.1.12", # See https://github.com/wandb/wandb-workspaces/issues/65
"sympy",
"streamlit",
"streamlit-antd-components",
# `datasets` less than 2.21.0 causes issues due to incompatibility with numpy>=2.0
# see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 )
"datasets>=2.21.0",
"scipy>=1.14.1",
"fastapi",
"uvicorn",
"openrouter>=0.1.1",
"httpx>=0.28.0",
"zstandard" # For streaming datasets
]
[dependency-groups]
dev = [
"pytest",
"pytest-cov", # for coverage reports
"pytest-xdist", # parallel test execution
"pytest-testmon", # only re-run tests affected by code changes
"ruff",
"basedpyright<1.32.0", # pyright and wandb issues, see https://github.com/goodfire-ai/spd/pull/232
"pre-commit",
]
[project.scripts]
spd-run = "spd.scripts.run_cli:cli"
spd-local = "spd.scripts.run_local:cli"
spd-pretrain = "spd.pretrain.scripts.run_slurm:cli"
spd-clustering = "spd.clustering.scripts.run_pipeline:cli"
spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli"
spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli"
spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli"
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["."]
include = ["spd*"]
[tool.ruff]
line-length = 100
fix = true
extend-exclude = ["spd/app/frontend"]
[tool.ruff.lint]
ignore = [
"F722", # Incompatible with jaxtyping
"E731", # I think lambda functions are fine in several places
"E501", # there are a lot of long lines in the codebase
]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
]
[tool.ruff.format]
# Enable reformatting of code snippets in docstrings.
docstring-code-format = true
[tool.ruff.lint.isort]
known-third-party = ["wandb"]
[tool.pyright]
include = ["spd", "tests"]
exclude = ["**/wandb/**", "spd/utils/linear_sum_assignment.py", "spd/app/frontend"]
stubPath = "typings" # Having type stubs for transformers shaves 10 seconds off basedpyright calls
strictListInference = true
strictDictionaryInference = true
strictSetInference = true
reportFunctionMemberAccess = true
reportUnknownParameterType = true
reportIncompatibleMethodOverride = true
reportIncompatibleVariableOverride = true
reportOverlappingOverload = true
reportConstantRedefinition = true
reportImportCycles = true
reportPropertyTypeMismatch = true
reportMissingTypeArgument = true
reportUnnecessaryCast = true
reportUnnecessaryComparison = true
reportUnnecessaryContains = true
reportUnusedExpression = true
reportMatchNotExhaustive = true
reportPrivateImportUsage = false
# basedpyright
reportCallIssue = true
reportAny = false
reportUnusedCallResult = false
reportUnknownMemberType = false
reportUnknownVariableType = false
reportUnknownArgumentType = false
reportExplicitAny = false
reportMissingTypeStubs = false
reportImplicitStringConcatenation = false
reportPrivateUsage = false
reportUnannotatedClassAttribute = false
reportUnknownLambdaType = false
[tool.pytest.ini_options]
addopts = ["--import-mode=importlib"]
filterwarnings = [
# https://github.com/google/python-fire/pull/447
"ignore::DeprecationWarning:fire:59",
# Ignore Pydantic V1 deprecation warnings from wandb_workspaces
"ignore:Pydantic V1 style.*:DeprecationWarning:wandb_workspaces",
]