1919gamma = torch .tensor (0.1 , device = device_id )
2020
2121##################################
22- # SoftDTW operation in pytorch
22+ # SoftDTW operation in pytorch
2323##################################
2424
25+
2526def softmin (args , gamma ):
26- minargs = reduce (lambda x ,y : torch .min (x ,y ), args )
27- if gamma > 0 :
28- minargs -= gamma * sum (((minargs - arg )/ gamma ).exp () for arg in args ).log ()
27+ minargs = reduce (lambda x , y : torch .min (x , y ), args )
28+ if gamma > 0 :
29+ minargs -= gamma * sum (((minargs - arg ) / gamma ).exp () for arg in args ).log ()
2930 return minargs
3031
32+
3133def SoftDTW_torch (x , y , gamma ):
3234 n , m = x .shape [1 ], y .shape [1 ]
33- x , y = x [:,None ,:], y [None ,:, :]
34- rjm1 = [torch .tensor (torch .inf , device = device_id ) for _ in range (n + 1 )]
35- rjm1 [0 ] = torch .tensor (0. , device = device_id )
35+ x , y = x [:, None , :], y [None , :, :]
36+ rjm1 = [torch .tensor (torch .inf , device = device_id ) for _ in range (n + 1 )]
37+ rjm1 [0 ] = torch .tensor (0.0 , device = device_id )
3638 torchinf = torch .tensor (torch .inf , device = device_id )
37- for j in range (1 ,m + 1 ):
39+ for j in range (1 , m + 1 ):
3840 rim1j = torchinf
39- for i in range (1 ,n + 1 ):
40- rij = (x [:,:,i - 1 ]- y [:,:,j - 1 ])** 2 + softmin ((rjm1 [i ], rjm1 [i - 1 ], rim1j ), gamma )
41- rjm1 [i - 1 ] = rim1j
41+ for i in range (1 , n + 1 ):
42+ rij = (x [:, :, i - 1 ] - y [:, :, j - 1 ]) ** 2 + softmin (
43+ (rjm1 [i ], rjm1 [i - 1 ], rim1j ), gamma
44+ )
45+ rjm1 [i - 1 ] = rim1j
4246 rim1j = rij
4347 rjm1 [i ] = rij
4448 return rij
4549
4650
47-
4851#########################################
4952# reduction function with torch and keops
5053#########################################
5154
55+
5256def fun_torch (x , y , gamma ):
53- Sxy = SoftDTW_torch (x ,y , gamma )
57+ Sxy = SoftDTW_torch (x , y , gamma )
5458 Kxy = (- Sxy ).exp ()
5559 return Kxy .sum (dim = 1 )
5660
61+
5762def fun_keops (x , y , gamma ):
58- n ,m = x .shape [1 ], y .shape [1 ]
63+ n , m = x .shape [1 ], y .shape [1 ]
5964 formula = "Exp(-SoftDTW_SqDist(x,y,gamma))"
6065 aliases = [f"x=Vi({ n } )" , f"y=Vj({ m } )" , "gamma=Pm(1)" ]
6166 Kxy = Genred (formula , aliases , reduction_op = "Sum" , axis = 1 )
62- return Kxy (x ,y ,gamma .view ((1 ,1 )))
67+ return Kxy (x , y , gamma .view ((1 , 1 )))
68+
6369
6470def fun_lazytensor (x , y , gamma ):
65- x = LazyTensor (x [:,None ,:])
66- y = LazyTensor (y [None ,:, :])
67- sdtw = x .softdtw_sqdist (y ,gamma )
71+ x = LazyTensor (x [:, None , :])
72+ y = LazyTensor (y [None , :, :])
73+ sdtw = x .softdtw_sqdist (y , gamma )
6874 K = (- sdtw ).exp ()
6975 return K .sum (axis = 1 )
7076
77+
7178##################################
7279# test
7380##################################
7481
75- #funs = (fun_torch, fun_keops, fun_lazytensor)
82+ # funs = (fun_torch, fun_keops, fun_lazytensor)
7683funs = (fun_torch , fun_lazytensor )
7784out = []
7885for fun in funs :
7986 print ("**************************" )
8087 print ("Testing " + fun .__name__ )
8188 if do_warmup :
82- fun (x [:100 ,:], y [:100 ,:], gamma )
83- fun (x [:100 ,:], y [:100 ,:], gamma )
89+ fun (x [:100 , :], y [:100 , :], gamma )
90+ fun (x [:100 , :], y [:100 , :], gamma )
8491 start = time .time ()
8592 out .append (fun (x , y , gamma ).squeeze ())
8693 end = time .time ()
@@ -89,8 +96,11 @@ def fun_lazytensor(x, y, gamma):
8996print ("******" )
9097
9198if len (out ) > 1 :
92- for k in range (1 ,len (out )):
93- print (f"relative error { funs [k ].__name__ } vs { funs [0 ].__name__ } :" , (torch .norm (out [0 ] - out [k ]) / torch .norm (out [0 ])).item ())
99+ for k in range (1 , len (out )):
100+ print (
101+ f"relative error { funs [k ].__name__ } vs { funs [0 ].__name__ } :" ,
102+ (torch .norm (out [0 ] - out [k ]) / torch .norm (out [0 ])).item (),
103+ )
94104
95105
96106if test_grad :
@@ -105,7 +115,8 @@ def fun_lazytensor(x, y, gamma):
105115 print ("time for " + fun .__name__ + " (grad):" , end - start )
106116
107117 if len (out_g ) > 1 :
108- for k in range (1 ,len (out )):
109- print (f"relative error grad { funs [k ].__name__ } vs { funs [0 ].__name__ } :" , (torch .norm (out_g [0 ] - out_g [k ]) / torch .norm (out_g [0 ])).item ())
110-
111-
118+ for k in range (1 , len (out )):
119+ print (
120+ f"relative error grad { funs [k ].__name__ } vs { funs [0 ].__name__ } :" ,
121+ (torch .norm (out_g [0 ] - out_g [k ]) / torch .norm (out_g [0 ])).item (),
122+ )
0 commit comments