Skip to content

inference function block with global variables consumes less gpu ram and faster than purely functional code without global usage #3251

@TheSeriousProgrammer

Description

@TheSeriousProgrammer

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Linux 6.1.38-1-MANJARO # 1 SMP PREEMPT_DYNAMIC Wed Jul 5 23:49:30 UTC 2023 x86_64 GNU/Linux
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
Name: flax
Version: 0.7.0
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page: 
Author: 
Author-email: Flax team <[email protected]>
License: 
Location: /home/captainamerica/Programming/TinySeg-Net/venv311/lib/python3.11/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: 
---
Name: jax
Version: 0.4.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/captainamerica/Programming/TinySeg-Net/venv311/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, flax, objax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.13+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/captainamerica/Programming/TinySeg-Net/venv311/lib/python3.11/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, objax, optax, orbax-checkpoint
  • Python version: 3.11.3
  • GPU/TPU model and memory:
    NVIDIA GeForce RTX 3070
  • CUDA version (if applicable):
    12.2

Problem you have encountered:

I have the following snippet from a free time research

import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import NamedTuple

class MultiplierOut(NamedTuple):

    gate1: jnp.ndarray
    gate2: jnp.ndarray
    gate3: jnp.ndarray

class ConvBlock(nn.Module):

    inp_channels: int
    out_channels: int
    kernel_size: int
    groups: int = 1
    dilation: int = 1
    padding: str = "SAME"
    activation:nn.module = nn.gelu

    def setup(self):
        self.conv_layer = nn.Conv(
            self.out_channels, 
            [self.kernel_size, self.kernel_size], 
            padding = self.padding,
            kernel_dilation = self.dilation,
            feature_group_count = self.groups, 
        ) 
        #self.activation = nn.gelu
        #self.activation = getattr(nn, self.activation)
        self.batch_norm = nn.BatchNorm(use_running_average=True)
    
    @nn.compact
    def __call__(self, x):
        x = self.conv_layer(x)
        x = self.batch_norm(x)
        #x = nn.gelu(x)
        #help(self.activation)
        x = self.activation(x)
        #x = self.activation(x)
        return x

class InvertedResBlockWithMultiplier(nn.Module):

    inp_channels: int
    interim_channels: int
    out_channels: int

    def setup(self):

        self.conv1 = ConvBlock(
            self.inp_channels, 
            self.interim_channels, 
            kernel_size=1
        )

        self.conv2 = ConvBlock(
            self.interim_channels,
            self.interim_channels,
            kernel_size=3,
            groups = self.interim_channels,
            dilation=2
        )

        self.conv3 = ConvBlock(
            self.interim_channels,
            self.out_channels,
            kernel_size = 1,
        )

    @nn.compact
    def __call__(self, x, gate_out: MultiplierOut):
        assert self.interim_channels%self.interim_channels==0

        conv1_out = self.conv1(x)
        conv1_gated = jnp.multiply(gate_out.gate1, conv1_out)

        #print("Conv2 inp", x.shape)
        #print("Conv2 groups", self.conv2.groups)
        conv2_out = self.conv2(conv1_out)
        conv2_gated = jnp.multiply(gate_out.gate2, conv2_out)

        conv3_out = self.conv3(conv2_out)
        conv3_gated = jnp.multiply(gate_out.gate3, conv3_out)

        return conv3_gated

class MultiplierBlock(nn.Module):
    inp_channels: int
    gate_1_channels: int
    gate_2_channels: int
    gate_3_channels: int

    def setup(self):
        self.gate1 = ConvBlock(
            self.inp_channels, 
            self.gate_1_channels, 
            kernel_size=5, 
            groups=self.inp_channels
        )

        self.gate2 = ConvBlock(
            self.inp_channels, 
            self.gate_2_channels, 
            kernel_size=5, 
            groups=self.inp_channels
        )
        
        self.gate3 = ConvBlock(
            self.inp_channels, 
            self.gate_3_channels, 
            kernel_size=5, 
            groups=self.inp_channels
        )

    @nn.compact
    def __call__(self, x):

        #print(self.inp_channels, self.gate_1_channels, self.gate_2_channels, self.gate_3_channels)

        gate1_out = self.gate1(x)
        gate2_out = self.gate2(x)
        gate3_out = self.gate3(x)

        return MultiplierOut(gate1=gate1_out, gate2=gate2_out, gate3=gate3_out)

class MagnifyBlock(nn.Module):
    inp_channels: int
    interim_channels: int
    out_channels: int
    reduce: bool = True
    residual: bool = True

    def setup(self):
        self.gate_block = MultiplierBlock(
            self.inp_channels,
            self.interim_channels,
            self.interim_channels,
            self.out_channels
        )

        self.value_block = ConvBlock(
            self.inp_channels,
            self.out_channels,
            kernel_size = 1
        )

        self.context_block = InvertedResBlockWithMultiplier(
            self.inp_channels,
            self.interim_channels,
            self.out_channels
        )

    @nn.compact
    def __call__(self, x):

        gate_out = self.gate_block(x)
        value_out = self.value_block(x)        

        #print(value_out.shape, gate_out.gate1.shape, gate_out.gate2.shape, gate_out.gate3.shape)
        context_out = self.context_block(value_out, gate_out)

        output = value_out * nn.sigmoid(context_out)
        #output = nn.multiply(value_out, nn.sigmoid(context_out))

        if self.residual:
            output = output + value_out

        if self.reduce:
            output = nn.max_pool(output, window_shape=(2, 2), strides=(2, 2))

        return output

class MagnifyMobileNet(nn.Module):
    inp_num_classes: int

    def setup(self):
        print(self.inp_num_classes)
        self.inp_batch_norm = nn.BatchNorm(use_running_average = True)
        
        self.first_conv_block = ConvBlock(3, 64, kernel_size = 3)
        self.first_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
        self.second_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
        self.third_magnify_block = MagnifyBlock(64, 128, 64)
        self.fourth_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
        self.fifth_magnify_block = MagnifyBlock(64, 128, 64, reduce = False)
        self.sixth_magnify_block = MagnifyBlock(64, 128, 64)

        self.output_conv1 = ConvBlock(64, 32, kernel_size = 3)
        self.output_conv2 = ConvBlock(
            inp_channels = 32, 
            out_channels = 8, 
            kernel_size = 3, 
            activation = nn.sigmoid,
        )
        

    @nn.compact
    def __call__(self, x):

        input_shape = x.shape
        x = self.inp_batch_norm(x)
        #print("first_block_in", x.shape)
        x = self.first_conv_block(x)
        #print("first_block_out", x.shape)
        x = self.second_magnify_block(x)
        x = self.third_magnify_block(x)
        x = self.fourth_magnify_block(x)
        x = self.fifth_magnify_block(x)
        x = self.sixth_magnify_block(x)

        x = self.output_conv1(x)
        x = self.output_conv2(x)


        x = jnp.sum(x, axis = -1, keepdims=True) / 8
        #x = nn.sigmoid(x)
        """
        x = jax.image.resize(
            x, 
            (
                x.shape[0],
                input_shape[1], 
                input_shape[2],
                x.shape[3]
            ),
            jax.image.ResizeMethod.CUBIC
        )
        """

        return x

# Create the model and apply initialization.
rng = jax.random.PRNGKey(0)
input_shape = (1, 512, 512, 3)

model = MagnifyMobileNet(inp_num_classes=3)

params = model.init(rng, jnp.ones(input_shape))

Before proceeding with model I wanted to know the throughput of the new model

For the same I used the below code

import tqdm

input_shape = (1, 512, 512, 3)
output = None

@jax.jit
def infer(x):
    output = model.apply(params, x)
    return output
import time

output = infer(jnp.ones(input_shape))
del output

s = time.time()
for i in tqdm.tqdm(range(1000)):
    output = infer(jnp.ones(input_shape))
    output = None
    #output = model.apply(params, jnp.ones(input_shape))
e = time.time()
print(1000/(e-s))

The output was

36.772840093497244

I was a bit skeptical and tried another snippet where the output wont be returned by the infer fn but stored in a global variable

import tqdm

input_shape = (1, 512, 512, 3)
output = None

@jax.jit
def infer(x):
    global output
    output = model.apply(params, x)

import time

output = infer(jnp.ones(input_shape))
del output

s = time.time()
for i in tqdm.tqdm(range(1000)):
    output = infer(jnp.ones(input_shape))
    output = None
    #output = model.apply(params, jnp.ones(input_shape))
e = time.time()
print(1000/(e-s))

And the output was

1228.4869997478181

The difference between the 2 benchmarking code blocks is that the first one returns the output whereas the other one stores the output in a global variable. The difference in the speeds are a ridiculous 36 fps and 1228fps respectively. Am I doing something wrong here?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions