Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# AttentionXML
[AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification](https://arxiv.org/abs/1811.01727)

## Requirements

* python==3.7.4
* click==7.0
* ruamel.yaml==0.16.5
* numpy==1.16.2
* scipy==1.3.1
* scikit-learn==0.21.2
* gensim==3.4.0
* torch==1.0.1
* nltk==3.4
* tqdm==4.31.1
* joblib==0.13.2
* logzero==1.5.0
## Installation

You can install AttentionXML by following commands:

```bash
git clone
cd AttentionXML
conda create -n attentionxml python=3.13
conda activate attentionxml
# 2025/12/16 the newest version of torch is 2.9.1, you can install it directly by:
pip3 install torch --index-url https://download.pytorch.org/whl/cu126
pip install -r requirements.txt
```

## Datasets

Expand Down
2 changes: 1 addition & 1 deletion deepxml/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def build_tree_by_level(sparse_data_x, sparse_data_y, mlb, eps: float, max_leaf:
levels, q = [2**x for x in levels], None
for i in range(len(levels)-1, -1, -1):
if os.path.exists(F'{groups_path}-Level-{i}.npy'):
labels_list = np.load(F'{groups_path}-Level-{i}.npy')
labels_list = np.load(F'{groups_path}-Level-{i}.npy', allow_pickle=True)
q = [(labels_i, labels_f[labels_i]) for labels_i in labels_list]
break
if q is None:
Expand Down
15 changes: 11 additions & 4 deletions deepxml/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def get_word_emb(vec_path, vocab_path=None):
if vocab_path is not None:
with open(vocab_path) as fp:
vocab = {word: idx for idx, word in enumerate(fp)}
return np.load(vec_path), vocab
return np.load(vec_path, allow_pickle=True), vocab
else:
return np.load(vec_path)
return np.load(vec_path, allow_pickle=True)


def get_data(text_file, label_file=None):
return np.load(text_file), np.load(label_file) if label_file is not None else None
text_data = np.load(text_file, allow_pickle=True)
label_data = np.load(label_file, allow_pickle=True) if label_file is not None else None
return text_data, label_data


def convert_to_binary(text_file, label_file=None, max_len=None, vocab=None, pad='<PAD>', unknown='<UNK>'):
Expand All @@ -74,6 +76,11 @@ def truncate_text(texts, max_len=500, padding_idx=0, unknown_idx=1):

def get_mlb(mlb_path, labels=None) -> MultiLabelBinarizer:
if os.path.exists(mlb_path):
# Handle sklearn module path changes for backward compatibility
import sys
import sklearn.preprocessing
if 'sklearn.preprocessing.label' not in sys.modules:
sys.modules['sklearn.preprocessing.label'] = sklearn.preprocessing
return joblib.load(mlb_path)
mlb = MultiLabelBinarizer(sparse_output=True)
mlb.fit(labels)
Expand All @@ -83,7 +90,7 @@ def get_mlb(mlb_path, labels=None) -> MultiLabelBinarizer:

def get_sparse_feature(feature_file, label_file):
sparse_x, _ = load_svmlight_file(feature_file, multilabel=True)
return normalize(sparse_x), np.load(label_file) if label_file is not None else None
return normalize(sparse_x), np.load(label_file, allow_pickle=True) if label_file is not None else None


def output_res(output_path, name, scores, labels):
Expand Down
7 changes: 4 additions & 3 deletions deepxml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@

def get_mlb(classes: TClass = None, mlb: TMlb = None, targets: TTarget = None):
if classes is not None:
mlb = MultiLabelBinarizer(classes, sparse_output=True)
mlb = MultiLabelBinarizer(sparse_output=True)
mlb.fit([classes])
if mlb is None and targets is not None:
if isinstance(targets, csr_matrix):
mlb = MultiLabelBinarizer(range(targets.shape[1]), sparse_output=True)
mlb.fit(None)
mlb = MultiLabelBinarizer(sparse_output=True)
mlb.fit([list(range(targets.shape[1]))])
else:
mlb = MultiLabelBinarizer(sparse_output=True)
mlb.fit(targets)
Expand Down
15 changes: 9 additions & 6 deletions deepxml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class Model(object):

"""
def __init__(self, network, model_path, gradient_clip_value=5.0, device_ids=None, **kwargs):
self.model = nn.DataParallel(network(**kwargs).cuda(), device_ids=device_ids)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device_ids is None and torch.cuda.is_available():
device_ids = list(range(torch.cuda.device_count()))
self.model = nn.DataParallel(network(**kwargs).to(self.device), device_ids=device_ids)
self.loss_fn = nn.BCEWithLogitsLoss()
self.model_path, self.state = model_path, {}
os.makedirs(os.path.split(self.model_path)[0], exist_ok=True)
Expand Down Expand Up @@ -64,7 +67,7 @@ def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params:
self.swa_init()
for i, (train_x, train_y) in enumerate(train_loader, 1):
global_step += 1
loss = self.train_step(train_x, train_y.cuda())
loss = self.train_step(train_x, train_y.to(self.device))
if global_step % step == 0:
self.swa_step()
self.swap_swa_params()
Expand Down Expand Up @@ -99,9 +102,9 @@ def clip_gradient(self):
if self.gradient_clip_value is not None:
max_norm = max(self.gradient_norm_queue)
total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm * self.gradient_clip_value)
self.gradient_norm_queue.append(min(total_norm, max_norm * 2.0, 1.0))
self.gradient_norm_queue.append(min(total_norm.item(), max_norm * 2.0, 1.0))
if total_norm > max_norm * self.gradient_clip_value:
logger.warn(F'Clipping gradients with total norm {round(total_norm, 5)} '
logger.warn(F'Clipping gradients with total norm {round(total_norm.item(), 5)} '
F'and max norm {round(max_norm, 5)}')

def swa_init(self):
Expand All @@ -118,7 +121,7 @@ def swa_step(self):
beta = 1.0 / swa_state['models_num']
with torch.no_grad():
for n, p in self.model.named_parameters():
swa_state[n].mul_(1.0 - beta).add_(beta, p.data)
swa_state[n].mul_(1.0 - beta).add_(p.data, alpha=beta)

def swap_swa_params(self):
if 'swa' in self.state:
Expand Down Expand Up @@ -162,7 +165,7 @@ def predict_step(self, data_x: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
self.model.eval()
with torch.no_grad():
scores = torch.sigmoid(self.network(data_x, candidates=candidates, attn_weights=self.attn_weights))
scores, labels = torch.topk(scores * group_scores.cuda(), k)
scores, labels = torch.topk(scores * group_scores.to(self.device), k)
return scores.cpu(), candidates[np.arange(len(data_x)).reshape(-1, 1), labels.cpu()]

def train(self, *args, **kwargs):
Expand Down
25 changes: 17 additions & 8 deletions deepxml/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def __init__(self, vocab_size=None, emb_size=None, emb_init=None, emb_trainable=
if emb_size is not None:
assert emb_size == emb_init.shape[1]
vocab_size, emb_size = emb_init.shape
self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx, sparse=True,
_weight=torch.from_numpy(emb_init).float() if emb_init is not None else None)
self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx, sparse=True)
if emb_init is not None:
self.emb.weight.data.copy_(torch.from_numpy(emb_init).float())
self.emb.weight.requires_grad = emb_trainable
self.dropout = nn.Dropout(dropout)
self.padding_idx = padding_idx
Expand All @@ -54,7 +55,9 @@ def forward(self, inputs, lengths, **kwargs):
init_state = self.init_state.repeat([1, inputs.size(0), 1])
cell_init, hidden_init = init_state[:init_state.size(0)//2], init_state[init_state.size(0)//2:]
idx = torch.argsort(lengths, descending=True)
packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs[idx], lengths[idx], batch_first=True)
# In PyTorch 2.x, pack_padded_sequence requires lengths to be on CPU
lengths_cpu = lengths[idx].cpu()
packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs[idx], lengths_cpu, batch_first=True)
outputs, _ = nn.utils.rnn.pad_packed_sequence(
self.lstm(packed_inputs, (hidden_init, cell_init))[0], batch_first=True)
return self.dropout(outputs[torch.argsort(idx)])
Expand All @@ -71,7 +74,7 @@ def __init__(self, labels_num, hidden_size):

def forward(self, inputs, masks):
masks = torch.unsqueeze(masks, 1) # N, 1, L
attention = self.attention(inputs).transpose(1, 2).masked_fill(1.0 - masks, -np.inf) # N, labels_num, L
attention = self.attention(inputs).transpose(1, 2).masked_fill(~masks, -np.inf) # N, labels_num, L
attention = F.softmax(attention, -1)
return attention @ inputs # N, labels_num, hidden_size

Expand All @@ -82,14 +85,20 @@ class AttentionWeights(nn.Module):
"""
def __init__(self, labels_num, hidden_size, device_ids=None):
super(AttentionWeights, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device_ids is None:
device_ids = list(range(1, torch.cuda.device_count()))
device_ids = list(range(1, torch.cuda.device_count())) if torch.cuda.is_available() else [0]
assert labels_num >= len(device_ids)
group_size, plus_num = labels_num // len(device_ids), labels_num % len(device_ids)
self.group = [group_size + 1] * plus_num + [group_size] * (len(device_ids) - plus_num)
assert sum(self.group) == labels_num
self.emb = nn.ModuleList(nn.Embedding(size, hidden_size, sparse=True).cuda(device_ids[i])
for i, size in enumerate(self.group))
# Create embeddings on appropriate devices
if torch.cuda.is_available() and len(device_ids) > 0:
self.emb = nn.ModuleList(nn.Embedding(size, hidden_size, sparse=True).to(f'cuda:{device_ids[i]}')
for i, size in enumerate(self.group))
else:
self.emb = nn.ModuleList(nn.Embedding(size, hidden_size, sparse=True).to(self.device)
for i, size in enumerate(self.group))
std = (6.0 / (labels_num + hidden_size)) ** 0.5
with torch.no_grad():
for emb in self.emb:
Expand Down Expand Up @@ -119,7 +128,7 @@ def forward(self, inputs, masks, candidates, attn_weights: nn.Module):
masks = torch.unsqueeze(masks, 1) # N, 1, L
attn_inputs = inputs.transpose(1, 2) # N, hidden, L
attn_weights = self.attention(candidates) if hasattr(self, 'attention') else attn_weights(candidates)
attention = (attn_weights @ attn_inputs).masked_fill(1.0 - masks, -np.inf) # N, sampled_size, L
attention = (attn_weights @ attn_inputs).masked_fill(~masks, -np.inf) # N, sampled_size, L
attention = F.softmax(attention, -1) # N, sampled_size, L
return attention @ inputs # N, sampled_size, hidden_size

Expand Down
10 changes: 5 additions & 5 deletions deepxml/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,19 @@ def make_sparse(values):

p.data.add_(make_sparse(-step_size * numer.div_(denom)))
if weight_decay > 0.0:
p.data.add_(-group['lr'] * weight_decay, p.data.sparse_mask(grad))
p.data.add_(p.data.sparse_mask(grad), alpha=-group['lr'] * weight_decay)
else:
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.addcdiv_(-step_size, exp_avg, denom)
p.data.addcdiv_(exp_avg, denom, value=-step_size)
if weight_decay > 0.0:
p.data.add_(-group['lr'] * weight_decay, p.data)
p.data.add_(p.data, alpha=-group['lr'] * weight_decay)

return loss
14 changes: 8 additions & 6 deletions deepxml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def train_level(self, level, train_x, train_y, valid_x, valid_y):
if level == 0:
while not os.path.exists(F'{self.groups_path}-Level-{level}.npy'):
time.sleep(30)
groups = np.load(F'{self.groups_path}-Level-{level}.npy')
groups = np.load(F'{self.groups_path}-Level-{level}.npy', allow_pickle=True)
train_y, valid_y = self.get_mapping_y(groups, self.labels_num, train_y, valid_y)
labels_num = len(groups)
train_loader = DataLoader(MultiLabelDataset(train_x, train_y),
Expand All @@ -83,7 +83,8 @@ def train_level(self, level, train_x, train_y, valid_x, valid_y):
return train_y, model.predict(train_loader, k=self.top), model.predict(valid_loader, k=self.top)
else:
train_group_y, train_group, valid_group = self.train_level(level - 1, train_x, train_y, valid_x, valid_y)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

logger.info('Getting Candidates')
_, group_labels = train_group
Expand Down Expand Up @@ -112,12 +113,12 @@ def train_level(self, level, train_x, train_y, valid_x, valid_y):
if level < self.level - 1:
while not os.path.exists(F'{self.groups_path}-Level-{level}.npy'):
time.sleep(30)
groups = np.load(F'{self.groups_path}-Level-{level}.npy')
groups = np.load(F'{self.groups_path}-Level-{level}.npy', allow_pickle=True)
train_y, valid_y = self.get_mapping_y(groups, self.labels_num, train_y, valid_y)
labels_num, last_groups = len(groups), self.get_inter_groups(len(groups))
else:
groups, labels_num = None, train_y.shape[1]
last_groups = np.load(F'{self.groups_path}-Level-{level-1}.npy')
last_groups = np.load(F'{self.groups_path}-Level-{level-1}.npy', allow_pickle=True)

train_loader = DataLoader(XMLDataset(train_x, train_y, labels_num=labels_num,
groups=last_groups, group_labels=group_candidates),
Expand Down Expand Up @@ -169,11 +170,12 @@ def predict_level(self, level, test_x, k, labels_num):
return model.predict(test_loader, k=k)
else:
if level == self.level - 1:
groups = np.load(F'{self.groups_path}-Level-{level-1}.npy')
groups = np.load(F'{self.groups_path}-Level-{level-1}.npy', allow_pickle=True)
else:
groups = self.get_inter_groups(labels_num)
group_scores, group_labels = self.predict_level(level - 1, test_x, self.top, len(groups))
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(F'Predicting Level-{level}, Top: {k}')
if model is None:
model = XMLModel(network=FastAttentionRNN, labels_num=labels_num,
Expand Down
4 changes: 2 additions & 2 deletions ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
def main(prefix, trees):
labels, scores = [], []
for i in range(trees):
labels.append(np.load(F'{prefix}-Tree-{i}-labels.npy'))
scores.append(np.load(F'{prefix}-Tree-{i}-scores.npy'))
labels.append(np.load(F'{prefix}-Tree-{i}-labels.npy', allow_pickle=True))
scores.append(np.load(F'{prefix}-Tree-{i}-scores.npy', allow_pickle=True))
ensemble_labels, ensemble_scores = [], []
for i in tqdm(range(len(labels[0]))):
s = defaultdict(float)
Expand Down
4 changes: 2 additions & 2 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
@click.option('-a', type=click.FLOAT, default=0.55, help='Parameter A for propensity score.')
@click.option('-b', type=click.FLOAT, default=1.5, help='Parameter B for propensity score.')
def main(results, targets, train_labels, a, b):
res, targets = np.load(results), np.load(targets)
res, targets = np.load(results, allow_pickle=True), np.load(targets, allow_pickle=True)
mlb = MultiLabelBinarizer(sparse_output=True)
targets = mlb.fit_transform(targets)
print('Precision@1,3,5:', get_p_1(res, targets, mlb), get_p_3(res, targets, mlb), get_p_5(res, targets, mlb))
print('nDCG@1,3,5:', get_n_1(res, targets, mlb), get_n_3(res, targets, mlb), get_n_5(res, targets, mlb))
if train_labels is not None:
train_labels = np.load(train_labels)
train_labels = np.load(train_labels, allow_pickle=True)
inv_w = get_inv_propensity(mlb.transform(train_labels), a, b)
print('PSPrecision@1,3,5:', get_psp_1(res, targets, inv_w, mlb), get_psp_3(res, targets, inv_w, mlb),
get_psp_5(res, targets, inv_w, mlb))
Expand Down
2 changes: 1 addition & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(text_path, tokenized_path, label_path, vocab_path, emb_path, w2v_model,
vocab, emb_init = build_vocab(fp, w2v_model, vocab_size=vocab_size)
np.save(vocab_path, vocab)
np.save(emb_path, emb_init)
vocab = {word: i for i, word in enumerate(np.load(vocab_path))}
vocab = {word: i for i, word in enumerate(np.load(vocab_path, allow_pickle=True))}
logger.info(F'Vocab Size: {len(vocab)}')

logger.info(F'Getting Dataset: {text_path} Max Length: {max_len}')
Expand Down
22 changes: 11 additions & 11 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
click==7.0
ruamel.yaml==0.16.5
numpy==1.16.2
scipy==1.3.1
scikit-learn==0.21.2
gensim==3.4.0
torch==1.0.1
nltk==3.4
tqdm==4.31.1
joblib==0.13.2
logzero==1.5.0
click>=8.0.0
ruamel.yaml>=0.17.0
numpy>=1.24.0
scipy>=1.14.0
scikit-learn>=1.3.0
gensim>=4.3.0
torch>=2.5.1
nltk>=3.8.0
tqdm>=4.65.0
joblib>=1.3.0
logzero>=1.7.0