Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ public class PyTorchStateStreamer {

private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);

/** The size of the data written before the model definition */
private static final int NUM_BYTES_IN_PRELUDE = 4;

private final OriginSettingClient client;
private final ExecutorService executorService;
private final NamedXContentRegistry xContentRegistry;
private volatile boolean isCancelled;
private volatile int modelSize = -1;
private final AtomicInteger bytesWritten = new AtomicInteger();
// model bytes only, does not include the prelude
private final AtomicInteger modelBytesWritten = new AtomicInteger();

public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN);
Expand All @@ -59,7 +63,7 @@ public void cancel() {

/**
* First writes the size of the model so the native process can
* allocated memory then writes the chunks of binary state.
* allocate memory then writes the chunks of binary state.
*
* @param modelId The model to write
* @param index The index to search for the model
Expand All @@ -72,11 +76,11 @@ public void writeStateToStream(String modelId, String index, OutputStream restor
restorer.setSearchSize(1);
restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), success -> {
logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index);
if (bytesWritten.get() != modelSize) {
if (modelBytesWritten.get() != modelSize) {
logger.error(
"model [{}] restored state size [{}] does not equal the expected model size [{}]",
modelId,
bytesWritten,
modelBytesWritten,
modelSize
);
}
Expand All @@ -96,7 +100,7 @@ private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStr
// The array backing the BytesReference may be bigger than what is
// referred to so write only what is after the offset
outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length());
bytesWritten.addAndGet(doc.getBinaryData().length());
modelBytesWritten.addAndGet(doc.getBinaryData().length());
return true;
}

Expand Down Expand Up @@ -139,12 +143,10 @@ private int writeModelSize(String modelId, Long modelSizeBytes, OutputStream out
throw new IllegalStateException(message);
}

final int NUM_BYTES = 4;
ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES);
ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES_IN_PRELUDE);
lengthBuffer.putInt(modelSizeBytes.intValue());
outputStream.write(lengthBuffer.array());

bytesWritten.addAndGet(NUM_BYTES);
return modelSizeBytes.intValue();
}
}