From 56f2c8f128150b69294482f2310936e2b925279c Mon Sep 17 00:00:00 2001 From: Tong Guo <779222056@qq.com> Date: Tue, 29 Oct 2019 11:38:29 +0800 Subject: [PATCH 1/3] Update multisql_predictor.py --- models/multisql_predictor.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/models/multisql_predictor.py b/models/multisql_predictor.py index a800419..6114fda 100644 --- a/models/multisql_predictor.py +++ b/models/multisql_predictor.py @@ -16,15 +16,15 @@ def __init__(self, N_word, N_h, N_depth, gpu, use_hs): self.gpu = gpu self.use_hs = use_hs - self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, + self.q_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) - self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, + self.hs_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) - self.mkw_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2, + self.mkw_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) @@ -91,7 +91,10 @@ def forward(self, q_emb_var, q_len, hs_emb_var, hs_len, mkw_emb_var, mkw_len): def loss(self, score, truth): data = torch.from_numpy(np.array(truth)) - truth_var = Variable(data.cuda()) + if self.gpu: + truth_var = Variable(data.cuda()) + else: + truth_var = Variable(data) loss = self.CE(score, truth_var) return loss From bb620f4b1003f2836d11dc184c99ca806ab35f6f Mon Sep 17 00:00:00 2001 From: Tong Guo <779222056@qq.com> Date: Tue, 29 Oct 2019 14:52:01 +0800 Subject: [PATCH 2/3] Update multisql_predictor.py --- models/multisql_predictor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/multisql_predictor.py b/models/multisql_predictor.py index 6114fda..20c671a 100644 --- a/models/multisql_predictor.py +++ b/models/multisql_predictor.py @@ -95,6 +95,7 @@ def loss(self, score, truth): truth_var = Variable(data.cuda()) else: truth_var = Variable(data) + truth_var = torch._cast_Long(truth_var) loss = self.CE(score, truth_var) return loss From 50f2c3f871a2a79b9bc68f5e78c2051011237d19 Mon Sep 17 00:00:00 2001 From: Tong Guo <779222056@qq.com> Date: Tue, 29 Oct 2019 14:55:46 +0800 Subject: [PATCH 3/3] Update multisql_predictor.py --- models/multisql_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/multisql_predictor.py b/models/multisql_predictor.py index 20c671a..b8ad65e 100644 --- a/models/multisql_predictor.py +++ b/models/multisql_predictor.py @@ -35,7 +35,7 @@ def __init__(self, N_word, N_h, N_depth, gpu, use_hs): self.multi_out_c = nn.Linear(N_h, N_h) self.multi_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1)) - self.softmax = nn.Softmax() #dim=1 + self.softmax = nn.Softmax(dim=1) #dim=1 self.CE = nn.CrossEntropyLoss() self.log_softmax = nn.LogSoftmax() self.mlsml = nn.MultiLabelSoftMarginLoss()