From 39cbf956ee26f513342721ca40d8efae1baa77dc Mon Sep 17 00:00:00 2001 From: TITC Date: Mon, 23 May 2022 22:14:08 +0800 Subject: [PATCH 1/7] compatible with torch.utils.data.DataLoader --- pix2tex/dataset/dataset.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index aa6dec5..ff68ea9 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterableDataset import numpy as np import imagesize import logging @@ -15,10 +16,10 @@ from pix2tex.utils.utils import in_model_path from pix2tex.dataset.transforms import train_transform, test_transform +import math - -class Im2LatexDataset: +class Im2LatexDataset(IterableDataset): keep_smaller_batches = False shuffle = True batchsize = 16 @@ -75,13 +76,14 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba self.data[(width, height)].append((eqs[self.indices[i]], im)) except KeyboardInterrupt: pass + # formula&image pairs grouped by image size self.data = dict(self.data) self._get_size() iter(self) def __len__(self): - return self.size + return self.size # total number of batches given the batchsize def __iter__(self): self.i = 0 @@ -101,6 +103,7 @@ def __iter__(self): self.pairs = np.random.permutation(np.array(self.pairs, dtype=object)) else: self.pairs = np.array(self.pairs, dtype=object) + self.pairs = self.pairs[self.start:self.end] self.size = len(self.pairs) return self @@ -121,6 +124,8 @@ def prepare_data(self, batch): """ eqs, ims = batch.T + # for im in ims: + # print(im,self.img_list.index(im), len([_ for _ in self.img_list if _ ==im]),len(self.img_list),hash("".join(self.img_list))) tok = self.tokenizer(list(eqs), return_token_type_ids=False) # pad with bos and eos token for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]): @@ -169,6 +174,9 @@ def load(self, filename, args=[]): filename = os.path.realpath(tempf) with open(filename, 'rb') as file: x = pickle.load(file) + x.start = 0 + x.end = x.size + # x.img_list = [_[1] for ele in x.pairs for _ in ele] return x def combine(self, x): @@ -230,6 +238,20 @@ def generate_tokenizer(equations, output, vocab_size): tokenizer.save(path=output, pretty=False) +def worker_init_fn(worker_id): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset # the dataset copy in this worker process + overall_start = dataset.start + overall_end = dataset.size + # configure the dataset to only process the split workload + per_worker = int(math.ceil((overall_end - overall_start) / + float(worker_info.num_workers))) + worker_id = worker_info.id + dataset.start = overall_start + worker_id * per_worker + dataset.end = min(dataset.start + per_worker, overall_end) + + + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Train model', add_help=False) From 26583afff83496b852a1fb301337aec3d86a73f7 Mon Sep 17 00:00:00 2001 From: TITC Date: Mon, 23 May 2022 23:39:06 +0800 Subject: [PATCH 2/7] fix bugs -shuffle multi-times at subprocess --- pix2tex/dataset/dataset.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index ff68ea9..2438c9c 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -68,6 +68,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba self.pad = pad self.keep_smaller_batches = keep_smaller_batches self.test = test + self.subprocess = False # check the image dimension for every image and group them together try: for i, im in tqdm(enumerate(self.images), total=len(self.images)): @@ -88,22 +89,23 @@ def __len__(self): def __iter__(self): self.i = 0 self.transform = test_transform if self.test else train_transform - self.pairs = [] - for k in self.data: - info = np.array(self.data[k], dtype=object) - p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info)) - for i in range(0, len(info), self.batchsize): - batch = info[p[i:i+self.batchsize]] - if len(batch.shape) == 1: - batch = batch[None, :] - if len(batch) < self.batchsize and not self.keep_smaller_batches: - continue - self.pairs.append(batch) + if not self.subprocess: + self.pairs = [] + for k in self.data: + info = np.array(self.data[k], dtype=object) + p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info)) + for i in range(0, len(info), self.batchsize): + batch = info[p[i:i+self.batchsize]] + if len(batch.shape) == 1: + batch = batch[None, :] + if len(batch) < self.batchsize and not self.keep_smaller_batches: + continue + self.pairs.append(batch) + self.pairs = self.pairs[self.start:self.end] if self.shuffle: self.pairs = np.random.permutation(np.array(self.pairs, dtype=object)) else: self.pairs = np.array(self.pairs, dtype=object) - self.pairs = self.pairs[self.start:self.end] self.size = len(self.pairs) return self @@ -125,7 +127,7 @@ def prepare_data(self, batch): eqs, ims = batch.T # for im in ims: - # print(im,self.img_list.index(im), len([_ for _ in self.img_list if _ ==im]),len(self.img_list),hash("".join(self.img_list))) + # print(im) tok = self.tokenizer(list(eqs), return_token_type_ids=False) # pad with bos and eos token for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]): @@ -176,7 +178,7 @@ def load(self, filename, args=[]): x = pickle.load(file) x.start = 0 x.end = x.size - # x.img_list = [_[1] for ele in x.pairs for _ in ele] + x.subprocess = False return x def combine(self, x): @@ -249,6 +251,7 @@ def worker_init_fn(worker_id): worker_id = worker_info.id dataset.start = overall_start + worker_id * per_worker dataset.end = min(dataset.start + per_worker, overall_end) + dataset.subprocess=True From cf2832dc75fb20cceda7740c7a9223aadb00b6c4 Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Tue, 24 May 2022 00:13:49 +0200 Subject: [PATCH 3/7] worker_init_fn -> iter, workaround for shuffle --- pix2tex/dataset/dataset.py | 70 ++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index 2438c9c..11e22d0 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -34,6 +34,7 @@ class Im2LatexDataset(IterableDataset): eos_token_id = 2 transform = train_transform data = defaultdict(lambda: []) + permutation = None def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_seq_len=1024, max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False, test=False): @@ -43,7 +44,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba equations (str, optional): Path to equations. Defaults to None. images (str, optional): Directory where images are saved. Defaults to None. tokenizer (str, optional): Path to saved tokenizer. Defaults to None. - shuffle (bool, opitonal): Defaults to True. + shuffle (bool, opitonal): Defaults to True. batchsize (int, optional): Defaults to 16. max_seq_len (int, optional): Defaults to 1024. max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle @@ -68,7 +69,6 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba self.pad = pad self.keep_smaller_batches = keep_smaller_batches self.test = test - self.subprocess = False # check the image dimension for every image and group them together try: for i, im in tqdm(enumerate(self.images), total=len(self.images)): @@ -80,7 +80,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba # formula&image pairs grouped by image size self.data = dict(self.data) self._get_size() - + self._shuffle() iter(self) def __len__(self): @@ -89,23 +89,27 @@ def __len__(self): def __iter__(self): self.i = 0 self.transform = test_transform if self.test else train_transform - if not self.subprocess: - self.pairs = [] - for k in self.data: - info = np.array(self.data[k], dtype=object) - p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info)) - for i in range(0, len(info), self.batchsize): - batch = info[p[i:i+self.batchsize]] - if len(batch.shape) == 1: - batch = batch[None, :] - if len(batch) < self.batchsize and not self.keep_smaller_batches: - continue - self.pairs.append(batch) - self.pairs = self.pairs[self.start:self.end] - if self.shuffle: - self.pairs = np.random.permutation(np.array(self.pairs, dtype=object)) + self.pairs = [] + for k in self.data: + info = np.array(self.data[k], dtype=object) + for i in range(0, len(info), self.batchsize): + batch = info[i:i+self.batchsize] + if len(batch.shape) == 1: + batch = batch[None, :] + if len(batch) < self.batchsize and not self.keep_smaller_batches: + continue + self.pairs.append(batch) + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + # configure the dataset to only process the split workload + per_worker = int(math.ceil(self.size/float(worker_info.num_workers))) + worker_id = worker_info.id + self.start = worker_id * per_worker + self.end = min(self.start + per_worker, self.size) else: - self.pairs = np.array(self.pairs, dtype=object) + self.start, self.end = 0, self.size + + self.pairs = np.array(self.pairs, dtype=object)[self.permutation[self.start:self.end]] self.size = len(self.pairs) return self @@ -162,6 +166,15 @@ def _get_size(self): for k in self.data: div, mod = divmod(len(self.data[k]), self.batchsize) self.size += div # + (1 if mod > 0 else 0) + if self.permutation is None or len(self.permutation) != self.size: + self._shuffle() + + def _shuffle(self): + if self.shuffle: + self.permutation = np.random.permutation(self.size) + else: + self.permutation = np.arange(self.size) + return self def load(self, filename, args=[]): """returns a pickled version of a dataset @@ -176,9 +189,7 @@ def load(self, filename, args=[]): filename = os.path.realpath(tempf) with open(filename, 'rb') as file: x = pickle.load(file) - x.start = 0 - x.end = x.size - x.subprocess = False + x._get_size() return x def combine(self, x): @@ -240,21 +251,6 @@ def generate_tokenizer(equations, output, vocab_size): tokenizer.save(path=output, pretty=False) -def worker_init_fn(worker_id): - worker_info = torch.utils.data.get_worker_info() - dataset = worker_info.dataset # the dataset copy in this worker process - overall_start = dataset.start - overall_end = dataset.size - # configure the dataset to only process the split workload - per_worker = int(math.ceil((overall_end - overall_start) / - float(worker_info.num_workers))) - worker_id = worker_info.id - dataset.start = overall_start + worker_id * per_worker - dataset.end = min(dataset.start + per_worker, overall_end) - dataset.subprocess=True - - - if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Train model', add_help=False) From da29abbc6fb947174ffa44bc47d2add16912ed2b Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Tue, 24 May 2022 12:50:43 +0200 Subject: [PATCH 4/7] Update dataset.py --- pix2tex/dataset/dataset.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index 11e22d0..3a7bcac 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import IterableDataset +from torch.utils.data import IterableDataset, DataLoader import numpy as np import imagesize import logging @@ -240,6 +240,17 @@ def update(self, **kwargs): iter(self) +class Dataloader(DataLoader): + def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, *args, **kwargs): + self.dataset = dataset + self.dataset.update(batchsize=batch_size, shuffle=shuffle) + super().__init__(self.dataset, *args, shuffle=False, batch_size=None, **kwargs) + + def __iter__(self): + self.dataset._shuffle() + return super().__iter__() + + def generate_tokenizer(equations, output, vocab_size): from tokenizers import Tokenizer, pre_tokenizers from tokenizers.models import BPE From 9c6f96e538b49382d6e9f79c695a586bdcff6ddf Mon Sep 17 00:00:00 2001 From: TITC Date: Tue, 24 May 2022 19:58:42 +0800 Subject: [PATCH 5/7] small modify https://github.com/lukas-blecher/LaTeX-OCR/pull/154#issuecomment-1135765832 --- pix2tex/dataset/dataset.py | 2 +- pix2tex/eval.py | 6 +++--- pix2tex/model/settings/config-vit.yaml | 1 + pix2tex/model/settings/config.yaml | 1 + pix2tex/model/settings/debug.yaml | 4 ++++ pix2tex/train.py | 20 ++++++++++---------- pix2tex/utils/utils.py | 1 + 7 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index 3a7bcac..8815405 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -243,7 +243,7 @@ def update(self, **kwargs): class Dataloader(DataLoader): def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, *args, **kwargs): self.dataset = dataset - self.dataset.update(batchsize=batch_size, shuffle=shuffle) + self.dataset.update(batchsize=batch_size, shuffle=shuffle, *args, **kwargs) super().__init__(self.dataset, *args, shuffle=False, batch_size=None, **kwargs) def __iter__(self): diff --git a/pix2tex/eval.py b/pix2tex/eval.py index 8742988..e66364d 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -1,4 +1,4 @@ -from pix2tex.dataset.dataset import Im2LatexDataset +from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader import argparse import logging import yaml @@ -28,12 +28,12 @@ def detokenize(tokens, tokenizer): @torch.no_grad() -def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'): +def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = None, name: str = 'test'): """evaluates the model. Returns bleu score on the dataset Args: model (torch.nn.Module): the model - dataset (Im2LatexDataset): test dataset + dataset (Dataloader): test dataset args (Munch): arguments num_batches (int): How many batches to evaluate on. Defaults to None (all batches). name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'. diff --git a/pix2tex/model/settings/config-vit.yaml b/pix2tex/model/settings/config-vit.yaml index 3d94e84..162880a 100644 --- a/pix2tex/model/settings/config-vit.yaml +++ b/pix2tex/model/settings/config-vit.yaml @@ -1,4 +1,5 @@ gpu_devices: null #[0,1,2,3,4,5,6,7] +num_workers: 0 betas: - 0.9 - 0.999 diff --git a/pix2tex/model/settings/config.yaml b/pix2tex/model/settings/config.yaml index a579f9e..c19dca5 100644 --- a/pix2tex/model/settings/config.yaml +++ b/pix2tex/model/settings/config.yaml @@ -1,4 +1,5 @@ gpu_devices: null #[0,1,2,3,4,5,6,7] +num_workers: 0 backbone_layers: - 2 - 3 diff --git a/pix2tex/model/settings/debug.yaml b/pix2tex/model/settings/debug.yaml index 94e3b77..7026fa2 100644 --- a/pix2tex/model/settings/debug.yaml +++ b/pix2tex/model/settings/debug.yaml @@ -65,3 +65,7 @@ pad: False pad_token: 0 bos_token: 1 eos_token: 2 + +#devices(GPU&CPU) +num_workers: 0 +gpu_devices: null #[0,1,2,3,4,5,6,7] \ No newline at end of file diff --git a/pix2tex/train.py b/pix2tex/train.py index bd2f599..5e8d21e 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -1,4 +1,4 @@ -from pix2tex.dataset.dataset import Im2LatexDataset +from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader import os import argparse import logging @@ -16,12 +16,12 @@ def train(args): - dataloader = Im2LatexDataset().load(args.data) - dataloader.update(**args, test=False) - valdataloader = Im2LatexDataset().load(args.valdata) + train_dataset = Im2LatexDataset().load(args.data) + train_dataloader = Dataloader(train_dataset, **args, test=False) + val_dataset = Im2LatexDataset().load(args.valdata) valargs = args.copy() valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) - valdataloader.update(**valargs) + val_dataloader = Dataloader(val_dataset, **valargs) device = args.device model = get_model(args) if torch.cuda.is_available() and not args.no_cuda: @@ -47,7 +47,7 @@ def save_models(e, step=0): try: for e in range(args.epoch, args.epochs): args.epoch = e - dset = tqdm(iter(dataloader)) + dset = tqdm(iter(train_dataloader)) for i, (seq, im) in enumerate(dset): if seq is not None and im is not None: opt.zero_grad() @@ -63,20 +63,20 @@ def save_models(e, step=0): dset.set_description('Loss: %.4f' % total_loss) if args.wandb: wandb.log({'train/loss': total_loss}) - if (i+1+len(dataloader)*e) % args.sample_freq == 0: - bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') + if (i+1+len(train_dataloader)*e) % args.sample_freq == 0: + bleu_score, edit_distance, token_accuracy = evaluate(model, val_dataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val') if bleu_score > max_bleu and token_accuracy > max_token_acc: max_bleu, max_token_acc = bleu_score, token_accuracy save_models(e, step=i) if (e+1) % args.save_freq == 0: - save_models(e, step=len(dataloader)) + save_models(e, step=len(train_dataloader)) if args.wandb: wandb.log({'train/epoch': e+1}) except KeyboardInterrupt: if e >= 2: save_models(e, step=i) raise KeyboardInterrupt - save_models(e, step=len(dataloader)) + save_models(e, step=len(train_dataloader)) if __name__ == '__main__': diff --git a/pix2tex/utils/utils.py b/pix2tex/utils/utils.py index 2b5f920..cff29ff 100644 --- a/pix2tex/utils/utils.py +++ b/pix2tex/utils/utils.py @@ -55,6 +55,7 @@ def parse_args(args, **kwargs) -> Munch: args.update(kwargs) args.wandb = not kwargs.debug and not args.debug args.device = get_device(args, kwargs.no_cuda) + args.num_workers = args.get('num_workers', 0) args.max_dimensions = [args.max_width, args.max_height] args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)] if 'decoder_args' not in args or args.decoder_args is None: From 20c0220ca86cb2f88cc26a6efa6a36cd36836995 Mon Sep 17 00:00:00 2001 From: Lukas Blecher Date: Tue, 24 May 2022 18:30:34 +0200 Subject: [PATCH 6/7] dataloader specifies batch size & shuffle toggle Better split between the dataloader and dataset classes --- pix2tex/dataset/dataset.py | 8 ++++---- pix2tex/eval.py | 2 +- pix2tex/train.py | 12 +++++------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index 8815405..a4ef906 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -237,14 +237,14 @@ def update(self, **kwargs): tokenizer_file = os.path.realpath(tokenizer_file) self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) self._get_size() - iter(self) + return iter(self) class Dataloader(DataLoader): - def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, *args, **kwargs): + def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, drop_last=True, num_workers=0): self.dataset = dataset - self.dataset.update(batchsize=batch_size, shuffle=shuffle, *args, **kwargs) - super().__init__(self.dataset, *args, shuffle=False, batch_size=None, **kwargs) + self.dataset.update(batchsize=batch_size, shuffle=shuffle, keep_smaller_batches=not drop_last) + super().__init__(self.dataset, num_workers=num_workers, shuffle=False, batch_size=None) def __iter__(self): self.dataset._shuffle() diff --git a/pix2tex/eval.py b/pix2tex/eval.py index e66364d..b47f486 100644 --- a/pix2tex/eval.py +++ b/pix2tex/eval.py @@ -46,7 +46,7 @@ def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = log = {} bleus, edit_dists, token_acc = [], [], [] bleu_score, edit_distance, token_accuracy = 0, 1, 0 - pbar = tqdm(enumerate(iter(dataset)), total=len(dataset)) + pbar = tqdm(enumerate(dataset), total=len(dataset)) for i, (seq, im) in pbar: if seq is None or im is None: continue diff --git a/pix2tex/train.py b/pix2tex/train.py index 5e8d21e..f1cc1ab 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -16,12 +16,10 @@ def train(args): - train_dataset = Im2LatexDataset().load(args.data) - train_dataloader = Dataloader(train_dataset, **args, test=False) - val_dataset = Im2LatexDataset().load(args.valdata) - valargs = args.copy() - valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True) - val_dataloader = Dataloader(val_dataset, **valargs) + train_dataset = Im2LatexDataset().load(args.data).update(**args, test=False) + train_dataloader = Dataloader(train_dataset, batch_size=args.batchsize, num_workers=args.num_workers) + val_dataset = Im2LatexDataset().load(args.valdata).update(**args, test=True) + val_dataloader = Dataloader(val_dataset, batch_size=args.testbatchsize, num_workers=args.num_workers, drop_last=False) device = args.device model = get_model(args) if torch.cuda.is_available() and not args.no_cuda: @@ -47,7 +45,7 @@ def save_models(e, step=0): try: for e in range(args.epoch, args.epochs): args.epoch = e - dset = tqdm(iter(train_dataloader)) + dset = tqdm(train_dataloader) for i, (seq, im) in enumerate(dset): if seq is not None and im is not None: opt.zero_grad() From b90567d07f9c0225fe49b626ddcbf20731050b12 Mon Sep 17 00:00:00 2001 From: TITC Date: Wed, 25 May 2022 09:21:36 +0800 Subject: [PATCH 7/7] eval need tokenizer& add pin_memory --- pix2tex/dataset/dataset.py | 5 +++-- pix2tex/train.py | 4 ++-- pix2tex/utils/utils.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pix2tex/dataset/dataset.py b/pix2tex/dataset/dataset.py index a4ef906..7e7fc37 100644 --- a/pix2tex/dataset/dataset.py +++ b/pix2tex/dataset/dataset.py @@ -241,10 +241,11 @@ def update(self, **kwargs): class Dataloader(DataLoader): - def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, drop_last=True, num_workers=0): + def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, drop_last=True, num_workers=0, pin_memory=False): self.dataset = dataset + self.tokenizer = dataset.tokenizer self.dataset.update(batchsize=batch_size, shuffle=shuffle, keep_smaller_batches=not drop_last) - super().__init__(self.dataset, num_workers=num_workers, shuffle=False, batch_size=None) + super().__init__(self.dataset, num_workers=num_workers, shuffle=False, batch_size=None, pin_memory=pin_memory) def __iter__(self): self.dataset._shuffle() diff --git a/pix2tex/train.py b/pix2tex/train.py index f1cc1ab..309cd18 100644 --- a/pix2tex/train.py +++ b/pix2tex/train.py @@ -17,9 +17,9 @@ def train(args): train_dataset = Im2LatexDataset().load(args.data).update(**args, test=False) - train_dataloader = Dataloader(train_dataset, batch_size=args.batchsize, num_workers=args.num_workers) + train_dataloader = Dataloader(train_dataset, batch_size=args.batchsize, num_workers=args.num_workers, pin_memory=args.pin_memory) val_dataset = Im2LatexDataset().load(args.valdata).update(**args, test=True) - val_dataloader = Dataloader(val_dataset, batch_size=args.testbatchsize, num_workers=args.num_workers, drop_last=False) + val_dataloader = Dataloader(val_dataset, batch_size=args.testbatchsize, num_workers=args.num_workers, drop_last=False, pin_memory=args.pin_memory) device = args.device model = get_model(args) if torch.cuda.is_available() and not args.no_cuda: diff --git a/pix2tex/utils/utils.py b/pix2tex/utils/utils.py index cff29ff..e07ac4c 100644 --- a/pix2tex/utils/utils.py +++ b/pix2tex/utils/utils.py @@ -56,6 +56,7 @@ def parse_args(args, **kwargs) -> Munch: args.wandb = not kwargs.debug and not args.debug args.device = get_device(args, kwargs.no_cuda) args.num_workers = args.get('num_workers', 0) + args.pin_memory = args.get('pin_memory', False) args.max_dimensions = [args.max_width, args.max_height] args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)] if 'decoder_args' not in args or args.decoder_args is None: