1+ # ***************************************************************
2+ # Copyright (c) Jittor 2020, Author:
3+ # All Rights Reserved.
4+ # This file is subject to the terms and conditions defined in
5+ # file 'LICENSE.txt', which is part of this source code package.
6+ # ***************************************************************
7+ import jittor as jt
8+ import unittest
9+ import numpy as np
10+ from jittor import models
11+
12+ pass_this_test = False
13+ try :
14+ jt .dirty_fix_pytorch_runtime_error ()
15+ import torch
16+ import torchvision
17+ except Exception as e :
18+ pass_this_test = True
19+
20+ def get_error (a , b ):
21+ return np .abs (a - b ) / max (np .abs (a ), np .abs (b ), 1e-5 ) , np .abs (a - b )
22+
23+ def check (jt_mod , torch_mod , rtol = 1e-2 , atol = 1e-5 , mean_atol = 1e-5 ):
24+ pa = [ p for p in jt_mod .parameters () if not p .is_stop_grad () ]
25+ pb = list (torch_mod .parameters ())
26+ assert len (pa ) == len (pb )
27+ error_count = 0
28+ for a ,b in zip (pa , pb ):
29+ assert a .shape == list (b .shape ), (a .shape , b .shape , a .name ())
30+ stda , meana = np .std (a .numpy ()), np .mean (a .numpy ())
31+ stdb , meanb = np .std (b .detach ().numpy ()), np .mean (b .detach ().numpy ())
32+
33+ r_err , a_err = get_error (stda , stdb )
34+ if r_err > rtol and a_err > atol :
35+ error_count += 1
36+ print ("compare std error" , stda , stdb , r_err , a_err , a .name (), a .shape )
37+
38+ r_err , a_err = get_error (meana , meanb )
39+ if r_err > rtol and a_err > mean_atol :
40+ error_count += 1
41+ print ("compare mean error" , meana , meanb , r_err , a_err , a .name (), a .shape )
42+ assert error_count == 0
43+
44+ @unittest .skipIf (pass_this_test , f"pass init check, no torch found" )
45+ class TestInit (unittest .TestCase ):
46+ @classmethod
47+ def setUpClass (self ):
48+ jt .seed (0 )
49+ np .random .seed (0 )
50+ torch .manual_seed (0 )
51+
52+ def test_conv (self ):
53+ check (jt .nn .Conv (64 , 256 , 3 ), torch .nn .Conv2d (64 , 256 , 3 ), rtol = 1e-1 , mean_atol = 1e-3 )
54+
55+ def test_resnet (self ):
56+ check (models .resnet152 (), torchvision .models .resnet152 (), rtol = 2e-2 , mean_atol = 1e-2 )
57+
58+ if __name__ == "__main__" :
59+ unittest .main ()
0 commit comments