Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions syftlib/src/main/java/org/openmined/syft/Syft.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package org.openmined.syft

import android.util.Log
import io.reactivex.Completable
import io.reactivex.disposables.CompositeDisposable
import org.openmined.syft.networking.clients.HttpClient
import org.openmined.syft.networking.clients.SocketClient
import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess
import org.openmined.syft.networking.datamodels.syft.AuthenticationResponse
import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import org.openmined.syft.networking.requests.CommunicationAPI
Expand Down Expand Up @@ -96,9 +97,14 @@ class Syft private constructor(
else {
compositeDisposable.add(socketClient.authenticate()
.compose(networkingSchedulers.applySingleSchedulers())
.subscribe { t: AuthenticationSuccess ->
if (!this::workerId.isInitialized)
setSyftWorkerId(t.workerId)
.subscribe { t: AuthenticationResponse ->
when (t) {
is AuthenticationResponse.AuthenticationSuccess ->
if (!this::workerId.isInitialized)
setSyftWorkerId(t.workerId)
is AuthenticationResponse.AuthenticationError ->
Log.d(TAG, t.errorMessage)
}
requestCycle(job)
}
)
Expand All @@ -125,7 +131,7 @@ class Syft private constructor(
private fun getUploadSpeed() = ""

private fun handleCycleReject(responseData: CycleResponseData.CycleReject) {
var jobId = SyftJob.JobID(responseData.modelName, responseData.version)
var jobId = SyftJob.JobID(responseData.modelName)
val job = workerJobs.getOrElse(jobId, {
jobId = SyftJob.JobID(responseData.modelName)
workerJobs.getValue(jobId)
Expand All @@ -143,7 +149,7 @@ class Syft private constructor(
}

private fun handleCycleAccept(responseData: CycleResponseData.CycleAccept) {
val jobId = SyftJob.JobID(responseData.modelName, responseData.version)
val jobId = SyftJob.JobID(responseData.modelName)
val job = workerJobs.getOrElse(jobId, {
workerJobs.getValue(SyftJob.JobID(responseData.modelName))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import kotlinx.serialization.json.JsonConfiguration
import kotlinx.serialization.json.json
import org.openmined.syft.networking.datamodels.NetworkModels
import org.openmined.syft.networking.datamodels.SocketResponse
import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess
import org.openmined.syft.networking.datamodels.syft.AuthenticationResponse
import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import org.openmined.syft.networking.datamodels.syft.ReportRequest
Expand All @@ -21,6 +21,7 @@ import org.openmined.syft.networking.datamodels.webRTC.JoinRoomResponse
import org.openmined.syft.networking.requests.MessageTypes
import org.openmined.syft.networking.requests.Protocol
import org.openmined.syft.networking.requests.REQUESTS
import org.openmined.syft.networking.requests.ResponseMessageTypes
import org.openmined.syft.networking.requests.SocketAPI
import org.openmined.syft.networking.requests.WebRTCMessageTypes
import org.openmined.syft.processes.SyftJob
Expand All @@ -46,46 +47,51 @@ class SocketClient(
private val messageProcessor = PublishProcessor.create<NetworkModels>()
private val compositeDisposable = CompositeDisposable()

override fun authenticate(): Single<AuthenticationSuccess> {
override fun authenticate(): Single<AuthenticationResponse> {
initiateSocketIfEmpty()
syftWebSocket.send(appendType(REQUESTS.AUTHENTICATION))
syftWebSocket.send(serializeNetworkModel(REQUESTS.AUTHENTICATION))
return messageProcessor.onBackpressureLatest()
.ofType(AuthenticationSuccess::class.java)
.ofType(AuthenticationResponse::class.java)
.firstOrError()
}

override fun getCycle(cycleRequest: CycleRequest): Single<CycleResponseData> {
syftWebSocket.send(appendType(REQUESTS.CYCLE_REQUEST, cycleRequest))
syftWebSocket.send(serializeNetworkModel(REQUESTS.CYCLE_REQUEST, cycleRequest))
return messageProcessor.onBackpressureBuffer()
.ofType(CycleResponseData::class.java)
.filter {
SyftJob.JobID(
cycleRequest.modelName,
cycleRequest.version
).matchWithResponse(it.modelName, it.version)
).matchWithResponse(it.modelName)
}.debounce(timeout.toLong(), TimeUnit.MILLISECONDS)
.firstOrError()
}

//todo handle backpressure and first or error
override fun report(reportRequest: ReportRequest): Single<ReportResponse> {
syftWebSocket.send(appendType(REQUESTS.REPORT, reportRequest))
syftWebSocket.send(serializeNetworkModel(REQUESTS.REPORT, reportRequest))
return messageProcessor.onBackpressureDrop()
.ofType(ReportResponse::class.java)
.firstOrError()
}

//todo handle backpressure and first or error
override fun joinRoom(joinRoomRequest: JoinRoomRequest): Single<JoinRoomResponse> {
syftWebSocket.send(appendType(WebRTCMessageTypes.WEBRTC_JOIN_ROOM, joinRoomRequest))
syftWebSocket.send(
serializeNetworkModel(
WebRTCMessageTypes.WEBRTC_JOIN_ROOM,
joinRoomRequest
)
)
return messageProcessor.onBackpressureBuffer()
.ofType(JoinRoomResponse::class.java)
.firstOrError()
}

//todo handle backpressure and first or error
override fun sendInternalMessage(internalMessageRequest: InternalMessageRequest): Single<InternalMessageResponse> {
syftWebSocket.send(appendType(REQUESTS.WEBRTC_INTERNAL, internalMessageRequest))
syftWebSocket.send(serializeNetworkModel(REQUESTS.WEBRTC_INTERNAL, internalMessageRequest))
return messageProcessor.onBackpressureBuffer()
.ofType(InternalMessageResponse::class.java)
.first(null)
Expand Down Expand Up @@ -120,9 +126,14 @@ class SocketClient(
return Json.parse(SocketResponse.serializer(), socketMessage)
}

private fun appendType(types: MessageTypes, data: NetworkModels? = null) = json {
private fun serializeNetworkModel(types: MessageTypes, data: NetworkModels? = null) = json {
TYPE to types.value
if (data != null)
DATA to data
if (data != null) {
if (types is ResponseMessageTypes)
DATA to types.serialize(data)
else
//todo change this appropriately when needed
DATA to data.toString()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package org.openmined.syft.networking.datamodels

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class ClientConfig(
//todo populate when defined
val modelName: String
@SerialName("name")
val modelName: String,
@SerialName("version")
val modelVersion: String,
@SerialName("batch_size")
val batchSize: Int,
val lr: Float,
@SerialName("max_updates")
val maxUpdates: Int
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import kotlinx.serialization.Decoder
import kotlinx.serialization.Encoder
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialDescriptor
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.Serializer
Expand All @@ -19,6 +20,7 @@ private const val TAG = "SocketSerializer"

@Serializable(with = SocketSerializer::class)
data class SocketResponse(
@SerialName("type")
val typesResponse: ResponseMessageTypes,
val data: NetworkModels
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.openmined.syft.networking.datamodels.syft

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import org.openmined.syft.networking.datamodels.NetworkModels

const val AUTH_TYPE = "federated/authenticate"
const val AUTH_SUCCESS = "success"
const val AUTH_FAILURE = "rejected"

@Serializable
data class AuthenticationRequest(
@SerialName("auth_token")
val authToken: String = ""
) : NetworkModels()

@Serializable
sealed class AuthenticationResponse : NetworkModels() {

@SerialName(AUTH_SUCCESS)
@Serializable
data class AuthenticationSuccess(
@SerialName("worker_id")
val workerId: String
) : AuthenticationResponse()

@SerialName(AUTH_FAILURE)
@Serializable
data class AuthenticationError(
@SerialName("error")
val errorMessage: String
) : AuthenticationResponse()
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,29 @@ const val CYCLE_REJECT = "rejected"

@Serializable
sealed class CycleResponseData : NetworkModels() {

@SerialName("model")

abstract val modelName: String
@SerialName("version")
abstract val version: String

@SerialName(CYCLE_ACCEPT)
@Serializable
data class CycleAccept(
@SerialName("model")
override val modelName: String,
override val version: String,
@SerialName("request_key")
val requestKey: String,
@SerialName("training_plan")
val trainingPlanID: String,
val plans: HashMap<String, String>,
@SerialName("client_config")
val clientConfig: ClientConfig,
@SerialName("protocols")
val protocolID: String,
val protocols: HashMap<String, String>,
@SerialName("model_id")
val modelId: String
) : CycleResponseData()

@SerialName(CYCLE_REJECT)
@Serializable
data class CycleReject(
@SerialName("model")
override val modelName: String,
override val version: String,
val timeout: Int
) : CycleResponseData()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
package org.openmined.syft.networking.requests

import io.reactivex.Single
import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess
import org.openmined.syft.networking.datamodels.syft.AuthenticationResponse
import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import org.openmined.syft.networking.datamodels.syft.ReportRequest
import org.openmined.syft.networking.datamodels.syft.ReportResponse


interface CommunicationAPI {
fun authenticate(): Single<AuthenticationSuccess>
fun authenticate(): Single<AuthenticationResponse>

fun getCycle(cycleRequest: CycleRequest): Single<CycleResponseData>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.openmined.syft.networking.requests
import io.reactivex.Single
import okhttp3.ResponseBody
import org.openmined.syft.networking.datamodels.syft.AUTH_TYPE
import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess
import org.openmined.syft.networking.datamodels.syft.AuthenticationResponse
import org.openmined.syft.networking.datamodels.syft.CYCLE_TYPE
import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
Expand Down Expand Up @@ -45,7 +45,7 @@ interface HttpAPI : CommunicationAPI {
): Single<Response<ResponseBody>>

@GET(AUTH_TYPE)
override fun authenticate(): Single<AuthenticationSuccess>
override fun authenticate(): Single<AuthenticationResponse>

@POST(CYCLE_TYPE)
override fun getCycle(@Body cycleRequest: CycleRequest): Single<CycleResponseData>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,34 @@ package org.openmined.syft.networking.requests
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonConfiguration
import kotlinx.serialization.json.JsonElement
import org.openmined.syft.networking.datamodels.NetworkModels
import org.openmined.syft.networking.datamodels.syft.AUTH_TYPE
import org.openmined.syft.networking.datamodels.syft.AuthenticationSuccess
import org.openmined.syft.networking.datamodels.syft.AuthenticationRequest
import org.openmined.syft.networking.datamodels.syft.AuthenticationResponse
import org.openmined.syft.networking.datamodels.syft.CYCLE_TYPE
import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import org.openmined.syft.networking.datamodels.webRTC.NEW_PEER_TYPE
import org.openmined.syft.networking.datamodels.NetworkModels
import org.openmined.syft.networking.datamodels.syft.REPORT_TYPE
import org.openmined.syft.networking.datamodels.syft.ReportRequest
import org.openmined.syft.networking.datamodels.syft.ReportResponse
import org.openmined.syft.networking.datamodels.webRTC.WEBRTC_INTERNAL_TYPE
import org.openmined.syft.networking.datamodels.webRTC.InternalMessageRequest
import org.openmined.syft.networking.datamodels.webRTC.InternalMessageResponse
import org.openmined.syft.networking.datamodels.webRTC.NEW_PEER_TYPE
import org.openmined.syft.networking.datamodels.webRTC.NewPeer
import org.openmined.syft.networking.datamodels.webRTC.WEBRTC_INTERNAL_TYPE

enum class REQUESTS(override val value: String) : ResponseMessageTypes {

AUTHENTICATION(AUTH_TYPE) {
override val jsonParser = Json(JsonConfiguration.Stable)
override val jsonParser = Json(JsonConfiguration.Stable.copy(classDiscriminator = "status"))
override fun parseJson(jsonString: String): NetworkModels =
jsonParser.parse(AuthenticationSuccess.serializer(), jsonString)
jsonParser.parse(AuthenticationResponse.serializer(), jsonString)

override fun serialize(obj: NetworkModels) =
jsonParser.toJson(AuthenticationSuccess.serializer(), obj as AuthenticationSuccess)
jsonParser.toJson(
AuthenticationRequest.serializer(),
obj as AuthenticationRequest
)
},

CYCLE_REQUEST(CYCLE_TYPE) {
Expand All @@ -35,15 +42,15 @@ enum class REQUESTS(override val value: String) : ResponseMessageTypes {
jsonParser.parse(CycleResponseData.serializer(), jsonString)

override fun serialize(obj: NetworkModels) =
jsonParser.toJson(CycleResponseData.serializer(), obj as CycleResponseData)
jsonParser.toJson(CycleRequest.serializer(), obj as CycleRequest)
},
REPORT(REPORT_TYPE) {
override val jsonParser = Json(JsonConfiguration.Stable)
override fun parseJson(jsonString: String): NetworkModels =
jsonParser.parse(ReportResponse.serializer(), jsonString)

override fun serialize(obj: NetworkModels) =
jsonParser.toJson(ReportResponse.serializer(), obj as ReportResponse)
jsonParser.toJson(ReportRequest.serializer(), obj as ReportRequest)
},
WEBRTC_INTERNAL(WEBRTC_INTERNAL_TYPE) {
override val jsonParser: Json
Expand All @@ -53,7 +60,10 @@ enum class REQUESTS(override val value: String) : ResponseMessageTypes {
jsonParser.parse(InternalMessageResponse.serializer(), jsonString)

override fun serialize(obj: NetworkModels): JsonElement =
jsonParser.toJson(InternalMessageResponse.serializer(), obj as InternalMessageResponse)
jsonParser.toJson(
InternalMessageRequest.serializer(),
obj as InternalMessageRequest
)
},
WEBRTC_PEER(NEW_PEER_TYPE) {
override val jsonParser: Json
Expand All @@ -63,7 +73,10 @@ enum class REQUESTS(override val value: String) : ResponseMessageTypes {
jsonParser.parse(NewPeer.serializer(), jsonString)

override fun serialize(obj: NetworkModels): JsonElement =
jsonParser.toJson(InternalMessageResponse.serializer(), obj as InternalMessageResponse)
jsonParser.toJson(
InternalMessageRequest.serializer(),
obj as InternalMessageRequest
)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class SyftJob(


data class JobID(val modelName: String, val version: String? = null) {
fun matchWithResponse(modelName: String, version: String) =
fun matchWithResponse(modelName: String, version: String? = null) =
if (this.version.isNullOrEmpty())
this.modelName == modelName
else
Expand Down