Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions jeta/cnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from functools import partial
from typing import Callable, List, Tuple

import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
from jax.numpy import ndarray


class CNP:
@staticmethod
@partial(jax.jit, static_argnums=(3, 4))
def meta_train_step(
encoder_state: TrainState,
decoder_state: TrainState,
tasks,
loss_fn: Callable[[ndarray, ndarray], ndarray],
metrics: List[Callable[[ndarray, ndarray], ndarray]] = [],
) -> Tuple[TrainState, TrainState, ndarray, List[ndarray]]:
"""Performs a single meta-training step on a batch of tasks.

The fuctions first adapts to the support set and then evaluates it's perfomance
on the query set.

Args:
encoder_state (TrainState): Contains information regarding the current encoder state.
decoder_state (TrainState): Contains information regarding the current decoder state.
tasks ((x_train, y_train), (x_test, y_test)): Batch of tasks to be evaluated on.
loss_fn ((logits, targets) -> loss): Loss Function.
metrics (Tuple[(ndarray, ndarray) -> ndarray]): Tuple of metrics to be evaluated on the query set.

Returns:
Tuple[TrainState, TrainState, jnp.ndarray, List[jnp.ndarray]]: (Next_Encoder_State, Next_Decoder_State, Loss, metrics).
"""

def batch_meta_train_loss(encoder_params, decoder_params):
batch_encoder_state = encoder_state.replace(params=encoder_params)
batch_decoder_state = decoder_state.replace(params=decoder_params)

loss, metrics_value = jax.vmap(
CNP.meta_loss,
in_axes=(None, None, None, 0, None, None),
)(batch_encoder_state, batch_decoder_state, loss_fn, tasks, metrics, True)
return loss.mean(), [metric.mean() for metric in metrics_value]

(loss, metrics_value), (encoder_grads, decoder_grads) = jax.value_and_grad(
batch_meta_train_loss, argnums=(0, 1), has_aux=True
)(encoder_state.params, decoder_state.params)

encoder_state = encoder_state.apply_gradients(grads=encoder_grads)
decoder_state = decoder_state.apply_gradients(grads=decoder_grads)

return encoder_state, decoder_state, loss, metrics_value

@staticmethod
@partial(jax.jit, static_argnums=(3, 4))
def meta_test_step(
encoder_state: TrainState,
decoder_state: TrainState,
tasks,
loss_fn: Callable[[ndarray, ndarray], ndarray],
metrics: List[Callable[[ndarray, ndarray], ndarray]] = [],
) -> Tuple[ndarray, List[ndarray]]:
"""Performs a single meta-testing step on a batch of tasks.

The function first adapts to the support set and then evaluates it's perfomance
on the query set.

Args:
encoder_state (TrainState): Contains information regarding the current encoder state.
decoder_state (TrainState): Contains information regarding the current decoder state.
tasks ((x_train, y_train), (x_test, y_test)): Batch of tasks to be evaluated on.
loss_fn ((logits, targets) -> loss): Loss Function.
metrics (Tuple[(ndarray, ndarray) -> ndarray]): Tuple of metrics to be evaluated on the query set.

Returns:
Tuple[jnp.ndarray, List[jnp.ndarray]: (Loss, metrics).
"""

loss, metrics_value = jax.vmap(
CNP.meta_loss,
in_axes=(None, None, None, 0, None, None),
)(encoder_state, decoder_state, loss_fn, tasks, metrics, False)
return loss.mean(), [metric.mean() for metric in metrics_value]

@staticmethod
def meta_loss(
encoder_state: TrainState,
decoder_state: TrainState,
loss_fn,
task,
metrics,
train,
) -> Tuple[ndarray, List[ndarray]]:
"""Calculates the Meta Loss of a task

Args:
encoder_state (TrainState): Contains information regarding the current encoder state.
decoder_state (TrainState): Contains information regarding the current decoder state.
loss_fn ((logits, targets) -> loss): Loss Function.
task ((x_train, y_train), (x_test, y_test)): Task to be evaluated on.
metrics (Tuple[(ndarray, ndarray) -> ndarray]): Tuple of metrics to be evaluated on the query set.
train (bool): Whether the encoder/decoder functions are in train/test mode.

Returns:
Tuple[jnp.ndarray, List[jnp.ndarray]]: (Loss, metrics).
"""
support_set, query_set = task

# Adaptation step
r = encoder_state.apply_fn(encoder_state.params, *support_set, train=train)

# Evaluation step
x_test, y_test = query_set
logits = decoder_state.apply_fn(decoder_state.params, r, x_test, train=train)

# Calculate metrics
metrics_value = [metric(logits, y_test) for metric in metrics]

return loss_fn(logits, y_test), metrics_value


class Aggregator:
"""Class containing different aggregation functions for the aggregate step in a CNP model."""

@staticmethod
def regression(r: ndarray) -> ndarray:
"""Aggregation function for regression tasks.

This function returns the mean calculated along the batch dimention.

Args:
r (shots, ...): Encoder output for each datapoint in a batch.

Returns:
ndarray: Aggregated output.
"""
aggregates = r.mean(0)
return aggregates

@staticmethod
def classification(r: ndarray, y: ndarray, ways: int) -> ndarray:
"""Aggregation function for classification tasks.

This function returns the classwise mean calculated along the batch dimention and concatenates them

Args:
r (ways * shots, ...): Encoder output for each datapoint in a batch
y (ways * shots, ...): Target values for the support set.
ways (int): Number of classes per task.

Returns:
(ways, ...): Aggregated output.
"""
ways_idx = jnp.arange(ways)
aggregates = jax.vmap(
lambda a, b, c: jnp.where(b == c, a.T, 0).T.sum(0), in_axes=(None, None, 0)
)(r, y, ways_idx)
aggregates = aggregates.reshape(-1) / ways
return aggregates
137 changes: 137 additions & 0 deletions jeta/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import List, Tuple

import flax.linen as nn


def MLP(dims: List[int], activations="relu") -> nn.Module:
"""Create a multi-layer perceptron (MLP) module.
This consists of only Dense Layers followed by a relu activation.

Args:
dims (List[int]): dimentions of the MLP Layers.
activations (str, optional): Activation function to be used (relu, leaky relu, softmax, tanh). Defaults to 'relu'.

Returns:
nn.Module: Returns a Flax Module that can be used as a MLP.
"""

if activations.lower() == "relu":
activation = nn.relu
elif activations.lower() == "leaky relu":
activation = nn.leaky_relu
elif activations.lower() == "softmax":
activation = nn.softmax
elif activations.lower() == "tanh":
activation = nn.tanh
else:
raise ValueError(
f"Invalid activation function\nExpected: (relu, leaky relu, softmax, tanh) but got: {activations}"
)

class Model(nn.Module):
@nn.compact
def __call__(self, x, train=True):
for dim in dims[:-1]:
x = nn.Dense(dim)(x)
x = activation(x)
x = nn.Dense(dims[-1])(x)
return x

mlp = Model()
return mlp


def ConvBlock(
channels: int, kernel_size: Tuple[int], strides: Tuple[int], padding: str
) -> nn.Module:
"""Create a convolutional block module.
This consists of a convolution layer, batch normalization and a relu activation.

Args:
channels (int): Number of filters in the convolution layer.
kernel_size (Tuple[int]): Size of the convolution kernel.
strides (Tuple[int]): Strides of the convolution kernel.
padding (str): Padding of the convolution.

Returns:
nn.Module: Returns a Flax Module that can be used as a convolutional block.
"""

class Model(nn.Module):
@nn.compact
def __call__(self, x, train=True):
x = nn.Conv(channels, kernel_size, strides, padding)(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
return x

conv_block = Model()
return conv_block


def CNN(dims: List[int]) -> nn.Module:
"""Create a convolutional neural network (CNN) module.
Each Convolution block consists of a convolution layer, batch normalization and a relu activation.

Args:
dims (List[int]): Number of filters in each CNN layer.

Returns:
nn.Module: Returns a Flax Module that can be used as a CNN.
"""

class Model(nn.Module):
@nn.compact
def __call__(self, x, train=True):
for dim in dims:
x = ConvBlock(dim, (3, 3), (1, 1), "SAME")(x, train)
return x

cnn = Model()
return cnn


def RNN(dims: List[int], type: str = "lstm", activation: str = "relu") -> nn.Module:
"""Create a recurrent neural network (RNN) module.
This consists of a Recurrent(LSTM, GRU) layer followed by a relu activation.

Args:
dims (List[int]): Number of units in the Recurrent layer.
type (str, optional): Type of the Recurrent layer(lstm, rnn). Defaults to 'lstm'.
activation (str, optional): Activation function to be used (relu, leaky relu, softmax, tanh). Defaults to 'relu'.

Returns:
nn.Module: Returns a Flax Module that can be used as a RNN.
"""

if type.lower() == "lstm":
Cell = nn.OptimizedLSTMCell
elif type.lower() == "gru":
Cell = nn.GRUCell
else:
raise ValueError(f"Invalid type\nExpected: (lstm, gru) but got: {type}")

if activation.lower() == "relu":
activation = nn.relu
elif activation.lower() == "leaky relu":
activation = nn.leaky_relu
elif activation.lower() == "softmax":
activation = nn.softmax
elif activation.lower() == "tanh":
activation = nn.tanh
else:
raise ValueError(
f"Invalid activation function\nExpected: (relu, leaky relu, softmax, tanh) but got: {activation}"
)

class Model(nn.Module):
@nn.compact
def __call__(self, x, train=True):
for dim in dims[:-1]:
x = Cell(dim)(x)
x = activation(x)
x = nn.Cell(dims[-1])(x)
return x

rnn = Model()
return rnn