Skip to content

Commit d0b3047

Browse files
authored
py: (#82)
1.mlp 2.container:ModuleList,Sequential 3.transformer
1 parent eef3593 commit d0b3047

File tree

10 files changed

+424
-81
lines changed

10 files changed

+424
-81
lines changed
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from .module import Module, Sequential
1+
from .module import Module
2+
from .container import Sequential, ModuleList
23
from .linear import Linear
34
from .sparse import Embedding
5+
46
__all__ = [
57
"Module",
6-
"Linear",
7-
"Sequential",
8+
"Sequential","ModuleList",
89
"Embedding",
10+
"Linear",
911
]
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# 这个代码直接copy自pytorch https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
2+
# 其中实现了什么,我本人也没仔细研究,如果有问题请咨询AI或者查看pytorch的文档
3+
4+
from __future__ import annotations
5+
6+
import operator
7+
from collections import abc as container_abcs, OrderedDict
8+
from itertools import chain, islice
9+
from typing import Any, Optional, overload, TYPE_CHECKING, TypeVar, Union
10+
from typing_extensions import Self
11+
12+
if TYPE_CHECKING:
13+
from collections.abc import Iterable, Iterator, Mapping
14+
15+
from .module import Module
16+
17+
__all__ = [
18+
"Sequential",
19+
"ModuleList",
20+
]
21+
22+
T = TypeVar("T", bound=Module)
23+
_V = TypeVar("_V")
24+
25+
26+
def _addindent(s_, numSpaces):
27+
s = s_.split("\n")
28+
# don't do anything for single-line stuff
29+
if len(s) == 1:
30+
return s_
31+
first = s.pop(0)
32+
s = [(numSpaces * " ") + line for line in s]
33+
s = "\n".join(s)
34+
s = first + "\n" + s
35+
return s
36+
37+
class Sequential(Module):
38+
# 参考 https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
39+
_modules: dict[str, Module] # type: ignore[assignment]
40+
41+
@overload
42+
def __init__(self, *args: Module) -> None: ...
43+
44+
@overload
45+
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
46+
47+
def __init__(self, *args):
48+
super().__init__()
49+
if len(args) == 1 and isinstance(args[0], OrderedDict):
50+
for key, module in args[0].items():
51+
self.add_module(key, module)
52+
else:
53+
for idx, module in enumerate(args):
54+
self.add_module(str(idx), module)
55+
56+
def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V:
57+
"""Get the idx-th item of the iterator."""
58+
size = len(self)
59+
idx = operator.index(idx)
60+
if not -size <= idx < size:
61+
raise IndexError(f"index {idx} is out of range")
62+
idx %= size
63+
return next(islice(iterator, idx, None))
64+
65+
def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]:
66+
if isinstance(idx, slice):
67+
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
68+
else:
69+
return self._get_item_by_idx(self._modules.values(), idx)
70+
71+
def __setitem__(self, idx: int, module: Module) -> None:
72+
key: str = self._get_item_by_idx(self._modules.keys(), idx)
73+
return setattr(self, key, module)
74+
75+
def __delitem__(self, idx: Union[slice, int]) -> None:
76+
if isinstance(idx, slice):
77+
for key in list(self._modules.keys())[idx]:
78+
delattr(self, key)
79+
else:
80+
key = self._get_item_by_idx(self._modules.keys(), idx)
81+
delattr(self, key)
82+
# To preserve numbering
83+
str_indices = [str(i) for i in range(len(self._modules))]
84+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
85+
86+
87+
def __len__(self) -> int:
88+
return len(self._modules)
89+
90+
def __add__(self, other) -> Sequential:
91+
if isinstance(other, Sequential):
92+
ret = Sequential()
93+
for layer in self:
94+
ret.append(layer)
95+
for layer in other:
96+
ret.append(layer)
97+
return ret
98+
else:
99+
raise ValueError(
100+
"add operator supports only objects "
101+
f"of Sequential class, but {str(type(other))} is given."
102+
)
103+
104+
def pop(self, key: Union[int, slice]) -> Module:
105+
v = self[key]
106+
del self[key]
107+
return v
108+
109+
def __iadd__(self, other) -> Self:
110+
if isinstance(other, Sequential):
111+
offset = len(self)
112+
for i, module in enumerate(other):
113+
self.add_module(str(i + offset), module)
114+
return self
115+
else:
116+
raise ValueError(
117+
"add operator supports only objects "
118+
f"of Sequential class, but {str(type(other))} is given."
119+
)
120+
121+
def __mul__(self, other: int) -> Sequential:
122+
if not isinstance(other, int):
123+
raise TypeError(
124+
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
125+
)
126+
elif other <= 0:
127+
raise ValueError(
128+
f"Non-positive multiplication factor {other} for {type(self)}"
129+
)
130+
else:
131+
combined = Sequential()
132+
offset = 0
133+
for _ in range(other):
134+
for module in self:
135+
combined.add_module(str(offset), module)
136+
offset += 1
137+
return combined
138+
139+
def __rmul__(self, other: int) -> Sequential:
140+
return self.__mul__(other)
141+
142+
def __imul__(self, other: int) -> Self:
143+
if not isinstance(other, int):
144+
raise TypeError(
145+
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
146+
)
147+
elif other <= 0:
148+
raise ValueError(
149+
f"Non-positive multiplication factor {other} for {type(self)}"
150+
)
151+
else:
152+
len_original = len(self)
153+
offset = len(self)
154+
for _ in range(other - 1):
155+
for i in range(len_original):
156+
self.add_module(str(i + offset), self._modules[str(i)])
157+
offset += len_original
158+
return self
159+
160+
def __dir__(self) -> list[str]:
161+
keys = super().__dir__()
162+
keys = [key for key in keys if not key.isdigit()]
163+
return keys
164+
165+
def __iter__(self) -> Iterator[Module]:
166+
return iter(self._modules.values())
167+
168+
def forward(self, input):
169+
for module in self:
170+
input = module(input)
171+
return input
172+
173+
def append(self, module: Module) -> Self:
174+
self.add_module(str(len(self)), module)
175+
return self
176+
177+
def insert(self, index: int, module: Module) -> Self:
178+
if not isinstance(module, Module):
179+
raise AssertionError(f"module should be of type: {Module}")
180+
n = len(self._modules)
181+
if not (-n <= index <= n):
182+
raise IndexError(f"Index out of range: {index}")
183+
if index < 0:
184+
index += n
185+
for i in range(n, index, -1):
186+
self._modules[str(i)] = self._modules[str(i - 1)]
187+
self._modules[str(index)] = module
188+
return self
189+
190+
def extend(self, sequential: Iterable[Module]) -> Self:
191+
for layer in sequential:
192+
self.append(layer)
193+
return self
194+
195+
class ModuleList(Module):
196+
_modules: dict[str, Module] # type: ignore[assignment]
197+
198+
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
199+
super().__init__()
200+
if modules is not None:
201+
self += modules
202+
203+
def _get_abs_string_index(self, idx):
204+
"""Get the absolute index for the list of modules."""
205+
idx = operator.index(idx)
206+
if not (-len(self) <= idx < len(self)):
207+
raise IndexError(f"index {idx} is out of range")
208+
if idx < 0:
209+
idx += len(self)
210+
return str(idx)
211+
212+
@overload
213+
def __getitem__(self, idx: slice) -> ModuleList: ...
214+
215+
@overload
216+
def __getitem__(self, idx: int) -> Module: ...
217+
218+
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]:
219+
if isinstance(idx, slice):
220+
return self.__class__(list(self._modules.values())[idx])
221+
else:
222+
return self._modules[self._get_abs_string_index(idx)]
223+
224+
def __setitem__(self, idx: int, module: Module) -> None:
225+
idx = self._get_abs_string_index(idx)
226+
return setattr(self, str(idx), module)
227+
228+
def __delitem__(self, idx: Union[int, slice]) -> None:
229+
if isinstance(idx, slice):
230+
for k in range(len(self._modules))[idx]:
231+
delattr(self, str(k))
232+
else:
233+
delattr(self, self._get_abs_string_index(idx))
234+
# To preserve numbering, self._modules is being reconstructed with modules after deletion
235+
str_indices = [str(i) for i in range(len(self._modules))]
236+
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
237+
238+
def __len__(self) -> int:
239+
return len(self._modules)
240+
241+
def __iter__(self) -> Iterator[Module]:
242+
return iter(self._modules.values())
243+
244+
def __iadd__(self, modules: Iterable[Module]) -> Self:
245+
return self.extend(modules)
246+
247+
def __add__(self, other: Iterable[Module]) -> ModuleList:
248+
combined = ModuleList()
249+
for i, module in enumerate(chain(self, other)):
250+
combined.add_module(str(i), module)
251+
return combined
252+
253+
def __repr__(self) -> str:
254+
"""Return a custom repr for ModuleList that compresses repeated module representations."""
255+
list_of_reprs = [repr(item) for item in self]
256+
if len(list_of_reprs) == 0:
257+
return self._get_name() + "()"
258+
259+
start_end_indices = [[0, 0]]
260+
repeated_blocks = [list_of_reprs[0]]
261+
for i, r in enumerate(list_of_reprs[1:], 1):
262+
if r == repeated_blocks[-1]:
263+
start_end_indices[-1][1] += 1
264+
continue
265+
266+
start_end_indices.append([i, i])
267+
repeated_blocks.append(r)
268+
269+
lines = []
270+
main_str = self._get_name() + "("
271+
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
272+
local_repr = f"({start_id}): {b}" # default repr
273+
274+
if start_id != end_id:
275+
n = end_id - start_id + 1
276+
local_repr = f"({start_id}-{end_id}): {n} x {b}"
277+
278+
local_repr = _addindent(local_repr, 2)
279+
lines.append(local_repr)
280+
281+
main_str += "\n " + "\n ".join(lines) + "\n"
282+
main_str += ")"
283+
return main_str
284+
285+
def __dir__(self) -> list[str]:
286+
keys = super().__dir__()
287+
keys = [key for key in keys if not key.isdigit()]
288+
return keys
289+
290+
def insert(self, index: int, module: Module) -> None:
291+
for i in range(len(self._modules), index, -1):
292+
self._modules[str(i)] = self._modules[str(i - 1)]
293+
self._modules[str(index)] = module
294+
295+
def append(self, module: Module) -> Self:
296+
self.add_module(str(len(self)), module)
297+
return self
298+
299+
def pop(self, key: Union[int, slice]) -> Module:
300+
v = self[key]
301+
del self[key]
302+
return v
303+
304+
def extend(self, modules: Iterable[Module]) -> Self:
305+
if not isinstance(modules, container_abcs.Iterable):
306+
raise TypeError(
307+
"ModuleList.extend should be called with an "
308+
"iterable, but got " + type(modules).__name__
309+
)
310+
offset = len(self)
311+
for i, module in enumerate(modules):
312+
self.add_module(str(offset + i), module)
313+
return self
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .gatedmlp import *
2+
3+
__all__ = [
4+
"GatedMLP",
5+
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from deepx.nn.functional import swish as swish_fn
2+
3+
ACT2FN={
4+
"silu":swish_fn,
5+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from deepx.nn.modules import Module,Linear
2+
from .actfn import ACT2FN
3+
4+
class GatedMLP(Module):
5+
def __init__(self, config:dict):
6+
super().__init__()
7+
# 输入层大小
8+
self.hidden_size = config.hidden_size
9+
# 中间层大小
10+
self.intermediate_size = config.intermediate_size
11+
#门控投影层
12+
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
13+
#上投影层
14+
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
15+
#下投影层
16+
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
17+
#激活函数
18+
self.act_fn = ACT2FN[config.hidden_act]
19+
20+
def forward(self, x):
21+
gate = self.gate_proj(x)
22+
up = self.up_proj(x)
23+
act = self.act_fn(gate)
24+
out = act * up
25+
out = self.down_proj(out)
26+
return out

0 commit comments

Comments
 (0)