Skip to content

torch_wrapper doesn't like nullable arrays #1237

Open
@green-cabbage

Description

@green-cabbage

This is continuation of the discussion on mattermost. While trying to implement our Analysis's simple DNN into torch_wrapper I noticed that the inputs to the DNN in torch_wrapper must not be a nullable array, otherwise, ValueError: buffer is smaller than requested size is returned. This is solved by applying input_arr = ak.fill_none(input_arr, value=0). It is important to note that input_arr contains no None values, but the possibility that the awkward array could have None values make the code crash. With the recommendation of @lgray , I have tried to using ak.drop_none or ak.to_packed as alternatives to ak.fill_none solutions, and they don't work. Here's my simple reproducer:

from hist import Hist
import dask
import awkward as ak
import hist.dask as hda
from coffea import processor
from coffea.nanoevents.methods import candidate
from coffea.dataset_tools import (
    apply_to_fileset,
    max_chunks,
    preprocess,
)
from distributed import Client
import dask_awkward as dak
import numpy as np
from coffea.nanoevents import NanoEventsFactory
from coffea.nanoevents.schemas import PFNanoAODSchema
import awkward as ak
import dask_awkward as dak
import numpy as np

#understand coffea pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, input_shape):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_shape, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.dropout1 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.dropout2 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(64, 32)
        self.bn3 = nn.BatchNorm1d(32)
        self.dropout3 = nn.Dropout(0.2)
        self.output = nn.Linear(32, 1)

    def forward(self, features):
        x = features
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.tanh(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = F.tanh(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.bn3(x)
        x = F.tanh(x)
        x = self.dropout3(x)

        x = self.output(x)
        output = F.sigmoid(x)
        return output


data1 = dak.from_parquet(f"part000.parquet")
events = data1[:3]

n_feat = 3
model = Net(n_feat)
model.eval()
input = torch.rand(100, n_feat)
torch.jit.trace(model, input).save("test_model.pt")





from coffea.ml_tools.torch_wrapper import torch_wrapper

class DNNWrapper(torch_wrapper):
    def _create_model(self):
        model = torch.jit.load(self.torch_jit)
        model.eval()
        return model
    def prepare_awkward(self, arr):
        # The input is any awkward array with matching dimension

        # Soln #1
        default_none_val = 0
        arr = ak.fill_none(arr, value=default_none_val) # apply "fill_none" to arr in order to remove "?" label of the awkward array


        # Soln #2
        # arr = ak.drop_none(arr)


        # Soln #3
        # arr = ak.to_packed(arr)


        return [
            ak.values_astype(arr, "float32"), #only modification we do is is force float32
        ], {}


# print(events.event.compute())
input_arr = ak.concatenate( # Fold 5 event-level variables into a singular array
    [
        events.dimuon_mass[:, np.newaxis],
        events.mu2_pt[:, np.newaxis],
        events.mu1_pt[:, np.newaxis],
    ],
    axis=1,
)
print(input_arr.compute())
dwrap = DNNWrapper("test_model.pt")
dnn_score = dwrap(input_arr)
print(dnn_score) # This is the lazy evaluated dask array! Use this directly for histogram filling
print(dnn_score.compute()) # Eagerly evaluated result
print("Success!")

input parquet file for reading events (must be unzipped): part000.zip

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions