1+ """
2+ 配置模块索引
3+
4+ 使用方式:
5+ # 方式1: 通过get_config函数
6+ from funrec.config import get_config
7+ config = get_config("fm")
8+
9+ # 方式2: 直接导入特定配置
10+ from funrec.config.config_fm import CONFIG
11+ config = CONFIG
12+ """
13+ from dataclasses import dataclass
14+ from typing import Dict , Any
15+
16+ from .config_afm import CONFIG as AFM_CONFIG
17+ from .config_apg import CONFIG as APG_CONFIG
18+ from .config_autoint import CONFIG as AUTOINT_CONFIG
19+ from .config_biassvd import CONFIG as BIASSVD_CONFIG
20+ from .config_dcn import CONFIG as DCN_CONFIG
21+ from .config_deepfm import CONFIG as DEEPFM_CONFIG
22+ from .config_dien import CONFIG as DIEN_CONFIG
23+ from .config_din import CONFIG as DIN_CONFIG
24+ from .config_dsin import CONFIG as DSIN_CONFIG
25+ from .config_dssm import CONFIG as DSSM_CONFIG
26+ from .config_eges import CONFIG as EGES_CONFIG
27+ from .config_essm import CONFIG as ESSM_CONFIG
28+ from .config_fibinet import CONFIG as FIBINET_CONFIG
29+ from .config_fm import CONFIG as FM_CONFIG
30+ from .config_fm_recall import CONFIG as FM_RECALL_CONFIG
31+ from .config_funksvd import CONFIG as FUNKSVD_CONFIG
32+ from .config_hmoe import CONFIG as HMOE_CONFIG
33+ from .config_hstu import CONFIG as HSTU_CONFIG
34+ from .config_item2vec import CONFIG as ITEM2VEC_CONFIG
35+ from .config_item_cf import CONFIG as ITEM_CF_CONFIG
36+ from .config_m2m import CONFIG as M2M_CONFIG
37+ from .config_mind import CONFIG as MIND_CONFIG
38+ from .config_mmoe import CONFIG as MMOE_CONFIG
39+ from .config_nfm import CONFIG as NFM_CONFIG
40+ from .config_pepnet import CONFIG as PEPNET_CONFIG
41+ from .config_ple import CONFIG as PLE_CONFIG
42+ from .config_pnn import CONFIG as PNN_CONFIG
43+ from .config_prm import CONFIG as PRM_CONFIG
44+ from .config_prs import CONFIG as PRS_CONFIG
45+ from .config_sasrec import CONFIG as SASREC_CONFIG
46+ from .config_sdm import CONFIG as SDM_CONFIG
47+ from .config_shared_bottom import CONFIG as SHARED_BOTTOM_CONFIG
48+ from .config_star import CONFIG as STAR_CONFIG
49+ from .config_swing import CONFIG as SWING_CONFIG
50+ from .config_user_cf import CONFIG as USER_CF_CONFIG
51+ from .config_wide_deep import CONFIG as WIDE_DEEP_CONFIG
52+ from .config_xdeepfm import CONFIG as XDEEPFM_CONFIG
53+ from .config_youtubednn import CONFIG as YOUTUBEDNN_CONFIG
54+
55+
56+ # 配置映射字典 - 通过模型名称获取配置
57+ CONFIG_MAPPING : Dict [str , Dict [str , Any ]] = {
58+ "afm" : AFM_CONFIG ,
59+ "apg" : APG_CONFIG ,
60+ "autoint" : AUTOINT_CONFIG ,
61+ "biassvd" : BIASSVD_CONFIG ,
62+ "dcn" : DCN_CONFIG ,
63+ "deepfm" : DEEPFM_CONFIG ,
64+ "dien" : DIEN_CONFIG ,
65+ "din" : DIN_CONFIG ,
66+ "dsin" : DSIN_CONFIG ,
67+ "dssm" : DSSM_CONFIG ,
68+ "eges" : EGES_CONFIG ,
69+ "essm" : ESSM_CONFIG ,
70+ "fibinet" : FIBINET_CONFIG ,
71+ "fm" : FM_CONFIG ,
72+ "fm_recall" : FM_RECALL_CONFIG ,
73+ "funksvd" : FUNKSVD_CONFIG ,
74+ "hmoe" : HMOE_CONFIG ,
75+ "hstu" : HSTU_CONFIG ,
76+ "item2vec" : ITEM2VEC_CONFIG ,
77+ "item_cf" : ITEM_CF_CONFIG ,
78+ "m2m" : M2M_CONFIG ,
79+ "mind" : MIND_CONFIG ,
80+ "mmoe" : MMOE_CONFIG ,
81+ "nfm" : NFM_CONFIG ,
82+ "pepnet" : PEPNET_CONFIG ,
83+ "ple" : PLE_CONFIG ,
84+ "pnn" : PNN_CONFIG ,
85+ "prm" : PRM_CONFIG ,
86+ "prs" : PRS_CONFIG ,
87+ "sasrec" : SASREC_CONFIG ,
88+ "sdm" : SDM_CONFIG ,
89+ "shared_bottom" : SHARED_BOTTOM_CONFIG ,
90+ "star" : STAR_CONFIG ,
91+ "swing" : SWING_CONFIG ,
92+ "user_cf" : USER_CF_CONFIG ,
93+ "wide_deep" : WIDE_DEEP_CONFIG ,
94+ "xdeepfm" : XDEEPFM_CONFIG ,
95+ "youtubednn" : YOUTUBEDNN_CONFIG ,
96+ }
97+
98+
99+ def get_config (model_name : str ) -> Dict [str , Any ]:
100+ """
101+ 根据模型名称获取配置字典
102+
103+ Args:
104+ model_name: 模型名称 (如 "fm", "afm")
105+
106+ Returns:
107+ 配置字典
108+
109+ Raises:
110+ KeyError: 如果模型配置不存在
111+ """
112+ if model_name not in CONFIG_MAPPING :
113+ available_models = ", " .join (sorted (CONFIG_MAPPING .keys ()))
114+ raise KeyError (f"模型 \" { model_name } \" 的配置不存在。可用模型: { available_models } " )
115+
116+ return CONFIG_MAPPING [model_name ]
117+
118+ @dataclass
119+ class Config :
120+ """配置类,包含所有配置部分。"""
121+
122+ data : Dict [str , Any ]
123+ features : Dict [str , Any ]
124+ training : Dict [str , Any ]
125+ evaluation : Dict [str , Any ]
126+
127+ def load_config (model_name : str ) -> Config :
128+ """
129+ 根据模型名称加载配置,返回Config对象
130+
131+ Args:
132+ model_name: 模型名称 (如 "fm", "afm")
133+
134+ Returns:
135+ Config对象,支持config.data、config.features等属性访问
136+
137+ Raises:
138+ KeyError: 如果模型配置不存在
139+ """
140+ config_dict = get_config (model_name )
141+
142+ return Config (
143+ data = config_dict .get ("data" , {}),
144+ features = config_dict .get ("features" , {}),
145+ training = config_dict .get ("training" , {}),
146+ evaluation = config_dict .get ("evaluation" , {}),
147+ )
148+
149+ # 导出所有配置
150+ __all__ = ["load_config" , "get_config" , "Config" , "CONFIG_MAPPING" ] + ["AFM_CONFIG" , "APG_CONFIG" , "AUTOINT_CONFIG" , "BIASSVD_CONFIG" , "DCN_CONFIG" , "DEEPFM_CONFIG" , "DIEN_CONFIG" , "DIN_CONFIG" , "DSIN_CONFIG" , "DSSM_CONFIG" , "EGES_CONFIG" , "ESSM_CONFIG" , "FIBINET_CONFIG" , "FM_CONFIG" , "FM_RECALL_CONFIG" , "FUNKSVD_CONFIG" , "HMOE_CONFIG" , "HSTU_CONFIG" , "ITEM2VEC_CONFIG" , "ITEM_CF_CONFIG" , "M2M_CONFIG" , "MIND_CONFIG" , "MMOE_CONFIG" , "NFM_CONFIG" , "PEPNET_CONFIG" , "PLE_CONFIG" , "PNN_CONFIG" , "PRM_CONFIG" , "PRS_CONFIG" , "SASREC_CONFIG" , "SDM_CONFIG" , "SHARED_BOTTOM_CONFIG" , "STAR_CONFIG" , "SWING_CONFIG" , "USER_CF_CONFIG" , "WIDE_DEEP_CONFIG" , "XDEEPFM_CONFIG" , "YOUTUBEDNN_CONFIG" ]
0 commit comments