forked from amd/RyzenAI-SW
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathuser_script.py
More file actions
23 lines (16 loc) · 753 Bytes
/
user_script.py
File metadata and controls
23 lines (16 loc) · 753 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import torch
def load_pytorch_origin_model(torch_hub_model_path):
return torch.hub.load("pytorch/vision:v0.10.0", "resnet50", pretrained=True)
class DataLoader:
def __init__(self, batchsize):
self.batchsize = batchsize
def __getitem__(self, idx):
input_data = torch.rand((self.batchsize, 3, 224, 224), dtype=torch.float16)
label = None
return input_data, label
def create_dataloader(data_dir, batchsize, *args, **kwargs):
return DataLoader(batchsize)