1+ # 这个代码直接copy自pytorch https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
2+ # 其中实现了什么,我本人也没仔细研究,如果有问题请咨询AI或者查看pytorch的文档
3+
4+ from __future__ import annotations
5+
6+ import operator
7+ from collections import abc as container_abcs , OrderedDict
8+ from itertools import chain , islice
9+ from typing import Any , Optional , overload , TYPE_CHECKING , TypeVar , Union
10+ from typing_extensions import Self
11+
12+ if TYPE_CHECKING :
13+ from collections .abc import Iterable , Iterator , Mapping
14+
15+ from .module import Module
16+
17+ __all__ = [
18+ "Sequential" ,
19+ "ModuleList" ,
20+ ]
21+
22+ T = TypeVar ("T" , bound = Module )
23+ _V = TypeVar ("_V" )
24+
25+
26+ def _addindent (s_ , numSpaces ):
27+ s = s_ .split ("\n " )
28+ # don't do anything for single-line stuff
29+ if len (s ) == 1 :
30+ return s_
31+ first = s .pop (0 )
32+ s = [(numSpaces * " " ) + line for line in s ]
33+ s = "\n " .join (s )
34+ s = first + "\n " + s
35+ return s
36+
37+ class Sequential (Module ):
38+ # 参考 https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
39+ _modules : dict [str , Module ] # type: ignore[assignment]
40+
41+ @overload
42+ def __init__ (self , * args : Module ) -> None : ...
43+
44+ @overload
45+ def __init__ (self , arg : OrderedDict [str , Module ]) -> None : ...
46+
47+ def __init__ (self , * args ):
48+ super ().__init__ ()
49+ if len (args ) == 1 and isinstance (args [0 ], OrderedDict ):
50+ for key , module in args [0 ].items ():
51+ self .add_module (key , module )
52+ else :
53+ for idx , module in enumerate (args ):
54+ self .add_module (str (idx ), module )
55+
56+ def _get_item_by_idx (self , iterator : Iterable [_V ], idx : int ) -> _V :
57+ """Get the idx-th item of the iterator."""
58+ size = len (self )
59+ idx = operator .index (idx )
60+ if not - size <= idx < size :
61+ raise IndexError (f"index { idx } is out of range" )
62+ idx %= size
63+ return next (islice (iterator , idx , None ))
64+
65+ def __getitem__ (self , idx : Union [slice , int ]) -> Union [Sequential , Module ]:
66+ if isinstance (idx , slice ):
67+ return self .__class__ (OrderedDict (list (self ._modules .items ())[idx ]))
68+ else :
69+ return self ._get_item_by_idx (self ._modules .values (), idx )
70+
71+ def __setitem__ (self , idx : int , module : Module ) -> None :
72+ key : str = self ._get_item_by_idx (self ._modules .keys (), idx )
73+ return setattr (self , key , module )
74+
75+ def __delitem__ (self , idx : Union [slice , int ]) -> None :
76+ if isinstance (idx , slice ):
77+ for key in list (self ._modules .keys ())[idx ]:
78+ delattr (self , key )
79+ else :
80+ key = self ._get_item_by_idx (self ._modules .keys (), idx )
81+ delattr (self , key )
82+ # To preserve numbering
83+ str_indices = [str (i ) for i in range (len (self ._modules ))]
84+ self ._modules = OrderedDict (list (zip (str_indices , self ._modules .values ())))
85+
86+
87+ def __len__ (self ) -> int :
88+ return len (self ._modules )
89+
90+ def __add__ (self , other ) -> Sequential :
91+ if isinstance (other , Sequential ):
92+ ret = Sequential ()
93+ for layer in self :
94+ ret .append (layer )
95+ for layer in other :
96+ ret .append (layer )
97+ return ret
98+ else :
99+ raise ValueError (
100+ "add operator supports only objects "
101+ f"of Sequential class, but { str (type (other ))} is given."
102+ )
103+
104+ def pop (self , key : Union [int , slice ]) -> Module :
105+ v = self [key ]
106+ del self [key ]
107+ return v
108+
109+ def __iadd__ (self , other ) -> Self :
110+ if isinstance (other , Sequential ):
111+ offset = len (self )
112+ for i , module in enumerate (other ):
113+ self .add_module (str (i + offset ), module )
114+ return self
115+ else :
116+ raise ValueError (
117+ "add operator supports only objects "
118+ f"of Sequential class, but { str (type (other ))} is given."
119+ )
120+
121+ def __mul__ (self , other : int ) -> Sequential :
122+ if not isinstance (other , int ):
123+ raise TypeError (
124+ f"unsupported operand type(s) for *: { type (self )} and { type (other )} "
125+ )
126+ elif other <= 0 :
127+ raise ValueError (
128+ f"Non-positive multiplication factor { other } for { type (self )} "
129+ )
130+ else :
131+ combined = Sequential ()
132+ offset = 0
133+ for _ in range (other ):
134+ for module in self :
135+ combined .add_module (str (offset ), module )
136+ offset += 1
137+ return combined
138+
139+ def __rmul__ (self , other : int ) -> Sequential :
140+ return self .__mul__ (other )
141+
142+ def __imul__ (self , other : int ) -> Self :
143+ if not isinstance (other , int ):
144+ raise TypeError (
145+ f"unsupported operand type(s) for *: { type (self )} and { type (other )} "
146+ )
147+ elif other <= 0 :
148+ raise ValueError (
149+ f"Non-positive multiplication factor { other } for { type (self )} "
150+ )
151+ else :
152+ len_original = len (self )
153+ offset = len (self )
154+ for _ in range (other - 1 ):
155+ for i in range (len_original ):
156+ self .add_module (str (i + offset ), self ._modules [str (i )])
157+ offset += len_original
158+ return self
159+
160+ def __dir__ (self ) -> list [str ]:
161+ keys = super ().__dir__ ()
162+ keys = [key for key in keys if not key .isdigit ()]
163+ return keys
164+
165+ def __iter__ (self ) -> Iterator [Module ]:
166+ return iter (self ._modules .values ())
167+
168+ def forward (self , input ):
169+ for module in self :
170+ input = module (input )
171+ return input
172+
173+ def append (self , module : Module ) -> Self :
174+ self .add_module (str (len (self )), module )
175+ return self
176+
177+ def insert (self , index : int , module : Module ) -> Self :
178+ if not isinstance (module , Module ):
179+ raise AssertionError (f"module should be of type: { Module } " )
180+ n = len (self ._modules )
181+ if not (- n <= index <= n ):
182+ raise IndexError (f"Index out of range: { index } " )
183+ if index < 0 :
184+ index += n
185+ for i in range (n , index , - 1 ):
186+ self ._modules [str (i )] = self ._modules [str (i - 1 )]
187+ self ._modules [str (index )] = module
188+ return self
189+
190+ def extend (self , sequential : Iterable [Module ]) -> Self :
191+ for layer in sequential :
192+ self .append (layer )
193+ return self
194+
195+ class ModuleList (Module ):
196+ _modules : dict [str , Module ] # type: ignore[assignment]
197+
198+ def __init__ (self , modules : Optional [Iterable [Module ]] = None ) -> None :
199+ super ().__init__ ()
200+ if modules is not None :
201+ self += modules
202+
203+ def _get_abs_string_index (self , idx ):
204+ """Get the absolute index for the list of modules."""
205+ idx = operator .index (idx )
206+ if not (- len (self ) <= idx < len (self )):
207+ raise IndexError (f"index { idx } is out of range" )
208+ if idx < 0 :
209+ idx += len (self )
210+ return str (idx )
211+
212+ @overload
213+ def __getitem__ (self , idx : slice ) -> ModuleList : ...
214+
215+ @overload
216+ def __getitem__ (self , idx : int ) -> Module : ...
217+
218+ def __getitem__ (self , idx : Union [int , slice ]) -> Union [Module , ModuleList ]:
219+ if isinstance (idx , slice ):
220+ return self .__class__ (list (self ._modules .values ())[idx ])
221+ else :
222+ return self ._modules [self ._get_abs_string_index (idx )]
223+
224+ def __setitem__ (self , idx : int , module : Module ) -> None :
225+ idx = self ._get_abs_string_index (idx )
226+ return setattr (self , str (idx ), module )
227+
228+ def __delitem__ (self , idx : Union [int , slice ]) -> None :
229+ if isinstance (idx , slice ):
230+ for k in range (len (self ._modules ))[idx ]:
231+ delattr (self , str (k ))
232+ else :
233+ delattr (self , self ._get_abs_string_index (idx ))
234+ # To preserve numbering, self._modules is being reconstructed with modules after deletion
235+ str_indices = [str (i ) for i in range (len (self ._modules ))]
236+ self ._modules = OrderedDict (list (zip (str_indices , self ._modules .values ())))
237+
238+ def __len__ (self ) -> int :
239+ return len (self ._modules )
240+
241+ def __iter__ (self ) -> Iterator [Module ]:
242+ return iter (self ._modules .values ())
243+
244+ def __iadd__ (self , modules : Iterable [Module ]) -> Self :
245+ return self .extend (modules )
246+
247+ def __add__ (self , other : Iterable [Module ]) -> ModuleList :
248+ combined = ModuleList ()
249+ for i , module in enumerate (chain (self , other )):
250+ combined .add_module (str (i ), module )
251+ return combined
252+
253+ def __repr__ (self ) -> str :
254+ """Return a custom repr for ModuleList that compresses repeated module representations."""
255+ list_of_reprs = [repr (item ) for item in self ]
256+ if len (list_of_reprs ) == 0 :
257+ return self ._get_name () + "()"
258+
259+ start_end_indices = [[0 , 0 ]]
260+ repeated_blocks = [list_of_reprs [0 ]]
261+ for i , r in enumerate (list_of_reprs [1 :], 1 ):
262+ if r == repeated_blocks [- 1 ]:
263+ start_end_indices [- 1 ][1 ] += 1
264+ continue
265+
266+ start_end_indices .append ([i , i ])
267+ repeated_blocks .append (r )
268+
269+ lines = []
270+ main_str = self ._get_name () + "("
271+ for (start_id , end_id ), b in zip (start_end_indices , repeated_blocks ):
272+ local_repr = f"({ start_id } ): { b } " # default repr
273+
274+ if start_id != end_id :
275+ n = end_id - start_id + 1
276+ local_repr = f"({ start_id } -{ end_id } ): { n } x { b } "
277+
278+ local_repr = _addindent (local_repr , 2 )
279+ lines .append (local_repr )
280+
281+ main_str += "\n " + "\n " .join (lines ) + "\n "
282+ main_str += ")"
283+ return main_str
284+
285+ def __dir__ (self ) -> list [str ]:
286+ keys = super ().__dir__ ()
287+ keys = [key for key in keys if not key .isdigit ()]
288+ return keys
289+
290+ def insert (self , index : int , module : Module ) -> None :
291+ for i in range (len (self ._modules ), index , - 1 ):
292+ self ._modules [str (i )] = self ._modules [str (i - 1 )]
293+ self ._modules [str (index )] = module
294+
295+ def append (self , module : Module ) -> Self :
296+ self .add_module (str (len (self )), module )
297+ return self
298+
299+ def pop (self , key : Union [int , slice ]) -> Module :
300+ v = self [key ]
301+ del self [key ]
302+ return v
303+
304+ def extend (self , modules : Iterable [Module ]) -> Self :
305+ if not isinstance (modules , container_abcs .Iterable ):
306+ raise TypeError (
307+ "ModuleList.extend should be called with an "
308+ "iterable, but got " + type (modules ).__name__
309+ )
310+ offset = len (self )
311+ for i , module in enumerate (modules ):
312+ self .add_module (str (offset + i ), module )
313+ return self
0 commit comments