-
Notifications
You must be signed in to change notification settings - Fork 773
Description
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Linux 6.1.38-1-MANJARO # 1 SMP PREEMPT_DYNAMIC Wed Jul 5 23:49:30 UTC 2023 x86_64 GNU/Linux - Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
Name: flax
Version: 0.7.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <[email protected]>
License:
Location: /home/captainamerica/Programming/TinySeg-Net/venv311/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by:
---
Name: jax
Version: 0.4.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/captainamerica/Programming/TinySeg-Net/venv311/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, flax, objax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.13+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/captainamerica/Programming/TinySeg-Net/venv311/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, objax, optax, orbax-checkpoint
- Python version: 3.11.3
- GPU/TPU model and memory:
NVIDIA GeForce RTX 3070 - CUDA version (if applicable):
12.2
Problem you have encountered:
I have the following snippet from a free time research
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import NamedTuple
class MultiplierOut(NamedTuple):
gate1: jnp.ndarray
gate2: jnp.ndarray
gate3: jnp.ndarray
class ConvBlock(nn.Module):
inp_channels: int
out_channels: int
kernel_size: int
groups: int = 1
dilation: int = 1
padding: str = "SAME"
activation:nn.module = nn.gelu
def setup(self):
self.conv_layer = nn.Conv(
self.out_channels,
[self.kernel_size, self.kernel_size],
padding = self.padding,
kernel_dilation = self.dilation,
feature_group_count = self.groups,
)
#self.activation = nn.gelu
#self.activation = getattr(nn, self.activation)
self.batch_norm = nn.BatchNorm(use_running_average=True)
@nn.compact
def __call__(self, x):
x = self.conv_layer(x)
x = self.batch_norm(x)
#x = nn.gelu(x)
#help(self.activation)
x = self.activation(x)
#x = self.activation(x)
return x
class InvertedResBlockWithMultiplier(nn.Module):
inp_channels: int
interim_channels: int
out_channels: int
def setup(self):
self.conv1 = ConvBlock(
self.inp_channels,
self.interim_channels,
kernel_size=1
)
self.conv2 = ConvBlock(
self.interim_channels,
self.interim_channels,
kernel_size=3,
groups = self.interim_channels,
dilation=2
)
self.conv3 = ConvBlock(
self.interim_channels,
self.out_channels,
kernel_size = 1,
)
@nn.compact
def __call__(self, x, gate_out: MultiplierOut):
assert self.interim_channels%self.interim_channels==0
conv1_out = self.conv1(x)
conv1_gated = jnp.multiply(gate_out.gate1, conv1_out)
#print("Conv2 inp", x.shape)
#print("Conv2 groups", self.conv2.groups)
conv2_out = self.conv2(conv1_out)
conv2_gated = jnp.multiply(gate_out.gate2, conv2_out)
conv3_out = self.conv3(conv2_out)
conv3_gated = jnp.multiply(gate_out.gate3, conv3_out)
return conv3_gated
class MultiplierBlock(nn.Module):
inp_channels: int
gate_1_channels: int
gate_2_channels: int
gate_3_channels: int
def setup(self):
self.gate1 = ConvBlock(
self.inp_channels,
self.gate_1_channels,
kernel_size=5,
groups=self.inp_channels
)
self.gate2 = ConvBlock(
self.inp_channels,
self.gate_2_channels,
kernel_size=5,
groups=self.inp_channels
)
self.gate3 = ConvBlock(
self.inp_channels,
self.gate_3_channels,
kernel_size=5,
groups=self.inp_channels
)
@nn.compact
def __call__(self, x):
#print(self.inp_channels, self.gate_1_channels, self.gate_2_channels, self.gate_3_channels)
gate1_out = self.gate1(x)
gate2_out = self.gate2(x)
gate3_out = self.gate3(x)
return MultiplierOut(gate1=gate1_out, gate2=gate2_out, gate3=gate3_out)
class MagnifyBlock(nn.Module):
inp_channels: int
interim_channels: int
out_channels: int
reduce: bool = True
residual: bool = True
def setup(self):
self.gate_block = MultiplierBlock(
self.inp_channels,
self.interim_channels,
self.interim_channels,
self.out_channels
)
self.value_block = ConvBlock(
self.inp_channels,
self.out_channels,
kernel_size = 1
)
self.context_block = InvertedResBlockWithMultiplier(
self.inp_channels,
self.interim_channels,
self.out_channels
)
@nn.compact
def __call__(self, x):
gate_out = self.gate_block(x)
value_out = self.value_block(x)
#print(value_out.shape, gate_out.gate1.shape, gate_out.gate2.shape, gate_out.gate3.shape)
context_out = self.context_block(value_out, gate_out)
output = value_out * nn.sigmoid(context_out)
#output = nn.multiply(value_out, nn.sigmoid(context_out))
if self.residual:
output = output + value_out
if self.reduce:
output = nn.max_pool(output, window_shape=(2, 2), strides=(2, 2))
return output
class MagnifyMobileNet(nn.Module):
inp_num_classes: int
def setup(self):
print(self.inp_num_classes)
self.inp_batch_norm = nn.BatchNorm(use_running_average = True)
self.first_conv_block = ConvBlock(3, 64, kernel_size = 3)
self.first_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
self.second_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
self.third_magnify_block = MagnifyBlock(64, 128, 64)
self.fourth_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
self.fifth_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
self.sixth_magnify_block = MagnifyBlock(64, 128, 64)
self.output_conv1 = ConvBlock(64, 32, kernel_size = 3)
self.output_conv2 = ConvBlock(
inp_channels = 32,
out_channels = 8,
kernel_size = 3,
activation = nn.sigmoid,
)
@nn.compact
def __call__(self, x):
input_shape = x.shape
x = self.inp_batch_norm(x)
#print("first_block_in", x.shape)
x = self.first_conv_block(x)
#print("first_block_out", x.shape)
x = self.second_magnify_block(x)
x = self.third_magnify_block(x)
x = self.fourth_magnify_block(x)
x = self.fifth_magnify_block(x)
x = self.sixth_magnify_block(x)
x = self.output_conv1(x)
x = self.output_conv2(x)
x = jnp.sum(x, axis = -1, keepdims=True) / 8
#x = nn.sigmoid(x)
"""
x = jax.image.resize(
x,
(
x.shape[0],
input_shape[1],
input_shape[2],
x.shape[3]
),
jax.image.ResizeMethod.CUBIC
)
"""
return x
# Create the model and apply initialization.
rng = jax.random.PRNGKey(0)
input_shape = (1, 512, 512, 3)
model = MagnifyMobileNet(inp_num_classes=3)
params = model.init(rng, jnp.ones(input_shape))
Before proceeding with model I wanted to know the throughput of the new model
For the same I used the below code
import tqdm
input_shape = (1, 512, 512, 3)
output = None
@jax.jit
def infer(x):
output = model.apply(params, x)
return output
import time
output = infer(jnp.ones(input_shape))
del output
s = time.time()
for i in tqdm.tqdm(range(1000)):
output = infer(jnp.ones(input_shape))
output = None
#output = model.apply(params, jnp.ones(input_shape))
e = time.time()
print(1000/(e-s))
The output was
36.772840093497244
I was a bit skeptical and tried another snippet where the output wont be returned by the infer fn but stored in a global variable
import tqdm
input_shape = (1, 512, 512, 3)
output = None
@jax.jit
def infer(x):
global output
output = model.apply(params, x)
import time
output = infer(jnp.ones(input_shape))
del output
s = time.time()
for i in tqdm.tqdm(range(1000)):
output = infer(jnp.ones(input_shape))
output = None
#output = model.apply(params, jnp.ones(input_shape))
e = time.time()
print(1000/(e-s))
And the output was
1228.4869997478181
The difference between the 2 benchmarking code blocks is that the first one returns the output whereas the other one stores the output in a global variable. The difference in the speeds are a ridiculous 36 fps and 1228fps respectively. Am I doing something wrong here?