Skip to content

Feature request: Trainer flexibility for outside data #245

@tbanker

Description

@tbanker

In applications such as link, Neuromancer and its various losses are helpful for adapting function approximators to satisfy constraints. However, it is challenging to iteratively update these function approximators using newly acquired/sampled batches of data that are not already contained in the Neuromancer environment.

Example: x_batch is iteratively sampled from a data source outside of the Neuromancer framework. Using these samples to update mu_node generally requires creating a new DataLoader and associated Trainer, which comes with overhead.

H = 2
A, B, Q = 3*[torch.tensor([[1.]])]

# system definition
dx_fun = lambda x, u: x @ A.T + u @ B.T    
dx_node = Node(dx_fun, ['x', 'u'], ['x'])
mu_node = Node(blocks.Linear(1, 1, bias=False), ['x'], ['u'])
l_fun = lambda x: Q*x**2
l_node = Node(l_fun, ['x'], ['l'])
cl_system = System([mu_node, dx_node, l_node], nsteps=H+1)

# problem definition
x, u, l = variable('x'), variable('u'), variable('l')
l_loss = Objective(var=H*l[:, :-1, :], name='stage_loss') # cost for steps k<H
loss = PenaltyLoss([l_loss], [])
problem = Problem([cl_system], loss)
opt = optim.AdamW(mu_node.parameters(), lr=0.001)

# iterative training loop
for _ in range(10):
    # hypothetically, get batch of data from non-neuromancer source
    x_batch = torch.randn(256, 1)

    # transfere data to neuromancer framework to update model
    train_dict = DictDataset({'x': x_batch.unsqueeze(1)})
    train_loader = DataLoader(train_dict, batch_size=256, shuffle=True, collate_fn=train_dict.collate_fn)
    trainer = Trainer(problem, train_loader,
                      optimizer=opt,
                      train_metric='train_loss',
                      eval_metric='train_loss',
                      epochs=1,
                      epoch_verbose=5)
    trainer.current_epoch = 1 # Optional to silence trainer entirely
    updated_model = trainer.train()

Additionally, to silence the trainer entirely, current_epoch need be set greater than 0, and epoch_verbose need be set > epochs. It would be convenient for the trainer accepts a verbose boolean as input to enable/disable printing entirely.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions