Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/NERDA/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@
from .preprocessing import create_dataloader
import torch
import numpy as np
from tqdm import tqdm
from nltk.tokenize import sent_tokenize, word_tokenize
from typing import List, Callable
import transformers
import sklearn.preprocessing

try:
from IPython import get_ipython
if 'IPKernelApp' in get_ipython().config:
from tqdm.notebook import tqdm
else:
from tqdm import tqdm
except:
from tqdm import tqdm

def predict(network: torch.nn.Module,
sentences: List[List[str]],
transformer_tokenizer: transformers.PreTrainedTokenizer,
Expand Down
10 changes: 9 additions & 1 deletion src/NERDA/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from transformers import AdamW, get_linear_schedule_with_warmup
import random
import torch
from tqdm import tqdm

try:
from IPython import get_ipython
if 'IPKernelApp' in get_ipython().config:
from tqdm.notebook import tqdm
else:
from tqdm import tqdm
except:
from tqdm import tqdm

def train(model, data_loader, optimizer, device, scheduler, n_tags):
"""One Iteration of Training"""
Expand Down