-
Notifications
You must be signed in to change notification settings - Fork 169
Description
In applications such as link, Neuromancer and its various losses are helpful for adapting function approximators to satisfy constraints. However, it is challenging to iteratively update these function approximators using newly acquired/sampled batches of data that are not already contained in the Neuromancer environment.
Example: x_batch is iteratively sampled from a data source outside of the Neuromancer framework. Using these samples to update mu_node generally requires creating a new DataLoader and associated Trainer, which comes with overhead.
H = 2
A, B, Q = 3*[torch.tensor([[1.]])]
# system definition
dx_fun = lambda x, u: x @ A.T + u @ B.T
dx_node = Node(dx_fun, ['x', 'u'], ['x'])
mu_node = Node(blocks.Linear(1, 1, bias=False), ['x'], ['u'])
l_fun = lambda x: Q*x**2
l_node = Node(l_fun, ['x'], ['l'])
cl_system = System([mu_node, dx_node, l_node], nsteps=H+1)
# problem definition
x, u, l = variable('x'), variable('u'), variable('l')
l_loss = Objective(var=H*l[:, :-1, :], name='stage_loss') # cost for steps k<H
loss = PenaltyLoss([l_loss], [])
problem = Problem([cl_system], loss)
opt = optim.AdamW(mu_node.parameters(), lr=0.001)
# iterative training loop
for _ in range(10):
# hypothetically, get batch of data from non-neuromancer source
x_batch = torch.randn(256, 1)
# transfere data to neuromancer framework to update model
train_dict = DictDataset({'x': x_batch.unsqueeze(1)})
train_loader = DataLoader(train_dict, batch_size=256, shuffle=True, collate_fn=train_dict.collate_fn)
trainer = Trainer(problem, train_loader,
optimizer=opt,
train_metric='train_loss',
eval_metric='train_loss',
epochs=1,
epoch_verbose=5)
trainer.current_epoch = 1 # Optional to silence trainer entirely
updated_model = trainer.train()
Additionally, to silence the trainer entirely, current_epoch need be set greater than 0, and epoch_verbose need be set > epochs. It would be convenient for the trainer accepts a verbose boolean as input to enable/disable printing entirely.