diff --git a/buildSrc/src/main/java/Dependencies.kt b/buildSrc/src/main/java/Dependencies.kt index 736bb51a..52d6fb0f 100644 --- a/buildSrc/src/main/java/Dependencies.kt +++ b/buildSrc/src/main/java/Dependencies.kt @@ -24,6 +24,8 @@ object Versions { const val okhttp = "4.3.1" const val protobuf = "3.11.0" const val syftProto = "0.0.8" + const val retrofit = "2.7.1" + const val kotlinConverter = "0.4.0" // release management const val netflixPublishing = "14.0.0" @@ -45,15 +47,19 @@ object ProjectDependencies { const val androidGradlePlugin = "com.android.tools.build:gradle:${Versions.gradle}" const val kotlinGradlePlugin = "org.jetbrains.kotlin:kotlin-gradle-plugin:${Versions.kotlin}" const val kotlinSerialization = "org.jetbrains.kotlin:kotlin-serialization:${Versions.kotlin}" - const val netflixPublishingPlugin = "com.netflix.nebula:nebula-publishing-plugin:${Versions.netflixPublishing}" - const val netflixReleasePlugin = "com.netflix.nebula:nebula-release-plugin:${Versions.netflixRelease}" - const val netflixBintrayPlugin = "com.netflix.nebula:nebula-bintray-plugin:${Versions.netflixBintray}" + const val netflixPublishingPlugin = "com.netflix.nebula:nebula-publishing-plugin:" + + Versions.netflixPublishing + const val netflixReleasePlugin = "com.netflix.nebula:nebula-release-plugin:" + + Versions.netflixRelease + const val netflixBintrayPlugin = "com.netflix.nebula:nebula-bintray-plugin:" + + Versions.netflixBintray } object CommonDependencies { const val appCompat = "androidx.appcompat:appcompat:${Versions.appCompat}" const val coreKtx = "androidx.core:core-ktx:${Versions.coreKtx}" const val kotlinSerialization = "org.jetbrains.kotlinx:kotlinx-serialization-runtime:${Versions.kotlinSerialization}" + const val kotlinSerializationFactory = "com.jakewharton.retrofit:retrofit2-kotlinx-serialization-converter:${Versions.kotlinConverter}" const val rxJava = "io.reactivex.rxjava2:rxjava:${Versions.rxJava}" const val rxAndroid = "io.reactivex.rxjava2:rxandroid:${Versions.rxAndroid}" const val espresso = "androidx.test.espresso:espresso-core:${Versions.espresso}" @@ -74,4 +80,6 @@ object SyftlibDependencies { const val syftProto = "org.openmined.kotlinsyft:syft-proto-jvm:${Versions.syftProto}" const val protobuf = "com.google.protobuf:protobuf-java:${Versions.protobuf}" const val junitJupiter = "org.junit.jupiter:junit-jupiter:${Versions.junitJupiter}" + const val retrofit = "com.squareup.retrofit2:retrofit:${Versions.retrofit}" + const val retrofitAdapter = "com.squareup.retrofit2:adapter-rxjava2:${Versions.retrofit}" } diff --git a/demo-app/src/main/java/org/openmined/syft/demo/StandaloneDemo.kt b/demo-app/src/main/java/org/openmined/syft/demo/StandaloneDemo.kt index 394cbcae..86ef1fce 100644 --- a/demo-app/src/main/java/org/openmined/syft/demo/StandaloneDemo.kt +++ b/demo-app/src/main/java/org/openmined/syft/demo/StandaloneDemo.kt @@ -1,23 +1,32 @@ package org.openmined.syft.demo import io.reactivex.Scheduler +import io.reactivex.android.schedulers.AndroidSchedulers import io.reactivex.schedulers.Schedulers import org.openmined.syft.Syft -import org.openmined.syft.networking.clients.SignallingClient -import org.openmined.syft.networking.requests.Protocol +import org.openmined.syft.networking.clients.HttpClient +import org.openmined.syft.networking.clients.SocketClient import org.openmined.syft.threading.ProcessSchedulers @ExperimentalUnsignedTypes fun main() { - val syft = Syft.getInstance(SignallingClient( - Protocol.WSS, - "echo.websocket.org", - 2000u - ), object : ProcessSchedulers { + val networkingSchedulers = object : ProcessSchedulers { + override val computeThreadScheduler: Scheduler + get() = Schedulers.io() + override val calleeThreadScheduler: Scheduler + get() = AndroidSchedulers.mainThread() + } + val computeSchedulers = object : ProcessSchedulers { override val computeThreadScheduler: Scheduler get() = Schedulers.computation() override val calleeThreadScheduler: Scheduler get() = Schedulers.single() } + val syft = Syft.getInstance( + SocketClient( + "echo.websocket.org", + 2000u + , computeSchedulers + ), HttpClient("echo.websocket.org"), computeSchedulers, networkingSchedulers ) } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 28b5cd58..7aa9513d 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ -#Sat Jan 11 12:13:19 GMT 2020 +#Fri Feb 28 13:29:55 CET 2020 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.4-all.zip diff --git a/syftlib/build.gradle b/syftlib/build.gradle index 5c2bb972..b83effd5 100644 --- a/syftlib/build.gradle +++ b/syftlib/build.gradle @@ -49,11 +49,13 @@ dependencies { implementation CommonDependencies.appCompat implementation CommonDependencies.coreKtx implementation CommonDependencies.kotlinSerialization - implementation SyftlibDependencies.webrtc implementation CommonDependencies.rxJava implementation CommonDependencies.rxAndroid implementation SyftlibDependencies.okhttp + implementation CommonDependencies.kotlinSerializationFactory + implementation SyftlibDependencies.retrofitAdapter + implementation SyftlibDependencies.retrofit implementation SyftlibDependencies.syftProto implementation SyftlibDependencies.protobuf diff --git a/syftlib/src/main/java/org/openmined/syft/Syft.kt b/syftlib/src/main/java/org/openmined/syft/Syft.kt index 8640d319..d263cdf0 100644 --- a/syftlib/src/main/java/org/openmined/syft/Syft.kt +++ b/syftlib/src/main/java/org/openmined/syft/Syft.kt @@ -1,85 +1,155 @@ package org.openmined.syft -import android.util.Log +import io.reactivex.Completable import io.reactivex.disposables.CompositeDisposable -import org.openmined.syft.networking.clients.NetworkMessage -import org.openmined.syft.networking.clients.SignallingClient -import org.openmined.syft.networking.datamodels.AuthenticationSuccess -import org.openmined.syft.networking.datamodels.CycleResponseData -import org.openmined.syft.networking.datamodels.SocketResponse -import org.openmined.syft.networking.requests.CommunicationDataFactory -import org.openmined.syft.networking.requests.REQUESTS +import org.openmined.syft.networking.clients.HttpClient +import org.openmined.syft.networking.clients.SocketClient +import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess +import org.openmined.syft.networking.datamodels.syft.CycleRequest +import org.openmined.syft.networking.datamodels.syft.CycleResponseData +import org.openmined.syft.networking.requests.CommunicationAPI +import org.openmined.syft.networking.requests.HttpAPI +import org.openmined.syft.processes.JobStatusSubscriber +import org.openmined.syft.processes.SyftJob import org.openmined.syft.threading.ProcessSchedulers +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.TimeUnit private const val TAG = "Syft" @ExperimentalUnsignedTypes class Syft private constructor( - private val signallingClient: SignallingClient, - private val schedulers: ProcessSchedulers + private val socketClient: SocketClient, + private val httpClient: HttpClient, + //todo this will be removed by syft configuration class + private val computeSchedulers: ProcessSchedulers, + //todo change this to read from syft configuration + private val networkingSchedulers: ProcessSchedulers + ) { companion object { @Volatile private var INSTANCE: Syft? = null - fun getInstance(signallingClient: SignallingClient, schedulers: ProcessSchedulers): Syft = + fun getInstance( + socketClient: SocketClient, + httpClient: HttpClient, + networkingSchedulers: ProcessSchedulers, + //todo this will be removed by syft configuration class + computeSchedulers: ProcessSchedulers + ): Syft = INSTANCE ?: synchronized(this) { INSTANCE ?: Syft( - signallingClient, - schedulers + socketClient, + httpClient, + networkingSchedulers, + computeSchedulers ).also { INSTANCE = it } } } - private lateinit var workerId: String + private val workerJobs = ConcurrentHashMap() + private val compositeDisposable = CompositeDisposable() - private val workerJobs = mutableListOf() - private val compositeDisposable = CompositeDisposable().add( - signallingClient.start() - .map { - when (it) { - is NetworkMessage.SocketOpen -> - signallingClient.send(REQUESTS.AUTHENTICATION) + //todo decide if this can be changed by pygrid or will remain same irrespective of the requests we make + @Volatile + lateinit var workerId: String - is NetworkMessage.SocketClosed -> Log.d( - TAG, - "Socket was closed successfully" - ) - is NetworkMessage.SocketError -> Log.e(TAG, "socket error", it.throwable) - is NetworkMessage.MessageReceived -> handleResponse( - CommunicationDataFactory.deserializeSocket( - it.message - ) + fun newJob( + model: String, + version: String? = null + ): SyftJob { + val job = SyftJob(this, computeSchedulers, networkingSchedulers, model, version) + val jobId = SyftJob.JobID(model, version) + workerJobs[jobId] = job + job.subscribe(object : JobStatusSubscriber() { + override fun onComplete() { + workerJobs.remove(jobId) + } + + override fun onError(throwable: Throwable) { + workerJobs.remove(jobId) + } + }, networkingSchedulers) + + return job + } + + fun requestCycle(job: SyftJob) { + if (this::workerId.isInitialized) + socketClient.getCycle( + CycleRequest( + workerId, + job.modelName, + job.version, + getPing(), + getDownloadSpeed(), + getUploadSpeed() ) - is NetworkMessage.MessageSent -> println("Message sent successfully") + ).compose(networkingSchedulers.applySingleSchedulers()) + .subscribe { response: CycleResponseData -> + when (response) { + is CycleResponseData.CycleAccept -> handleCycleAccept(response) + is CycleResponseData.CycleReject -> handleCycleReject(response) + } } - } - .subscribeOn(schedulers.computeThreadScheduler) - .observeOn(schedulers.calleeThreadScheduler) - .subscribe() - ) - - fun newJob(modelName: String, version: String): SyftJob { - val job = SyftJob(modelName, version) - signallingClient.send( - REQUESTS.CYCLE_REQUEST, - CommunicationDataFactory.requestCycle(workerId, job, "", "", "") - ) - workerJobs.add(job) - return job + else { + compositeDisposable.add(socketClient.authenticate() + .compose(networkingSchedulers.applySingleSchedulers()) + .subscribe { t: AuthenticationSuccess -> + if (!this::workerId.isInitialized) + setSyftWorkerId(t.workerId) + requestCycle(job) + } + ) + } } - private fun handleResponse(response: SocketResponse) { - when (response.data) { - is AuthenticationSuccess -> - this.workerId = response.data.workerId - is CycleResponseData -> { - when (response.data) { - is CycleResponseData.CycleAccept -> "accept here" - is CycleResponseData.CycleReject -> "set timeout for job" - } - } + fun getDownloader(): HttpAPI = httpClient.apiClient - } + //todo decide this based on configuration + fun getSignallingClient(): CommunicationAPI = socketClient + + fun getWebRTCSignallingClient(): SocketClient = socketClient + + @Synchronized + private fun setSyftWorkerId(workerId: String) { + if (!this::workerId.isInitialized) + this.workerId = workerId + else if (workerJobs.isEmpty()) + this.workerId = workerId + } + + private fun getPing() = "" + private fun getDownloadSpeed() = "" + private fun getUploadSpeed() = "" + + private fun handleCycleReject(responseData: CycleResponseData.CycleReject) { + var jobId = SyftJob.JobID(responseData.modelName, responseData.version) + val job = workerJobs.getOrElse(jobId, { + jobId = SyftJob.JobID(responseData.modelName) + workerJobs.getValue(jobId) + }) + job.cycleStatus.set(SyftJob.CycleStatus.REJECT) + compositeDisposable.add( + Completable + .timer(responseData.timeout.toLong(), TimeUnit.MILLISECONDS) + .compose(networkingSchedulers.applyCompletableSchedulers()) + .subscribe { + job.cycleStatus.set(SyftJob.CycleStatus.APPLY) + job.start() + } + ) } + + private fun handleCycleAccept(responseData: CycleResponseData.CycleAccept) { + val jobId = SyftJob.JobID(responseData.modelName, responseData.version) + val job = workerJobs.getOrElse(jobId, { + workerJobs.getValue(SyftJob.JobID(responseData.modelName)) + }) + job.setRequestKey(responseData) + job.downloadData() + } + + } diff --git a/syftlib/src/main/java/org/openmined/syft/SyftJob.kt b/syftlib/src/main/java/org/openmined/syft/SyftJob.kt deleted file mode 100644 index 55bff05a..00000000 --- a/syftlib/src/main/java/org/openmined/syft/SyftJob.kt +++ /dev/null @@ -1,38 +0,0 @@ -package org.openmined.syft - - -import kotlinx.serialization.Serializable - -@Serializable -class SyftJob(val modelName: String, val version: String? = null) { - - /** - * create a worker job - */ - fun start() { - - } - - /** - * Run this once the PyGrid accepts the worker - * all the requisite plans, protocols are downloaded - */ - fun executeTrainingPlan() { - - } - - /** - * if not empty execute protocol after training - */ - fun executeProtocol() { - - } - - /** - * report the results back to PyGrid - */ - fun report() { - - } - -} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/clients/HttpClient.kt b/syftlib/src/main/java/org/openmined/syft/networking/clients/HttpClient.kt new file mode 100644 index 00000000..48b63682 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/clients/HttpClient.kt @@ -0,0 +1,17 @@ +package org.openmined.syft.networking.clients + +import com.jakewharton.retrofit2.converter.kotlinx.serialization.asConverterFactory +import kotlinx.serialization.json.Json +import okhttp3.MediaType.Companion.toMediaType +import org.openmined.syft.networking.requests.HttpAPI +import retrofit2.Retrofit +import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory + +class HttpClient(baseUrl: String) { + val apiClient: HttpAPI = Retrofit.Builder() + .addCallAdapterFactory(RxJava2CallAdapterFactory.create()) + .addConverterFactory(Json.asConverterFactory("application/json".toMediaType())) + .baseUrl(baseUrl) + .build().create(HttpAPI::class.java) + +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/clients/NetworkMessage.kt b/syftlib/src/main/java/org/openmined/syft/networking/clients/NetworkMessage.kt index 1c6e683a..3e02589f 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/clients/NetworkMessage.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/clients/NetworkMessage.kt @@ -1,9 +1,7 @@ package org.openmined.syft.networking.clients -sealed class NetworkMessage() { - object SocketClosed : NetworkMessage() +sealed class NetworkMessage { object SocketOpen : NetworkMessage() data class SocketError(val throwable: Throwable) : NetworkMessage() - object MessageSent : NetworkMessage() data class MessageReceived(val message: String) : NetworkMessage() } \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/clients/SocketClient.kt b/syftlib/src/main/java/org/openmined/syft/networking/clients/SocketClient.kt new file mode 100644 index 00000000..d6a69491 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/clients/SocketClient.kt @@ -0,0 +1,128 @@ +package org.openmined.syft.networking.clients + +import android.util.Log +import io.reactivex.Single +import io.reactivex.disposables.CompositeDisposable +import io.reactivex.processors.PublishProcessor +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonConfiguration +import kotlinx.serialization.json.json +import org.openmined.syft.networking.datamodels.NetworkModels +import org.openmined.syft.networking.datamodels.SocketResponse +import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess +import org.openmined.syft.networking.datamodels.syft.CycleRequest +import org.openmined.syft.networking.datamodels.syft.CycleResponseData +import org.openmined.syft.networking.datamodels.syft.ReportRequest +import org.openmined.syft.networking.datamodels.syft.ReportResponse +import org.openmined.syft.networking.datamodels.webRTC.InternalMessageRequest +import org.openmined.syft.networking.datamodels.webRTC.InternalMessageResponse +import org.openmined.syft.networking.datamodels.webRTC.JoinRoomRequest +import org.openmined.syft.networking.datamodels.webRTC.JoinRoomResponse +import org.openmined.syft.networking.requests.MessageTypes +import org.openmined.syft.networking.requests.Protocol +import org.openmined.syft.networking.requests.REQUESTS +import org.openmined.syft.networking.requests.SocketAPI +import org.openmined.syft.networking.requests.WebRTCMessageTypes +import org.openmined.syft.processes.SyftJob +import org.openmined.syft.threading.ProcessSchedulers +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +private const val TAG = "SocketClient" + +@ExperimentalUnsignedTypes +class SocketClient( + baseUrl: String, + private val timeout: UInt = 20000u, + private val schedulers: ProcessSchedulers +) : SocketAPI { + + //Choosing stable kotlin serialization over default + private val Json = Json(JsonConfiguration.Stable) + + private val syftWebSocket = SyftWebSocket(Protocol.WSS, baseUrl, timeout) + @Volatile + private var socketClientSubscribed = AtomicBoolean(false) + private val messageProcessor = PublishProcessor.create() + private val compositeDisposable = CompositeDisposable() + + override fun authenticate(): Single { + initiateSocketIfEmpty() + syftWebSocket.send(appendType(REQUESTS.AUTHENTICATION)) + return messageProcessor.onBackpressureLatest() + .ofType(AuthenticationSuccess::class.java) + .firstOrError() + } + + override fun getCycle(cycleRequest: CycleRequest): Single { + syftWebSocket.send(appendType(REQUESTS.CYCLE_REQUEST, cycleRequest)) + return messageProcessor.onBackpressureBuffer() + .ofType(CycleResponseData::class.java) + .filter { + SyftJob.JobID( + cycleRequest.modelName, + cycleRequest.version + ).matchWithResponse(it.modelName, it.version) + }.debounce(timeout.toLong(), TimeUnit.MILLISECONDS) + .firstOrError() + } + + //todo handle backpressure and first or error + override fun report(reportRequest: ReportRequest): Single { + syftWebSocket.send(appendType(REQUESTS.REPORT, reportRequest)) + return messageProcessor.onBackpressureDrop() + .ofType(ReportResponse::class.java) + .firstOrError() + } + + //todo handle backpressure and first or error + override fun joinRoom(joinRoomRequest: JoinRoomRequest): Single { + syftWebSocket.send(appendType(WebRTCMessageTypes.WEBRTC_JOIN_ROOM, joinRoomRequest)) + return messageProcessor.onBackpressureBuffer() + .ofType(JoinRoomResponse::class.java) + .firstOrError() + } + + //todo handle backpressure and first or error + override fun sendInternalMessage(internalMessageRequest: InternalMessageRequest): Single { + syftWebSocket.send(appendType(REQUESTS.WEBRTC_INTERNAL, internalMessageRequest)) + return messageProcessor.onBackpressureBuffer() + .ofType(InternalMessageResponse::class.java) + .first(null) + } + + private fun initiateSocketIfEmpty() { + if (socketClientSubscribed.get()) + return + compositeDisposable.add(syftWebSocket.start() + .map { + when (it) { + is NetworkMessage.SocketOpen -> authenticate() + is NetworkMessage.SocketError -> Log.e( + TAG, + "socket error", + it.throwable + ) + is NetworkMessage.MessageReceived -> emitMessage(deserializeSocket(it.message)) + } + } + .subscribeOn(schedulers.computeThreadScheduler) + .observeOn(schedulers.calleeThreadScheduler) + .subscribe()) + socketClientSubscribed.set(true) + } + + private fun emitMessage(response: SocketResponse) { + messageProcessor.offer(response.data) + } + + private fun deserializeSocket(socketMessage: String): SocketResponse { + return Json.parse(SocketResponse.serializer(), socketMessage) + } + + private fun appendType(types: MessageTypes, data: NetworkModels? = null) = json { + TYPE to types.value + if (data != null) + DATA to data + } +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/clients/SignallingClient.kt b/syftlib/src/main/java/org/openmined/syft/networking/clients/SyftWebSocket.kt similarity index 58% rename from syftlib/src/main/java/org/openmined/syft/networking/clients/SignallingClient.kt rename to syftlib/src/main/java/org/openmined/syft/networking/clients/SyftWebSocket.kt index 5f52a77b..8e3daa87 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/clients/SignallingClient.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/clients/SyftWebSocket.kt @@ -3,42 +3,39 @@ package org.openmined.syft.networking.clients import io.reactivex.Flowable import io.reactivex.processors.PublishProcessor import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.json import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.Response import okhttp3.WebSocket import okhttp3.WebSocketListener -import org.openmined.syft.networking.requests.MessageTypes import org.openmined.syft.networking.requests.Protocol import java.util.concurrent.TimeUnit +const val TYPE = "type" +const val DATA = "data" private const val SOCKET_CLOSE_CLIENT = 1000 -private const val TYPE = "type" -private const val DATA = "data" @ExperimentalUnsignedTypes -class SignallingClient( - private val protocol: Protocol, - private val address: String, - private val keepAliveTimeout: UInt = 20000u +class SyftWebSocket( + protocol: Protocol, + address: String, + keepAliveTimeout: UInt ) { - private lateinit var request: Request - private lateinit var client: OkHttpClient - private lateinit var webSocket: WebSocket - private val syftSocketListener = SyftSocketListener() + private var request = Request.Builder() + .url("$protocol://$address") + .build() + private var client = OkHttpClient.Builder() + .pingInterval(keepAliveTimeout.toLong(), TimeUnit.MILLISECONDS) + .build() + private val syftSocketListener = SyftSocketListener() private val statusPublishProcessor: PublishProcessor = PublishProcessor.create() + private lateinit var webSocket: WebSocket + fun start(): Flowable { - client = OkHttpClient.Builder() - .pingInterval(keepAliveTimeout.toLong(), TimeUnit.MILLISECONDS) - .build() - request = Request.Builder() - .url("$protocol://$address") - .build() connect() return statusPublishProcessor.onBackpressureBuffer() } @@ -46,23 +43,10 @@ class SignallingClient( /** * Send the data over the Socket connection to PyGrid */ - fun send(typesResponse: MessageTypes, data: JsonObject? = null) { - val message = json { - TYPE to typesResponse.value - if (data != null) - DATA to data - }.toString() - - if (webSocket.send(message)) { - statusPublishProcessor.offer(NetworkMessage.MessageSent) - } - } + fun send(message: JsonObject) = webSocket.send(message.toString()) + + fun close() = webSocket.close(SOCKET_CLOSE_CLIENT, "Socket closed by client") - fun close() { - if (webSocket.close(SOCKET_CLOSE_CLIENT, "Socket closed by client")) { - statusPublishProcessor.offer(NetworkMessage.SocketClosed) - } - } private fun connect() { webSocket = client.newWebSocket(request, syftSocketListener) @@ -72,7 +56,7 @@ class SignallingClient( override fun onOpen(webSocket: WebSocket, response: Response) { super.onOpen(webSocket, response) - this@SignallingClient.webSocket = webSocket + this@SyftWebSocket.webSocket = webSocket statusPublishProcessor.offer(NetworkMessage.SocketOpen) } diff --git a/syftlib/src/main/java/org/openmined/syft/networking/clients/WebRTC.kt b/syftlib/src/main/java/org/openmined/syft/networking/clients/WebRTCClient.kt similarity index 91% rename from syftlib/src/main/java/org/openmined/syft/networking/clients/WebRTC.kt rename to syftlib/src/main/java/org/openmined/syft/networking/clients/WebRTCClient.kt index ba137551..44504fe2 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/clients/WebRTC.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/clients/WebRTCClient.kt @@ -1,8 +1,12 @@ package org.openmined.syft.networking.clients import android.util.Log -import org.openmined.syft.networking.requests.CommunicationDataFactory +import io.reactivex.disposables.CompositeDisposable +import org.openmined.syft.networking.datamodels.webRTC.InternalMessageRequest +import org.openmined.syft.networking.datamodels.webRTC.JoinRoomRequest +import org.openmined.syft.networking.datamodels.webRTC.JoinRoomResponse import org.openmined.syft.networking.requests.WebRTCMessageTypes +import org.openmined.syft.threading.ProcessSchedulers import org.webrtc.DataChannel import org.webrtc.IceCandidate import org.webrtc.MediaStream @@ -23,10 +27,12 @@ private const val TAG = "WebRTCClient" internal class WebRTCClient( private val peerConnectionFactory: PeerConnectionFactory, private val peerConfig: PeerConnection.RTCConfiguration, - private val signallingClient: SignallingClient + private val socketClient: SocketClient, + private val schedulers: ProcessSchedulers ) { private val peers = HashMap() + private val compositeDisposable = CompositeDisposable() private lateinit var workerId: String private lateinit var scopeId: String @@ -35,9 +41,9 @@ internal class WebRTCClient( this.workerId = workerId this.scopeId = scopeId - signallingClient.send( - WebRTCMessageTypes.WEBRTC_JOIN_ROOM, - CommunicationDataFactory.joinRoom(workerId, scopeId) + compositeDisposable.add(socketClient.joinRoom(JoinRoomRequest(workerId, scopeId)) + .compose(schedulers.applySingleSchedulers()) + .subscribe { _: JoinRoomResponse?, _: Throwable? -> } ) } @@ -111,9 +117,16 @@ internal class WebRTCClient( private fun sendInternalMessage(type: WebRTCMessageTypes, message: String, target: String) { if (target != workerId) { Log.d(TAG, "Sending Internal WebRTC message via PyGrid") - this.signallingClient.send( - WebRTCMessageTypes.WEBRTC_INTERNAL_MESSAGE, - CommunicationDataFactory.internalMessage(workerId, scopeId, target, type, message) + compositeDisposable.add( + this.socketClient.sendInternalMessage( + InternalMessageRequest( + workerId, + scopeId, + target, + type.value, + message + ) + ).compose(schedulers.applySingleSchedulers()).subscribe() ) } } @@ -140,6 +153,7 @@ internal class WebRTCClient( } } + //todo shift this to reactive version fun receiveInternalMessage( types: WebRTCMessageTypes, newWorkerId: String, diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ModelConfig.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ClientConfig.kt similarity index 87% rename from syftlib/src/main/java/org/openmined/syft/networking/datamodels/ModelConfig.kt rename to syftlib/src/main/java/org/openmined/syft/networking/datamodels/ClientConfig.kt index 632a81f4..7be034b4 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ModelConfig.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ClientConfig.kt @@ -3,7 +3,7 @@ package org.openmined.syft.networking.datamodels import kotlinx.serialization.Serializable @Serializable -data class ModelConfig( +data class ClientConfig( //todo populate when defined val modelName: String ) diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/CycleResponseData.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/CycleResponseData.kt deleted file mode 100644 index dadc553f..00000000 --- a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/CycleResponseData.kt +++ /dev/null @@ -1,33 +0,0 @@ -package org.openmined.syft.networking.datamodels - -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable - -const val CYCLE_TYPE = "federated/cycle-request" -const val CYCLE_ACCEPT = "accepted" -const val CYCLE_REJECT = "rejected" - -@Serializable -sealed class CycleResponseData : NetworkModels() { - - @SerialName(CYCLE_ACCEPT) - @Serializable - data class CycleAccept( - @SerialName("request_key") - val requestKey: String, - @SerialName("training_plan") - val trainingPlanID: String, - @SerialName("model_config") - val modelConfig: ModelConfig, - @SerialName("protocol") - val protocolID: String, - @SerialName("model") - val modelId: String - ) : CycleResponseData() - - @SerialName(CYCLE_REJECT) - @Serializable - data class CycleReject( - val timeout: Int - ) : CycleResponseData() -} diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ReportStatus.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ReportStatus.kt deleted file mode 100644 index eed63631..00000000 --- a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/ReportStatus.kt +++ /dev/null @@ -1,10 +0,0 @@ -package org.openmined.syft.networking.datamodels - -import kotlinx.serialization.Serializable - -const val REPORT_TYPE = "federated/report" - -@Serializable -data class ReportStatus( - val status: String -) : NetworkModels() \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/WebRTCInternalMessage.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/WebRTCInternalMessage.kt deleted file mode 100644 index 119ca904..00000000 --- a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/WebRTCInternalMessage.kt +++ /dev/null @@ -1,15 +0,0 @@ -package org.openmined.syft.networking.datamodels - -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable - -const val WEBRTC_INTERNAL_TYPE = "webrtc_internal" - -@Serializable -data class WebRTCInternalMessage( - val type: String, - @SerialName("worker_id") - val newWorkerId: String, - @SerialName("sdp_string") - val sessionDescription: String -) : NetworkModels() \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/AuthenticationSuccess.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/AuthenticationSuccess.kt similarity index 68% rename from syftlib/src/main/java/org/openmined/syft/networking/datamodels/AuthenticationSuccess.kt rename to syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/AuthenticationSuccess.kt index de0dba17..7e192d64 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/AuthenticationSuccess.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/AuthenticationSuccess.kt @@ -1,7 +1,8 @@ -package org.openmined.syft.networking.datamodels +package org.openmined.syft.networking.datamodels.syft import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import org.openmined.syft.networking.datamodels.NetworkModels const val AUTH_TYPE = "federated/authenticate" diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/CycleDataModels.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/CycleDataModels.kt new file mode 100644 index 00000000..80dacb01 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/CycleDataModels.kt @@ -0,0 +1,58 @@ +package org.openmined.syft.networking.datamodels.syft + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.openmined.syft.networking.datamodels.ClientConfig +import org.openmined.syft.networking.datamodels.NetworkModels + +const val CYCLE_TYPE = "federated/cycle-request" +const val CYCLE_ACCEPT = "accepted" +const val CYCLE_REJECT = "rejected" + +@Serializable +sealed class CycleResponseData : NetworkModels() { + + @SerialName("model") + abstract val modelName: String + @SerialName("version") + abstract val version: String + + @SerialName(CYCLE_ACCEPT) + @Serializable + data class CycleAccept( + override val modelName: String, + override val version: String, + @SerialName("request_key") + val requestKey: String, + @SerialName("training_plan") + val trainingPlanID: String, + @SerialName("client_config") + val clientConfig: ClientConfig, + @SerialName("protocols") + val protocolID: String, + @SerialName("model_id") + val modelId: String + ) : CycleResponseData() + + @SerialName(CYCLE_REJECT) + @Serializable + data class CycleReject( + override val modelName: String, + override val version: String, + val timeout: Int + ) : CycleResponseData() +} + +@Serializable +data class CycleRequest( + @SerialName("worker_id") + val workerId: String, + @SerialName("model") + val modelName: String, + val version: String? = null, + val ping: String, + @SerialName("download") + val downloadSpeed: String, + @SerialName("upload") + val uploadSpeed: String +) : NetworkModels() diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/ReportDataModels.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/ReportDataModels.kt new file mode 100644 index 00000000..128098d1 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/syft/ReportDataModels.kt @@ -0,0 +1,21 @@ +package org.openmined.syft.networking.datamodels.syft + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.openmined.syft.networking.datamodels.NetworkModels + +const val REPORT_TYPE = "federated/report" + +@Serializable +data class ReportResponse( + val status: String +) : NetworkModels() + +@Serializable +data class ReportRequest( + @SerialName("worker_id") + val workerId: String, + @SerialName("request_key") + val requestKey: String, + val diff: String +) : NetworkModels() \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/InternalMessageResponse.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/InternalMessageResponse.kt new file mode 100644 index 00000000..502c8836 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/InternalMessageResponse.kt @@ -0,0 +1,27 @@ +package org.openmined.syft.networking.datamodels.webRTC + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.openmined.syft.networking.datamodels.NetworkModels + +const val WEBRTC_INTERNAL_TYPE = "webrtc_internal" + +@Serializable +data class InternalMessageResponse( + val type: String, + @SerialName("worker_id") + val newWorkerId: String, + @SerialName("sdp_string") + val sessionDescription: String +) : NetworkModels() + +@Serializable +data class InternalMessageRequest( + @SerialName("worker_id") + val workerId: String, + @SerialName("scope_id") + val scopeId: String, + val target: String, + val type: String, + val message: String +) : NetworkModels() \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/JoinRoom.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/JoinRoom.kt new file mode 100644 index 00000000..2415e6d4 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/JoinRoom.kt @@ -0,0 +1,21 @@ +package org.openmined.syft.networking.datamodels.webRTC + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.openmined.syft.networking.datamodels.NetworkModels + +@Serializable +data class JoinRoomRequest( + @SerialName("worker_id") + val workerId: String, + @SerialName("scope_id") + val scopeId: String +) : NetworkModels() + +@Serializable +data class JoinRoomResponse( + @SerialName("worker_id") + val workerId: String, + @SerialName("scope_id") + val scopeId: String +) : NetworkModels() \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/WebRTCNewPeer.kt b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/NewPeer.kt similarity index 54% rename from syftlib/src/main/java/org/openmined/syft/networking/datamodels/WebRTCNewPeer.kt rename to syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/NewPeer.kt index c611b74a..f91e80b2 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/datamodels/WebRTCNewPeer.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/datamodels/webRTC/NewPeer.kt @@ -1,12 +1,13 @@ -package org.openmined.syft.networking.datamodels +package org.openmined.syft.networking.datamodels.webRTC import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import org.openmined.syft.networking.datamodels.NetworkModels const val NEW_PEER_TYPE = "peer" @Serializable -data class WebRTCNewPeer( +data class NewPeer( @SerialName("worker_id") val workerId: String -) : NetworkModels() \ No newline at end of file +) : NetworkModels() diff --git a/syftlib/src/main/java/org/openmined/syft/networking/requests/CommunicationAPI.kt b/syftlib/src/main/java/org/openmined/syft/networking/requests/CommunicationAPI.kt new file mode 100644 index 00000000..534ef666 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/requests/CommunicationAPI.kt @@ -0,0 +1,17 @@ +package org.openmined.syft.networking.requests + +import io.reactivex.Single +import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess +import org.openmined.syft.networking.datamodels.syft.CycleRequest +import org.openmined.syft.networking.datamodels.syft.CycleResponseData +import org.openmined.syft.networking.datamodels.syft.ReportRequest +import org.openmined.syft.networking.datamodels.syft.ReportResponse + + +interface CommunicationAPI { + fun authenticate(): Single + + fun getCycle(cycleRequest: CycleRequest): Single + + fun report(reportRequest: ReportRequest): Single +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/requests/HttpAPI.kt b/syftlib/src/main/java/org/openmined/syft/networking/requests/HttpAPI.kt new file mode 100644 index 00000000..127fcf31 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/requests/HttpAPI.kt @@ -0,0 +1,55 @@ +package org.openmined.syft.networking.requests + +import io.reactivex.Single +import okhttp3.ResponseBody +import org.openmined.syft.networking.datamodels.syft.AUTH_TYPE +import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess +import org.openmined.syft.networking.datamodels.syft.CYCLE_TYPE +import org.openmined.syft.networking.datamodels.syft.CycleRequest +import org.openmined.syft.networking.datamodels.syft.CycleResponseData +import org.openmined.syft.networking.datamodels.syft.REPORT_TYPE +import org.openmined.syft.networking.datamodels.syft.ReportRequest +import org.openmined.syft.networking.datamodels.syft.ReportResponse +import retrofit2.Response +import retrofit2.http.Body +import retrofit2.http.GET +import retrofit2.http.POST +import retrofit2.http.Query +import retrofit2.http.Streaming + +interface HttpAPI : CommunicationAPI { + + @Streaming + @GET("/federated/get-plan") + fun downloadPlan( + @Query("worker_id") workerId: String, + @Query("request_key") requestKey: String, + @Query("plan_id") planId: String, + @Query("receive_operations_as") op_type: String + ): Single> + + @Streaming + @GET("/federated/get-protocol") + fun downloadProtocol( + @Query("worker_id") workerId: String, + @Query("request_key") requestKey: String, + @Query("protocol_id") protocolId: String + ): Single> + + @Streaming + @GET("/federated/get-model") + fun downloadModel( + @Query("worker_id") workerId: String, + @Query("request_key") requestKey: String, + @Query("model_id") modelId: String + ): Single> + + @GET(AUTH_TYPE) + override fun authenticate(): Single + + @POST(CYCLE_TYPE) + override fun getCycle(@Body cycleRequest: CycleRequest): Single + + @POST(REPORT_TYPE) + override fun report(@Body reportRequest: ReportRequest): Single +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/networking/requests/RequestBuilder.kt b/syftlib/src/main/java/org/openmined/syft/networking/requests/RequestBuilder.kt deleted file mode 100644 index 2a695eab..00000000 --- a/syftlib/src/main/java/org/openmined/syft/networking/requests/RequestBuilder.kt +++ /dev/null @@ -1,73 +0,0 @@ -package org.openmined.syft.networking.requests - -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonConfiguration -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.json -import org.openmined.syft.SyftJob -import org.openmined.syft.networking.datamodels.SocketResponse - -class CommunicationDataFactory { - companion object DataFactory { - /** - * The data curators for http and web socket requests - */ - - //Choosing stable kotlin serialization over default - private val Json = Json(JsonConfiguration.Stable) - - fun requestCycle( - workerId: String, - syftJob: SyftJob, - ping: String, - download: String, - upload: String - ): JsonObject { - return json { - "workerId" to workerId - "model" to syftJob.modelName - "ping" to ping - "download" to download - "upload" to upload - if (syftJob.version != null) - "version" to syftJob.version - } - } - - fun report(workerId: String, requestKey: String, diff: String): JsonObject { - return json { - "workerId" to workerId - "request_key" to requestKey - "diff" to diff - } - } - - fun joinRoom(workerId: String, scopeId: String): JsonObject { - return json { - "workerId" to workerId - "scopeId" to scopeId - } - } - - fun internalMessage( - workerId: String, - scopeId: String, - target: String, - type: WebRTCMessageTypes, - message: String - ): JsonObject { - return json { - "workerId" to workerId - "scopeId" to scopeId - "to" to target - "type" to type.value - "data" to message - } - } - - fun deserializeSocket(socketMessage: String): SocketResponse { - return Json.parse(SocketResponse.serializer(), socketMessage) - } - } -} - diff --git a/syftlib/src/main/java/org/openmined/syft/networking/requests/ResponseRequestTypes.kt b/syftlib/src/main/java/org/openmined/syft/networking/requests/ResponseRequestTypes.kt index fd8a5e03..1e731157 100644 --- a/syftlib/src/main/java/org/openmined/syft/networking/requests/ResponseRequestTypes.kt +++ b/syftlib/src/main/java/org/openmined/syft/networking/requests/ResponseRequestTypes.kt @@ -3,17 +3,17 @@ package org.openmined.syft.networking.requests import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonConfiguration import kotlinx.serialization.json.JsonElement -import org.openmined.syft.networking.datamodels.AUTH_TYPE -import org.openmined.syft.networking.datamodels.AuthenticationSuccess -import org.openmined.syft.networking.datamodels.CYCLE_TYPE -import org.openmined.syft.networking.datamodels.CycleResponseData -import org.openmined.syft.networking.datamodels.NEW_PEER_TYPE +import org.openmined.syft.networking.datamodels.syft.AUTH_TYPE +import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess +import org.openmined.syft.networking.datamodels.syft.CYCLE_TYPE +import org.openmined.syft.networking.datamodels.syft.CycleResponseData +import org.openmined.syft.networking.datamodels.webRTC.NEW_PEER_TYPE import org.openmined.syft.networking.datamodels.NetworkModels -import org.openmined.syft.networking.datamodels.REPORT_TYPE -import org.openmined.syft.networking.datamodels.ReportStatus -import org.openmined.syft.networking.datamodels.WEBRTC_INTERNAL_TYPE -import org.openmined.syft.networking.datamodels.WebRTCInternalMessage -import org.openmined.syft.networking.datamodels.WebRTCNewPeer +import org.openmined.syft.networking.datamodels.syft.REPORT_TYPE +import org.openmined.syft.networking.datamodels.syft.ReportResponse +import org.openmined.syft.networking.datamodels.webRTC.WEBRTC_INTERNAL_TYPE +import org.openmined.syft.networking.datamodels.webRTC.InternalMessageResponse +import org.openmined.syft.networking.datamodels.webRTC.NewPeer enum class REQUESTS(override val value: String) : ResponseMessageTypes { @@ -40,30 +40,30 @@ enum class REQUESTS(override val value: String) : ResponseMessageTypes { REPORT(REPORT_TYPE) { override val jsonParser = Json(JsonConfiguration.Stable) override fun parseJson(jsonString: String): NetworkModels = - jsonParser.parse(ReportStatus.serializer(), jsonString) + jsonParser.parse(ReportResponse.serializer(), jsonString) override fun serialize(obj: NetworkModels) = - jsonParser.toJson(ReportStatus.serializer(), obj as ReportStatus) + jsonParser.toJson(ReportResponse.serializer(), obj as ReportResponse) }, WEBRTC_INTERNAL(WEBRTC_INTERNAL_TYPE) { override val jsonParser: Json get() = Json(JsonConfiguration.Stable) override fun parseJson(jsonString: String): NetworkModels = - jsonParser.parse(WebRTCInternalMessage.serializer(), jsonString) + jsonParser.parse(InternalMessageResponse.serializer(), jsonString) override fun serialize(obj: NetworkModels): JsonElement = - jsonParser.toJson(WebRTCInternalMessage.serializer(), obj as WebRTCInternalMessage) + jsonParser.toJson(InternalMessageResponse.serializer(), obj as InternalMessageResponse) }, WEBRTC_PEER(NEW_PEER_TYPE) { override val jsonParser: Json get() = Json(JsonConfiguration.Stable) override fun parseJson(jsonString: String): NetworkModels = - jsonParser.parse(WebRTCNewPeer.serializer(), jsonString) + jsonParser.parse(NewPeer.serializer(), jsonString) override fun serialize(obj: NetworkModels): JsonElement = - jsonParser.toJson(WebRTCInternalMessage.serializer(), obj as WebRTCInternalMessage) + jsonParser.toJson(InternalMessageResponse.serializer(), obj as InternalMessageResponse) } diff --git a/syftlib/src/main/java/org/openmined/syft/networking/requests/SocketAPI.kt b/syftlib/src/main/java/org/openmined/syft/networking/requests/SocketAPI.kt new file mode 100644 index 00000000..11a47871 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/networking/requests/SocketAPI.kt @@ -0,0 +1,14 @@ +package org.openmined.syft.networking.requests + +import io.reactivex.Single +import org.openmined.syft.networking.datamodels.webRTC.InternalMessageRequest +import org.openmined.syft.networking.datamodels.webRTC.InternalMessageResponse +import org.openmined.syft.networking.datamodels.webRTC.JoinRoomRequest +import org.openmined.syft.networking.datamodels.webRTC.JoinRoomResponse + +interface SocketAPI : CommunicationAPI { + + fun joinRoom(joinRoomRequest: JoinRoomRequest): Single + + fun sendInternalMessage(internalMessageRequest: InternalMessageRequest): Single +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/processes/JobStatusMessage.kt b/syftlib/src/main/java/org/openmined/syft/processes/JobStatusMessage.kt new file mode 100644 index 00000000..9c48c5df --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/processes/JobStatusMessage.kt @@ -0,0 +1,6 @@ +package org.openmined.syft.processes + +sealed class JobStatusMessage { + object JobCycleAccepted : JobStatusMessage() + object JobReady : JobStatusMessage() +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/processes/JobStatusSubscriber.kt b/syftlib/src/main/java/org/openmined/syft/processes/JobStatusSubscriber.kt new file mode 100644 index 00000000..b45090d1 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/processes/JobStatusSubscriber.kt @@ -0,0 +1,14 @@ +package org.openmined.syft.processes + +open class JobStatusSubscriber { + open fun onReady() {} + open fun onComplete() {} + open fun onError(throwable: Throwable) {} + + fun onJobStatusMessage(jobStatusMessage: JobStatusMessage) { + when (jobStatusMessage) { + is JobStatusMessage.JobReady -> onReady() + //add all the other messages as and when needed + } + } +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/processes/Plan.kt b/syftlib/src/main/java/org/openmined/syft/processes/Plan.kt new file mode 100644 index 00000000..11f091a5 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/processes/Plan.kt @@ -0,0 +1,7 @@ +package org.openmined.syft.processes + +class Plan(val planId: String) { + lateinit var torchScriptLocation: String + fun execute(input: String, target: String) { + } +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/processes/Protocol.kt b/syftlib/src/main/java/org/openmined/syft/processes/Protocol.kt new file mode 100644 index 00000000..0fd86bc1 --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/processes/Protocol.kt @@ -0,0 +1,7 @@ +package org.openmined.syft.processes + +class Protocol(val protocolId: String) { + lateinit var protocolFileLocation: String + fun execute() { + } +} \ No newline at end of file diff --git a/syftlib/src/main/java/org/openmined/syft/processes/SyftJob.kt b/syftlib/src/main/java/org/openmined/syft/processes/SyftJob.kt new file mode 100644 index 00000000..b789271a --- /dev/null +++ b/syftlib/src/main/java/org/openmined/syft/processes/SyftJob.kt @@ -0,0 +1,205 @@ +package org.openmined.syft.processes + +import android.util.Log +import io.reactivex.Single +import io.reactivex.disposables.CompositeDisposable +import io.reactivex.processors.PublishProcessor +import org.openmined.syft.Syft +import org.openmined.syft.networking.datamodels.syft.CycleResponseData +import org.openmined.syft.networking.datamodels.syft.ReportRequest +import org.openmined.syft.networking.datamodels.syft.ReportResponse +import org.openmined.syft.threading.ProcessSchedulers +import java.io.File +import java.io.InputStream +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicReference + +private const val TAG = "SyftJob" + +/** + * @param worker : The syft worker handling this job + * @param computeSchedulers : The threads on which networking and file saving occurs + * @param modelName : The model being trained or used in inference + * @param version : The version of the model with name modelName + */ +@ExperimentalUnsignedTypes +class SyftJob( + private val worker: Syft, + //todo change this to read from syft configuration + private val computeSchedulers: ProcessSchedulers, + //todo change this to read from syft configuration + private val networkingSchedulers: ProcessSchedulers, + val modelName: String, + val version: String? = null +) { + + + var cycleStatus = AtomicReference(CycleStatus.APPLY) + private var trainingParamsStatus = AtomicReference(DownloadStatus.NOT_STARTED) + + private lateinit var requestKey: String + + //todo need to filled based on the destination directory defined by syft configuration class + private val destinationDir = "" + private val modelFileLocation = "$destinationDir/model/$modelName" + private val plans = ConcurrentHashMap() + private val protocols = ConcurrentHashMap() + private val jobStatusProcessor: PublishProcessor = PublishProcessor.create() + private val compositeDisposable = CompositeDisposable() + + /** + * create a worker job + */ + fun start(subscriber: JobStatusSubscriber = JobStatusSubscriber()) { + //todo all this in syft.kt + //todo check for connection if doesn't exist establish one + //todo before calling this function syft should have checked the bandwidth etc requirements + if (cycleStatus.get() == CycleStatus.APPLY) { + Log.d(TAG, "job awaiting timer completion to resend the Cycle Request") + return + } + worker.requestCycle(this) + subscribe(subscriber, computeSchedulers) + } + + /** + * report the results back to PyGrid + */ + fun report(diff: String) { + compositeDisposable.add( + worker.getSignallingClient().report(ReportRequest(worker.workerId, requestKey, diff)) + .compose(networkingSchedulers.applySingleSchedulers()) + .subscribe { reportResponse: ReportResponse -> + Log.i(TAG, reportResponse.status) + }) + } + + fun subscribe( + subscriber: JobStatusSubscriber, + schedulers: ProcessSchedulers + ) { + compositeDisposable.add( + jobStatusProcessor.onBackpressureBuffer() + .compose(schedulers.applyFlowableSchedulers()) + .subscribe( + { message -> subscriber.onJobStatusMessage(message) }, + { error -> subscriber.onError(error) }, + { subscriber.onComplete() } + ) + ) + } + + @Synchronized + fun setRequestKey(responseData: CycleResponseData.CycleAccept) { + requestKey = responseData.requestKey + cycleStatus.set(CycleStatus.ACCEPTED) + jobStatusProcessor.offer(JobStatusMessage.JobCycleAccepted) + } + + //todo before downloading check for wifi connection again + fun downloadData() { + if (trainingParamsStatus.get() != DownloadStatus.NOT_STARTED) { + Log.d(TAG, "download already running") + return + } + trainingParamsStatus.set(DownloadStatus.RUNNING) + val downloadList = mutableListOf>() + + plans.forEach { (planId, plan) -> + //todo instead of hardcoding this will be defined by configuration class method and by plan class + plan.torchScriptLocation = "$destinationDir/plans/$planId" + downloadList.add(planDownloader(plan.torchScriptLocation, planId)) + } + protocols.forEach { (protocolId, protocol) -> + //todo instead of hardcoding this will be defined by configuration class method and by protocol class + protocol.protocolFileLocation = "$destinationDir/plans/$protocolId" + downloadList.add(protocolDownloader(protocol.protocolFileLocation, protocolId)) + } + downloadList.add(modelDownloader(modelName)) + + compositeDisposable.add(Single.zip(downloadList) { successMessages -> + successMessages.joinToString( + ",", + prefix = "files ", + postfix = "downloaded successfully" + ) + } + .compose(networkingSchedulers.applySingleSchedulers()) + .subscribe( + { successMsg: String -> + Log.d(TAG, successMsg) + trainingParamsStatus.set(DownloadStatus.COMPLETE) + jobStatusProcessor.offer(JobStatusMessage.JobReady) + }, + { e -> jobStatusProcessor.onError(e) } + ) + ) + } + + //We might want to make these public if needed later + private fun modelDownloader(modelName: String) = + worker.getDownloader().downloadModel(worker.workerId, requestKey, modelName).compose( + computeSchedulers.applySingleSchedulers() + ).flatMap { response -> + saveFile(response.body()?.byteStream(), modelFileLocation, modelName) + } + + + private fun planDownloader(destinationDir: String, planId: String) = + worker.getDownloader().downloadPlan( + worker.workerId, + requestKey, + planId, + "torchscript" + ).compose(computeSchedulers.applySingleSchedulers()) + .flatMap { response -> + saveFile(response.body()?.byteStream(), destinationDir, planId) + } + + private fun protocolDownloader(destinationDir: String, protocolId: String) = + worker.getDownloader().downloadProtocol( + worker.workerId, + requestKey, + protocolId + ).compose(computeSchedulers.applySingleSchedulers()) + .flatMap { response -> + saveFile(response.body()?.byteStream(), destinationDir, protocolId) + } + + private fun saveFile( + input: InputStream?, + destinationDir: String, + fileName: String + ): Single { + val destination = File(destinationDir) + if (!destination.exists()) + destination.mkdirs() + return Single.create { emitter -> + input?.let { + val file = File(destination,fileName) + file.outputStream().use { outputFile -> + input.copyTo(outputFile) + } + emitter.onSuccess(file.absolutePath) + } ?: emitter.onError(Exception("invalid response stream for downloaded file")) + } + } + + + data class JobID(val modelName: String, val version: String? = null) { + fun matchWithResponse(modelName: String, version: String) = + if (this.version.isNullOrEmpty()) + this.modelName == modelName + else + (this.modelName == modelName) && (this.version == version) + } + + enum class DownloadStatus { + NOT_STARTED, RUNNING, COMPLETE + } + + enum class CycleStatus { + APPLY, REJECT, ACCEPTED + } + +} diff --git a/syftlib/src/main/java/org/openmined/syft/threading/ProcessSchedulers.kt b/syftlib/src/main/java/org/openmined/syft/threading/ProcessSchedulers.kt index b8a60a0b..a8273e78 100644 --- a/syftlib/src/main/java/org/openmined/syft/threading/ProcessSchedulers.kt +++ b/syftlib/src/main/java/org/openmined/syft/threading/ProcessSchedulers.kt @@ -1,6 +1,9 @@ package org.openmined.syft.threading +import io.reactivex.Completable +import io.reactivex.Flowable import io.reactivex.Scheduler +import io.reactivex.Single interface ProcessSchedulers { @@ -15,4 +18,22 @@ interface ProcessSchedulers { * @sample calleeThreadScheduler AndroidSchedulers.MainThread() */ val calleeThreadScheduler: Scheduler + + fun applySingleSchedulers() = { singleObservable: Single -> + singleObservable + .subscribeOn(computeThreadScheduler) + .observeOn(calleeThreadScheduler) + } + + fun applyCompletableSchedulers() = { completable: Completable -> + completable + .subscribeOn(computeThreadScheduler) + .observeOn(calleeThreadScheduler) + } + + fun applyFlowableSchedulers() = { flowable: Flowable -> + flowable + .subscribeOn(computeThreadScheduler) + .observeOn(calleeThreadScheduler) + } } \ No newline at end of file diff --git a/syftlib/src/test/java/org/openmined/syft/SyftTest.kt b/syftlib/src/test/java/org/openmined/syft/SyftTest.kt deleted file mode 100644 index da7a3860..00000000 --- a/syftlib/src/test/java/org/openmined/syft/SyftTest.kt +++ /dev/null @@ -1,29 +0,0 @@ -package org.openmined.syft - -import com.nhaarman.mockitokotlin2.doReturn -import com.nhaarman.mockitokotlin2.mock -import com.nhaarman.mockitokotlin2.verify -import com.nhaarman.mockitokotlin2.whenever -import io.reactivex.Flowable -import io.reactivex.schedulers.Schedulers -import org.junit.jupiter.api.Test -import org.openmined.syft.networking.clients.NetworkMessage -import org.openmined.syft.networking.clients.SignallingClient -import org.openmined.syft.threading.ProcessSchedulers - -internal class SyftTest { - - @Test - @ExperimentalUnsignedTypes - fun `Given a syft object when start is invoked the the signalling client is started`() { - val signallingClient = mock() - whenever(signallingClient.start()).thenReturn(Flowable.just(NetworkMessage.SocketOpen)) - val schedulers = mock { - on { computeThreadScheduler } doReturn Schedulers.trampoline() - on { calleeThreadScheduler } doReturn Schedulers.trampoline() - } - - Syft.getInstance(signallingClient, schedulers) - verify(signallingClient).start() - } -} diff --git a/syftlib/src/test/java/org/openmined/syft/networking/RequestBuilderTest.kt b/syftlib/src/test/java/org/openmined/syft/networking/RequestBuilderTest.kt deleted file mode 100644 index 535d380d..00000000 --- a/syftlib/src/test/java/org/openmined/syft/networking/RequestBuilderTest.kt +++ /dev/null @@ -1,118 +0,0 @@ -package org.openmined.syft.networking - -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.json -import org.junit.jupiter.api.Test -import org.openmined.syft.networking.datamodels.AuthenticationSuccess -import org.openmined.syft.networking.datamodels.CycleResponseData -import org.openmined.syft.networking.datamodels.ModelConfig -import org.openmined.syft.networking.datamodels.ReportStatus -import org.openmined.syft.networking.datamodels.WebRTCInternalMessage -import org.openmined.syft.networking.datamodels.WebRTCNewPeer -import org.openmined.syft.networking.requests.CommunicationDataFactory -import org.openmined.syft.networking.requests.REQUESTS - -class RequestBuilderTest { - - private val authenticationSuccess = - appendType( - "federated/authenticate", - json { "worker_id" to "Test worker ID" }).toString() - - private val cycleResponseReject = - appendType("federated/cycle-request", - json { - "status" to "rejected" - "timeout" to 2700 - }).toString() - - private val cycleResponseAccept = - appendType("federated/cycle-request", - json { - "status" to "accepted" - "request_key" to "LONG HASH VALUE" - "training_plan" to "TRAINING ID" - "model_config" to json { "modelName" to "model test" } - "protocol" to "PROTOCOL ID" - "model" to "model ID" - }).toString() - - private val reportStatus = - appendType("federated/report", json { "status" to "success" }).toString() - - private val webRTCInternal = - appendType("webrtc_internal", json { - "type" to "candidate" - "worker_id" to "testing new worker" - "sdp_string" to "SDP" - }).toString() - - private val newPeer = - appendType("peer", json { "worker_id" to "new ID" }).toString() - @Test - fun `given authentication json is parsed into AuthenticationSuccess class`() { - - val deserializeObject = - CommunicationDataFactory.deserializeSocket(authenticationSuccess) - assert( - deserializeObject.data == AuthenticationSuccess( - "Test worker ID" - ) - ) - } - - @Test - fun `given cycle response as reject parse into CycleReject`() { - val deserializeObject = CommunicationDataFactory.deserializeSocket(cycleResponseReject) - val trueObject = CycleResponseData.CycleReject(2700) - assert(deserializeObject.data == trueObject) - assert(deserializeObject.typesResponse == REQUESTS.CYCLE_REQUEST) - } - - @Test - fun `given cycle response as accept parse into CycleAccept`() { - val deserializeObject = CommunicationDataFactory.deserializeSocket(cycleResponseAccept) - val trueObject = CycleResponseData.CycleAccept( - "LONG HASH VALUE", - "TRAINING ID", - ModelConfig("model test"), - "PROTOCOL ID", - "model ID" - ) - assert(deserializeObject.data == trueObject) - assert(deserializeObject.typesResponse == REQUESTS.CYCLE_REQUEST) - } - - @Test - fun `check report status`() { - val deserializeObject = CommunicationDataFactory.deserializeSocket(reportStatus) - val trueObject = ReportStatus("success") - assert(deserializeObject.data == trueObject) - assert(deserializeObject.typesResponse == REQUESTS.REPORT) - } - - @Test - fun `check webRTC internal message deserialization`() { - val deserializeObject = CommunicationDataFactory.deserializeSocket(webRTCInternal) - val trueObject = WebRTCInternalMessage("candidate", "testing new worker", "SDP") - assert(deserializeObject.data == trueObject) - assert(deserializeObject.typesResponse == REQUESTS.WEBRTC_INTERNAL) - } - - @Test - fun `check webRTC new peer message deserialization`() { - val deserializeObject = CommunicationDataFactory.deserializeSocket(newPeer) - val trueObject = WebRTCNewPeer("new ID") - assert(deserializeObject.data == trueObject) - assert(deserializeObject.typesResponse == REQUESTS.WEBRTC_PEER) - } - - - private fun appendType(type: String, obj: JsonElement): JsonElement { - return json { - "type" to type - "data" to obj - } - } - -} \ No newline at end of file diff --git a/syftlib/src/test/java/org/openmined/syft/networking/WebRTCClientTest.kt b/syftlib/src/test/java/org/openmined/syft/networking/WebRTCClientTest.kt deleted file mode 100644 index cb3ac444..00000000 --- a/syftlib/src/test/java/org/openmined/syft/networking/WebRTCClientTest.kt +++ /dev/null @@ -1,49 +0,0 @@ -package org.openmined.syft.networking - -import kotlinx.serialization.json.json -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.mockito.InjectMocks -import org.mockito.Mock -import org.mockito.Mockito.verify -import org.mockito.MockitoAnnotations -import org.openmined.syft.networking.clients.SignallingClient -import org.openmined.syft.networking.clients.WebRTCClient -import org.openmined.syft.networking.requests.WebRTCMessageTypes -import org.webrtc.PeerConnection -import org.webrtc.PeerConnectionFactory - -private const val TAG = "WebRTC test" - -@ExperimentalUnsignedTypes -class WebRTCClientTest { - - @Mock - private lateinit var peerConnectionFactory: PeerConnectionFactory - @Mock - private lateinit var peerConfig: PeerConnection.RTCConfiguration - @Mock - private lateinit var signallingClient: SignallingClient - - @InjectMocks - private lateinit var cut: WebRTCClient - - @BeforeEach - fun setUp() { - MockitoAnnotations.initMocks(this) - } - - @Test - @ExperimentalUnsignedTypes - fun `Given a workerId and a scopeId when the client starts it sends it through the socket`() { - val workerId = "workerId" - val scopeId = "scopeId" - val expected = json { - "workerId" to workerId - "scopeId" to scopeId - } - cut.start(workerId, scopeId) - - verify(signallingClient).send(WebRTCMessageTypes.WEBRTC_JOIN_ROOM, expected) - } -}