@@ -37,12 +37,16 @@ public class PyTorchStateStreamer {
3737
3838 private static final Logger logger = LogManager .getLogger (PyTorchStateStreamer .class );
3939
40+ /** The size of the data written before the model definition */
41+ private static final int NUM_BYTES_IN_PRELUDE = 4 ;
42+
4043 private final OriginSettingClient client ;
4144 private final ExecutorService executorService ;
4245 private final NamedXContentRegistry xContentRegistry ;
4346 private volatile boolean isCancelled ;
4447 private volatile int modelSize = -1 ;
45- private final AtomicInteger bytesWritten = new AtomicInteger ();
48+ // model bytes only, does not include the prelude
49+ private final AtomicInteger modelBytesWritten = new AtomicInteger ();
4650
4751 public PyTorchStateStreamer (Client client , ExecutorService executorService , NamedXContentRegistry xContentRegistry ) {
4852 this .client = new OriginSettingClient (Objects .requireNonNull (client ), ML_ORIGIN );
@@ -59,7 +63,7 @@ public void cancel() {
5963
6064 /**
6165 * First writes the size of the model so the native process can
62- * allocated memory then writes the chunks of binary state.
66+ * allocate memory then writes the chunks of binary state.
6367 *
6468 * @param modelId The model to write
6569 * @param index The index to search for the model
@@ -72,11 +76,11 @@ public void writeStateToStream(String modelId, String index, OutputStream restor
7276 restorer .setSearchSize (1 );
7377 restorer .restoreModelDefinition (doc -> writeChunk (doc , restoreStream ), success -> {
7478 logger .debug ("model [{}] state restored in [{}] documents from index [{}]" , modelId , restorer .getNumDocsWritten (), index );
75- if (bytesWritten .get () != modelSize ) {
79+ if (modelBytesWritten .get () != modelSize ) {
7680 logger .error (
7781 "model [{}] restored state size [{}] does not equal the expected model size [{}]" ,
7882 modelId ,
79- bytesWritten ,
83+ modelBytesWritten ,
8084 modelSize
8185 );
8286 }
@@ -96,7 +100,7 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr
96100 // The array backing the BytesReference may be bigger than what is
97101 // referred to so write only what is after the offset
98102 outputStream .write (doc .getBinaryData ().array (), doc .getBinaryData ().arrayOffset (), doc .getBinaryData ().length ());
99- bytesWritten .addAndGet (doc .getBinaryData ().length ());
103+ modelBytesWritten .addAndGet (doc .getBinaryData ().length ());
100104 return true ;
101105 }
102106
@@ -139,12 +143,10 @@ private int writeModelSize(String modelId, Long modelSizeBytes, OutputStream out
139143 throw new IllegalStateException (message );
140144 }
141145
142- final int NUM_BYTES = 4 ;
143- ByteBuffer lengthBuffer = ByteBuffer .allocate (NUM_BYTES );
146+ ByteBuffer lengthBuffer = ByteBuffer .allocate (NUM_BYTES_IN_PRELUDE );
144147 lengthBuffer .putInt (modelSizeBytes .intValue ());
145148 outputStream .write (lengthBuffer .array ());
146149
147- bytesWritten .addAndGet (NUM_BYTES );
148150 return modelSizeBytes .intValue ();
149151 }
150152}
0 commit comments