-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathDecoder.py
More file actions
21 lines (18 loc) · 746 Bytes
/
Decoder.py
File metadata and controls
21 lines (18 loc) · 746 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#'!/usr/bin/env python
#-*- coding:utf-8 -*-
#!/usr/bin/python3
from chainer import Chain, functions, links
class Decoder(Chain):
def __init__(self, vocab_size, embed_size, hidden_size):
super(Decoder, self).__init__(
ye = links.EmbedID(vocab_size, embed_size),
eh = links.Linear(embed_size, 4 * hidden_size),
hh = links.Linear(hidden_size, 4 * hidden_size),
hf = links.Linear(hidden_size, embed_size),
weight_jy = links.Linear(embed_size, vocab_size),
)
def __call__(self, y, c, h):
e = functions.tanh(self.ye(y))
c, h = functions.lstm(c, self.eh(e) + self.hh(h))
f = functions.tanh(self.hf(h))
return self.weight_jy(f), c, h