55 "crypto/cipher"
66 "crypto/rand"
77 "errors"
8+ "fmt"
89 "io"
910 "net"
1011 "runtime"
@@ -95,19 +96,31 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) {
9596 nfsKey , encapsulatedNfsKey := i .nfsEKey .Encapsulate ()
9697 paddingLen := randBetween (100 , 1000 )
9798
98- clientHello := make ([]byte , 1 + 1184 + 1088 + 5 + paddingLen )
99- clientHello [0 ] = ClientCipher
100- copy (clientHello [1 :], pfsEKeyBytes )
101- copy (clientHello [1185 :], encapsulatedNfsKey )
102- EncodeHeader (clientHello [2273 :], int (paddingLen ))
103- rand .Read (clientHello [2278 :])
99+ clientHello := make ([]byte , 5 + 1 + 1184 + 1088 + 5 + paddingLen )
100+ EncodeHeader (clientHello , 1 , 1 + 1184 + 1088 )
101+ clientHello [5 ] = ClientCipher
102+ copy (clientHello [5 + 1 :], pfsEKeyBytes )
103+ copy (clientHello [5 + 1 + 1184 :], encapsulatedNfsKey )
104+ EncodeHeader (clientHello [5 + 1 + 1184 + 1088 :], 23 , int (paddingLen ))
105+ rand .Read (clientHello [5 + 1 + 1184 + 1088 + 5 :])
104106
105- if _ , err := c .Conn .Write (clientHello ); err != nil {
107+ if n , err := c .Conn .Write (clientHello ); n != len ( clientHello ) || err != nil {
106108 return nil , err
107109 }
108- // we can send more padding if needed
110+ // client can send more padding / NFS AEAD messages if needed
111+
112+ _ , t , l , err := ReadAndDecodeHeader (c .Conn )
113+ if err != nil {
114+ return nil , err
115+ }
116+ if t != 1 {
117+ return nil , fmt .Errorf ("unexpected type %v, expect server hello" , t )
118+ }
109119
110120 peerServerHello := make ([]byte , 1088 + 21 )
121+ if l != len (peerServerHello ) {
122+ return nil , fmt .Errorf ("unexpected length %v for server hello" , l )
123+ }
111124 if _ , err := io .ReadFull (c .Conn , peerServerHello ); err != nil {
112125 return nil , err
113126 }
@@ -122,7 +135,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) {
122135
123136 nonce := [12 ]byte {ClientCipher }
124137 VLESS , _ := NewAead (ClientCipher , c .baseKey , encapsulatedPfsKey , encapsulatedNfsKey ).Open (nil , nonce [:], c .ticket , pfsEKeyBytes )
125- if ! bytes .Equal (VLESS , []byte ("VLESS" )) { // TODO: more messages
138+ if ! bytes .Equal (VLESS , []byte ("VLESS" )) {
126139 return nil , errors .New ("invalid server" )
127140 }
128141
@@ -153,21 +166,22 @@ func (c *ClientConn) Write(b []byte) (int, error) {
153166 rand .Read (c .random )
154167 c .aead = NewAead (ClientCipher , c .baseKey , c .random , c .ticket )
155168 c .nonce = make ([]byte , 12 )
156- data = make ([]byte , 21 + 32 + 5 + len (b )+ 16 )
157- copy (data , c .ticket )
158- copy (data [21 :], c .random )
159- EncodeHeader (data [53 :], len (b )+ 16 )
160- c .aead .Seal (data [:58 ], c .nonce , b , data [53 :58 ])
169+ data = make ([]byte , 5 + 21 + 32 + 5 + len (b )+ 16 )
170+ EncodeHeader (data , 0 , 21 + 32 )
171+ copy (data [5 :], c .ticket )
172+ copy (data [5 + 21 :], c .random )
173+ EncodeHeader (data [5 + 21 + 32 :], 23 , len (b )+ 16 )
174+ c .aead .Seal (data [:5 + 21 + 32 + 5 ], c .nonce , b , data [5 + 21 + 32 :5 + 21 + 32 + 5 ])
161175 } else {
162176 data = make ([]byte , 5 + len (b )+ 16 )
163- EncodeHeader (data , len (b )+ 16 )
177+ EncodeHeader (data , 23 , len (b )+ 16 )
164178 c .aead .Seal (data [:5 ], c .nonce , b , data [:5 ])
165179 if bytes .Equal (c .nonce , MaxNonce ) {
166180 c .aead = NewAead (ClientCipher , c .baseKey , data [5 :], data [:5 ])
167181 }
168182 }
169183 IncreaseNonce (c .nonce )
170- if _ , err := c .Conn .Write (data ); err != nil {
184+ if n , err := c .Conn .Write (data ); n != len ( data ) || err != nil {
171185 return 0 , err
172186 }
173187 }
@@ -178,29 +192,44 @@ func (c *ClientConn) Read(b []byte) (int, error) {
178192 if len (b ) == 0 {
179193 return 0 , nil
180194 }
181- peerHeader := make ([]byte , 5 )
182195 if c .peerAead == nil {
183- if c .instance == nil {
196+ var t byte
197+ var l int
198+ var err error
199+ if c .instance == nil { // 1-RTT
184200 for {
185- if _ , err := io . ReadFull (c .Conn , peerHeader ); err != nil {
201+ if _ , t , l , err = ReadAndDecodeHeader (c .Conn ); err != nil {
186202 return 0 , err
187203 }
188- peerPaddingLen , _ := DecodeHeader (peerHeader )
189- if peerPaddingLen == 0 {
204+ if t != 23 {
190205 break
191206 }
192- if _ , err := io .ReadFull (c .Conn , make ([]byte , peerPaddingLen )); err != nil {
207+ if _ , err := io .ReadFull (c .Conn , make ([]byte , l )); err != nil {
193208 return 0 , err
194209 }
195210 }
196211 } else {
197- if _ , err := io .ReadFull (c .Conn , peerHeader ); err != nil {
212+ h := make ([]byte , 5 )
213+ if _ , err := io .ReadFull (c .Conn , h ); err != nil {
198214 return 0 , err
199215 }
216+ if t , l , err = DecodeHeader (h ); err != nil {
217+ c .instance .Lock ()
218+ if bytes .Equal (c .ticket , c .instance .ticket ) {
219+ c .instance .expire = time .Now () // expired
220+ }
221+ c .instance .Unlock ()
222+ return 0 , errors .New ("new handshake needed" )
223+ }
224+ }
225+ if t != 0 {
226+ return 0 , fmt .Errorf ("unexpected type %v, expect server random" , t )
200227 }
201228 peerRandom := make ([]byte , 32 )
202- copy (peerRandom , peerHeader )
203- if _ , err := io .ReadFull (c .Conn , peerRandom [5 :]); err != nil {
229+ if l != len (peerRandom ) {
230+ return 0 , fmt .Errorf ("unexpected length %v for server random" , l )
231+ }
232+ if _ , err := io .ReadFull (c .Conn , peerRandom ); err != nil {
204233 return 0 , err
205234 }
206235 if c .random == nil {
@@ -214,33 +243,26 @@ func (c *ClientConn) Read(b []byte) (int, error) {
214243 c .peerCache = c .peerCache [n :]
215244 return n , nil
216245 }
217- if _ , err := io .ReadFull (c .Conn , peerHeader ); err != nil {
218- return 0 , err
219- }
220- peerLength , err := DecodeHeader (peerHeader ) // 17~17000
246+ h , t , l , err := ReadAndDecodeHeader (c .Conn ) // l: 17~17000
221247 if err != nil {
222- if c .instance != nil {
223- c .instance .Lock ()
224- if bytes .Equal (c .ticket , c .instance .ticket ) {
225- c .instance .expire = time .Now () // expired
226- }
227- c .instance .Unlock ()
228- }
229248 return 0 , err
230249 }
231- peerData := make ([]byte , peerLength )
250+ if t != 23 {
251+ return 0 , fmt .Errorf ("unexpected type %v, expect encrypted data" , t )
252+ }
253+ peerData := make ([]byte , l )
232254 if _ , err := io .ReadFull (c .Conn , peerData ); err != nil {
233255 return 0 , err
234256 }
235- dst := peerData [:peerLength - 16 ]
257+ dst := peerData [:l - 16 ]
236258 if len (dst ) <= len (b ) {
237259 dst = b [:len (dst )] // avoids another copy()
238260 }
239261 var peerAead cipher.AEAD
240262 if bytes .Equal (c .peerNonce , MaxNonce ) {
241- peerAead = NewAead (ClientCipher , c .baseKey , peerData , peerHeader )
263+ peerAead = NewAead (ClientCipher , c .baseKey , peerData , h )
242264 }
243- _ , err = c .peerAead .Open (dst [:0 ], c .peerNonce , peerData , peerHeader )
265+ _ , err = c .peerAead .Open (dst [:0 ], c .peerNonce , peerData , h )
244266 if peerAead != nil {
245267 c .peerAead = peerAead
246268 }
0 commit comments