diff --git a/kge/model/embedder/lookup_embedder.py b/kge/model/embedder/lookup_embedder.py index 85cb41b72..fe14894fc 100644 --- a/kge/model/embedder/lookup_embedder.py +++ b/kge/model/embedder/lookup_embedder.py @@ -7,7 +7,7 @@ from kge.model import KgeEmbedder from kge.misc import round_to_points -from typing import List, Dict +from typing import List, Dict, Union class LookupEmbedder(KgeEmbedder): @@ -44,6 +44,8 @@ def __init__( # initialize weights self._init_embeddings(self._embeddings.weight.data) + self._embeddings_frozen = None + # TODO handling negative dropout because using it with ax searches for now dropout = self.get_option("dropout") if dropout < 0: @@ -89,7 +91,10 @@ def init_pretrained(self, pretrained_embedder: KgeEmbedder) -> None: ) def embed(self, indexes: Tensor) -> Tensor: - return self._postprocess(self._embeddings(indexes.long())) + return self._postprocess(self._embed(indexes)) + + def _embed(self, indexes: Tensor) -> Tensor: + return self._embeddings(indexes.long()) def embed_all(self) -> Tensor: return self._postprocess(self._embeddings_all()) @@ -109,6 +114,97 @@ def _embeddings_all(self) -> Tensor: def _get_regularize_weight(self) -> Tensor: return self.get_option("regularize_weight") + def freeze(self, freeze_indexes: Union[List, Tensor]) -> Tensor: + """Freeze the embeddings of the entities specified by freeze_indexes. + + This method overrides the _embed() and _embeddings_all() methods. + + """ + num_freeze = len(freeze_indexes) + + original_weights = self._embeddings.weight.data + + if isinstance(freeze_indexes, list): + freeze_indexes = torch.tensor( + freeze_indexes, device=self.config.get("job.device") + ).long() + + self._embeddings_frozen = torch.nn.Embedding( + num_freeze, self.dim, sparse=self.sparse, + ) + self._embeddings = torch.nn.Embedding( + self.vocab_size - num_freeze, self.dim, sparse=self.sparse, + ) + + # for a global index i stores at position i a 1 + # when it corresponds to a frozen parameter + freeze_mask = torch.zeros( + self.vocab_size, dtype=torch.bool, device=self.config.get("job.device") + ) + freeze_mask[freeze_indexes] = 1 + + # assign current values to the new embeddings + self._embeddings_frozen.weight.data = original_weights[freeze_mask] + self._embeddings.weight.data = original_weights[~freeze_mask] + + # freeze + self._embeddings_frozen.weight.requires_grad = False + + # for a global index i stores at position i its index in either the + # frozen or the non-frozen embedding tensor + global_to_local_mapper = torch.zeros( + self.vocab_size, dtype=torch.long, device=self.config.get("job.device") + ) + global_to_local_mapper[freeze_mask] = torch.arange( + num_freeze, device=self.config.get("job.device") + ) + global_to_local_mapper[~freeze_mask] = torch.arange( + self.vocab_size - num_freeze, device=self.config.get("job.device") + ) + + def _embed(indexes: Tensor) -> Tensor: + + emb = torch.empty( + (len(indexes), self.dim), device=self._embeddings.weight.device + ) + + frozen_indexes_mask = freeze_mask[indexes.long()] + + emb[frozen_indexes_mask] = self._embeddings_frozen( + global_to_local_mapper[indexes[frozen_indexes_mask].long()] + ) + + emb[~frozen_indexes_mask] = self._embeddings( + global_to_local_mapper[indexes[~frozen_indexes_mask].long()] + ) + return emb + + def _embeddings_all() -> Tensor: + + emb = torch.empty( + (self.vocab_size, self.dim), device=self._embeddings.weight.device + ) + + emb[freeze_mask] = self._embeddings_frozen( + torch.arange( + num_freeze, + dtype=torch.long, + device=self._embeddings_frozen.weight.device, + ) + ) + + emb[~freeze_mask] = self._embeddings( + torch.arange( + self.vocab_size - num_freeze, + dtype=torch.long, + device=self._embeddings.weight.device, + ) + ) + return emb + + self._embeddings_all = _embeddings_all + self._embed = _embed + def penalty(self, **kwargs) -> List[Tensor]: # TODO factor out to a utility method result = super().penalty(**kwargs) @@ -135,7 +231,7 @@ def penalty(self, **kwargs) -> List[Tensor]: unique_indexes, counts = torch.unique( kwargs["indexes"], return_counts=True ) - parameters = self._embeddings(unique_indexes) + parameters = self._embed(unique_indexes) if p % 2 == 1: parameters = torch.abs(parameters) result += [ diff --git a/kge/model/embedder/lookup_embedder.yaml b/kge/model/embedder/lookup_embedder.yaml index ea7a01f57..bda119e3d 100644 --- a/kge/model/embedder/lookup_embedder.yaml +++ b/kge/model/embedder/lookup_embedder.yaml @@ -39,6 +39,13 @@ lookup_embedder: # the packaged model # if false initialize other embeddings normally ensure_all: False + + # Freeze a subset of the embeddings during training. Expects a file with + # entity/relation ids per line. Expects either an absolute path or the filename when + # the file is located in the dataset folder. Embeddings associated with the ids + # are hold constant during training. Leave empty for not freezing embeddings. + freeze: + ids_file: "" # Dropout used for the embeddings. dropout: 0. diff --git a/kge/model/kge_model.py b/kge/model/kge_model.py index 044b7500c..bf3dd70a0 100644 --- a/kge/model/kge_model.py +++ b/kge/model/kge_model.py @@ -12,6 +12,7 @@ from kge.misc import filename_in_module from kge.util import load_checkpoint from typing import Any, Dict, List, Optional, Union, Tuple +from kge.util.io import file_to_list from typing import TYPE_CHECKING @@ -436,6 +437,49 @@ def load_pretrained_model( self._relation_embedder.init_pretrained( pretrained_relations_model.get_p_embedder() ) + # freeze embeddings if desired + for embedder, name in [ + (self._relation_embedder, "relation"), + (self._entity_embedder, "entity"), + ]: + freeze_file = embedder.get_option("freeze.ids_file") + if freeze_file != "": + if not os.path.isfile(freeze_file): + freeze_file = os.path.join(self.dataset.folder, freeze_file) + if not os.path.isfile(freeze_file): + raise FileNotFoundError( + f"Could not find freeze files for {name} embedder" + ) + else: + ids = file_to_list(freeze_file) + id_map = self.dataset.load_map(f"{name}_ids") + freeze_indexes = list( + map(lambda _id: id_map.index(_id), ids) + ) + model = self.config.get("model") + if ( + model == "reciprocal_relations_model" + and name == "relation" + ): + # this is the base model and num_relations is twice + # the number of relations already + reciprocal_indexes = list( + map( + lambda idx: idx + + self.dataset.num_relations() / 2, + freeze_indexes, + ) + ) + freeze_indexes.extend(reciprocal_indexes) + if len(freeze_indexes) > len(set(freeze_indexes)): + raise Exception( + f"Unique set of ids needed for freezing {name}'s." + ) + + self.config.log( + f"Freezing {name} embeddings found in {freeze_file}" + ) + embedder.freeze(freeze_indexes) #: Scorer self._scorer: RelationalScorer diff --git a/kge/util/io.py b/kge/util/io.py index 5ba35007b..099c94d5c 100644 --- a/kge/util/io.py +++ b/kge/util/io.py @@ -44,3 +44,10 @@ def load_checkpoint(checkpoint_file: str, device="cpu"): checkpoint["file"] = checkpoint_file checkpoint["folder"] = os.path.dirname(checkpoint_file) return checkpoint + + +def file_to_list(file: str): + """Return lines of a file as list. """ + with open(file, "r") as f: + data = f.read().rstrip("\n").splitlines() + return data diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 000000000..0ffb862f8 --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,96 @@ +import unittest +import os +import torch +from tests.util import create_config, empty_cache, get_cache_dir +from kge.misc import kge_base_dir +from kge.model.kge_model import KgeModel +from kge.job import TrainingJob +from kge.dataset import Dataset + + +class TestFreeze(unittest.TestCase): + def setUp(self) -> None: + self.dataset_name = "toy" + self.folder = os.path.join(get_cache_dir(), "test_freeze") + self.config = create_config(self.dataset_name) + self.config.folder = self.folder + self.config.init_folder() + self.config.set("train.max_epochs", 1) + self.dataset = Dataset.create(config=self.config) + + def tearDown(self) -> None: + empty_cache() + + def test_freeze(self) -> None: + """Test if frozen embeddings are correctly frozen. + + Ensure, after calling freeze() of the LookupEmbedder, embeddings are hold + constant during training. + + """ + + model = KgeModel.create(config=self.config, dataset=self.dataset) + + # freeze every other entity and relation embedding + freeze_indexes_ent = list(range(0, model.dataset.num_entities(), 2)) + freeze_indexes_rel = list(range(0, model.dataset.num_relations(), 2)) + + entity_embedder = model.get_o_embedder() + relation_embedder = model.get_p_embedder() + + # copy before freeze + frozen_emb_rel = ( + relation_embedder.embed(torch.tensor(freeze_indexes_rel)).clone().detach() + ) + + frozen_emb_ent = ( + entity_embedder.embed(torch.tensor(freeze_indexes_ent)).clone().detach() + ) + + # freeze + entity_embedder.freeze(freeze_indexes_ent) + relation_embedder.freeze(freeze_indexes_rel) + + training_job = TrainingJob.create( + config=model.config, dataset=model.dataset, model=model + ) + training_job.run() + + frozen_emb_rel_after = relation_embedder.embed(torch.tensor(freeze_indexes_rel)) + frozen_emb_ent_after = entity_embedder.embed(torch.tensor(freeze_indexes_ent)) + + # Ensure the frozen embeddings have not been changed + self.assertTrue( + torch.all(torch.eq(frozen_emb_ent, frozen_emb_ent_after)), + msg="Frozen parameter changed during training", + ) + + self.assertTrue( + torch.all(torch.eq(frozen_emb_rel, frozen_emb_rel_after)), + msg="Frozen parameter changed during training", + ) + + def test_scores_after_freeze(self) -> None: + """Test if score calculation is correct after calling freeze() on Embeddings.""" + + model = KgeModel.create(config=self.config, dataset=self.dataset) + + # freeze every other entity and relation embedding + freeze_indexes_ent = list(range(0, model.dataset.num_entities(), 2)) + freeze_indexes_rel = list(range(0, model.dataset.num_relations(), 2)) + + entity_embedder = model.get_o_embedder() + relation_embedder = model.get_p_embedder() + + triples = self.dataset.split("train") + scores_before = model.score_spo(triples[:, 0], triples[:, 1], triples[:, 2]) + + entity_embedder.freeze(freeze_indexes_ent) + relation_embedder.freeze(freeze_indexes_rel) + + scores_after = model.score_spo(triples[:, 0], triples[:, 1], triples[:, 2]) + + self.assertTrue( + torch.all(torch.eq(scores_before, scores_after)), + msg="Model score computation has changed after calling freeze." + ) diff --git a/tests/util.py b/tests/util.py index 70e4349bc..641e6f6e5 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,7 +1,8 @@ import os from kge import Config from kge.misc import kge_base_dir - +from os import path +import shutil def create_config(test_dataset_name: str, model: str = "complex") -> Config: config = Config() @@ -16,3 +17,16 @@ def create_config(test_dataset_name: str, model: str = "complex") -> Config: def get_dataset_folder(dataset_name): return os.path.join(kge_base_dir(), "tests", "data", dataset_name) + + +def get_cache_dir(): + return os.path.join(kge_base_dir(), "tests", "data", "cache") + + +def empty_cache(): + for file in os.listdir(get_cache_dir()): + obj = path.join(get_cache_dir(), file) + if os.path.isfile(obj): + os.remove(obj) + else: + shutil.rmtree(obj)