Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/with-grid/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ <h1>syft.js/grid.js testing</h1>
<div id="acc_graph"></div>
</div>
</div>

<div id="status"></div>
</div>

<div id="app">
Expand Down
43 changes: 33 additions & 10 deletions examples/with-grid/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,32 @@ startButton.onclick = () => {
startFL(gridServer.value, modelName, modelVersion);
};

const updateStatus = (message) => {
const cont = document.getElementById('status');
cont.innerHTML = message + '<br>' + cont.innerHTML;
};

const startFL = async (url, modelName, modelVersion) => {
const worker = new Syft({ url, verbose: true });
const job = await worker.newJob({ modelName, modelVersion });

job.start();

job.on('accepted', async ({ model, clientConfig }) => {
updateStatus("Accepted into cycle!");

// Load data
console.log('Loading data...');
updateStatus("Loading data...");
const mnist = new MnistData();
await mnist.load();
const trainDataset = mnist.getTrainData();
const data = trainDataset.xs;
const targets = trainDataset.labels;
console.log('Data loaded');
updateStatus("Data loaded.");

// Prepare randomized indices for data batching.
const indices = Array.from({length: data.shape[0]}, (v, i) => i);
tf.util.shuffle(indices);

// Prepare train parameters.
const batchSize = clientConfig.batch_size;
Expand All @@ -96,8 +107,9 @@ const startFL = async (url, modelName, modelVersion) => {
for (let update = 0, batch = 0, epoch = 0; update < numUpdates; update++) {
// Slice a batch.
const chunkSize = Math.min(batchSize, data.shape[0] - batch * batchSize);
const dataBatch = data.slice(batch * batchSize, chunkSize);
const targetBatch = targets.slice(batch * batchSize, chunkSize);
const indicesBatch = indices.slice(batch * batchSize, batch * batchSize + chunkSize);
const dataBatch = data.gather(indicesBatch);
const targetBatch = targets.gather(indicesBatch);

// Execute the plan and get updated model params back.
let [loss, acc, ...updatedModelParams] = await job.plans[
Expand Down Expand Up @@ -139,6 +151,10 @@ const startFL = async (url, modelName, modelVersion) => {
targetBatch.dispose();
}

// Free GPU memory.
data.dispose();
targets.dispose();

// TODO protocol execution
// job.protocols['secure_aggregation'].execute();

Expand All @@ -147,19 +163,26 @@ const startFL = async (url, modelName, modelVersion) => {

// Report diff.
await job.report(modelDiff);
console.log('Done!');
updateStatus('Cycle is done!');

// Try again.
setTimeout(startFL, 1000, url, modelName, modelVersion);
});

job.on('rejected', ({ timeout }) => {
// Handle the job rejection
console.log('We have been rejected by PyGrid to participate in the job.');
const msUntilRetry = timeout * 1000;
// Try to join the job again in "msUntilRetry" milliseconds
setTimeout(job.start.bind(job), msUntilRetry);
if (timeout) {
const msUntilRetry = timeout * 1000;
// Try to join the job again in "msUntilRetry" milliseconds
updateStatus(`Rejected from cycle, retry in ${timeout}`);
setTimeout(job.start.bind(job), msUntilRetry);
} else {
updateStatus(`Rejected from cycle with no timeout, assuming Model training is complete.`);
}
});

job.on('error', err => {
console.log('Error', err);
updateStatus(`Error: ${err.message}`);
});
};

Expand Down
1 change: 0 additions & 1 deletion src/_helpers.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ export const pickTensors = tree => {
return objects;
};


export const torchToTF = (command, kwargs) => {
const logger = new Logger();
const cmd_map = {
Expand Down
14 changes: 12 additions & 2 deletions src/grid-api-client.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ import Logger from './logger';
const HTTP_PATH_VERB = {
'federated/get-plan': 'GET',
'federated/get-model': 'GET',
'federated/cycle-request': 'POST'
'federated/get-protocol': 'GET',
'federated/cycle-request': 'POST',
'federated/report': 'POST'
};

export default class GridAPIClient {
Expand All @@ -23,6 +25,9 @@ export default class GridAPIClient {
this.ws = null;
this.logger = new Logger('grid', true);
this.responseTimeout = 10000;

this._handleWsError = this._handleWsError.bind(this);
this._handleWsClose = this._handleWsClose.bind(this);
}

async authenticate(authToken) {
Expand Down Expand Up @@ -212,7 +217,12 @@ export default class GridAPIClient {

const data = JSON.parse(event.data);
this.logger.log('Received message', data);
resolve(data);
if (data.type !== message.type) {
// TODO do it differently
this.logger.log('Received invalid response type, ignoring');
} else {
resolve(data.data);
}
};

this.ws.onerror = event => {
Expand Down
81 changes: 43 additions & 38 deletions src/job.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import EventObserver from './events';
import { protobuf, unserialize } from './protobuf';
import { CYCLE_STATUS_ACCEPTED, CYCLE_STATUS_REJECTED } from './_constants';
import { GRID_UNKNOWN_CYCLE_STATUS } from './_errors';
import SyftModel from './syft_model';
import SyftModel from './syft-model';
import Logger from './logger';

export default class Job {
Expand Down Expand Up @@ -79,48 +79,53 @@ export default class Job {
}

async start() {
// speed test
const { ping, download, upload } = await this.grid.getConnectionSpeed();
try {
// speed test
const { ping, download, upload } = await this.grid.getConnectionSpeed();

// request cycle
const cycleParams = await this.grid.requestCycle(
this.worker.worker_id,
this.modelName,
this.modelVersion,
ping,
download,
upload
);
// request cycle
const cycleParams = await this.grid.requestCycle(
this.worker.worker_id,
this.modelName,
this.modelVersion,
ping,
download,
upload
);

switch (cycleParams.status) {
case CYCLE_STATUS_ACCEPTED:
// load model, plans, protocols, etc.
this.logger.log(
`Accepted into cycle with params: ${JSON.stringify(
cycleParams,
null,
2
)}`
);
await this.initCycle(cycleParams);

switch (cycleParams.status) {
case CYCLE_STATUS_ACCEPTED:
// load model, plans, protocols, etc.
this.logger.log(
`Accepted into cycle with params: ${JSON.stringify(
cycleParams,
null,
2
)}`
);
await this.initCycle(cycleParams);
this.observer.broadcast('accepted', {
model: this.model,
clientConfig: this.clientConfig
});
break;

this.observer.broadcast('accepted', {
model: this.model,
clientConfig: this.clientConfig
});
break;
case CYCLE_STATUS_REJECTED:
this.logger.log(
`Rejected from cycle with timeout: ${cycleParams.timeout}`
);
this.observer.broadcast('rejected', {
timeout: cycleParams.timeout
});
break;

case CYCLE_STATUS_REJECTED:
this.logger.log(
`Rejected from cycle with timeout: ${cycleParams.timeout}`
);
this.observer.broadcast('rejected', {
timeout: cycleParams.timeout
});
break;
default:
throw new Error(GRID_UNKNOWN_CYCLE_STATUS(cycleParams.status));
}

default:
throw new Error(GRID_UNKNOWN_CYCLE_STATUS(cycleParams.status));
} catch (error) {
this.observer.broadcast('error', error);
}
}

Expand Down
42 changes: 42 additions & 0 deletions src/object-registry.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import * as tf from '@tensorflow/tfjs-core';

export default class ObjectRegistry {
constructor() {
this.objects = {};
this.gc = {};
}

set(id, obj, gc = false) {
this.objects[id] = obj;
this.gc[id] = gc;
}

setGc(id, gc) {
this.gc[id] = gc;
}

get(id) {
return this.objects[id];
}

has(id) {
return Object.hasOwnProperty.call(this.objects, id);
}

clear() {
for (let key of Object.keys(this.objects)) {
if (this.gc[key] && this.objects[key] instanceof tf.Tensor) {
this.objects[key].dispose();
}
}
this.objects = {};
this.gc = {};
}

load(objectRegistry) {
for (let key of Object.keys(objectRegistry.objects)) {
this.set(key, objectRegistry.get(key));
}
}
}

69 changes: 35 additions & 34 deletions src/syft_model.js → src/syft-model.js
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
import { unserialize, protobuf, serialize } from './protobuf';
import { State } from './types/plan';
import { TorchTensor } from './types/torch';

export default class SyftModel {
constructor({ worker, modelData }) {
const state = unserialize(
worker,
modelData,
protobuf.syft_proto.execution.v1.State
);
this.worker = worker;
this.params = state.getTfTensors();
}

async createSerializedDiff(updatedModelParams) {
const modelDiff = [];
for (let i = 0; i < updatedModelParams.length; i++) {
modelDiff.push(this.params[i].sub(updatedModelParams[i]));
}

const tensors = [];
for (let param of modelDiff) {
tensors.push(await TorchTensor.fromTfTensor(param));
}
const state = new State([], tensors);
const bin = serialize(this.worker, state);

// free memory
tensors.forEach(t => t._tfTensor.dispose());

return bin;
}
}
import { unserialize, protobuf, serialize } from './protobuf';
import { State } from './types/plan';
import { TorchTensor } from './types/torch';
import Placeholder from './types/placeholder';

export default class SyftModel {
constructor({ worker, modelData }) {
const state = unserialize(
worker,
modelData,
protobuf.syft_proto.execution.v1.State
);
this.worker = worker;
this.params = state.getTfTensors();
}

async createSerializedDiff(updatedModelParams) {
const
placeholders = [],
tensors = [];

for (let i = 0; i < updatedModelParams.length; i++) {
let paramDiff = this.params[i].sub(updatedModelParams[i]);
placeholders.push(new Placeholder(i, [`#${i}`, `#state-${i}`]));
tensors.push(await TorchTensor.fromTfTensor(paramDiff));
}
const state = new State(placeholders, tensors);
const bin = serialize(this.worker, state);

// free memory
tensors.forEach(t => t._tfTensor.dispose());

return bin;
}
}
3 changes: 2 additions & 1 deletion src/syft.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import WebRTCClient from './webrtc';
import { protobuf, unserialize } from './protobuf';
import GridAPIClient from './grid-api-client';
import Job from './job';
import ObjectRegistry from './object-registry';

export default class Syft {
/* ----- CONSTRUCTOR ----- */
Expand All @@ -28,7 +29,7 @@ export default class Syft {
this.gridClient = new GridAPIClient({ url, allowInsecureUrl: verbose });

// objects registry
this.objects = {};
this.objects = new ObjectRegistry();

// For creating event listeners
this.observer = new EventObserver();
Expand Down
Loading