Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .idea/codeStyles/Project.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion demo-app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies {
implementation CommonDependencies.kotlinSerialization
implementation CommonDependencies.rxJava
implementation CommonDependencies.rxAndroid

// TODO During the first stages of the project, include the library here. Later on we should use the library from the repository
implementation project(path: ':syftlib')
// implementation 'org.openmined.kotlinsyft:syftlib:0.2.0'
Expand Down
6 changes: 4 additions & 2 deletions demo-app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.openmined.syft.demo">

<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />

<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:usesCleartextTraffic="true"
android:theme="@style/AppTheme">
android:theme="@style/AppTheme"
android:usesCleartextTraffic="true">
<activity
android:name=".ui.MainActivity"
android:label="@string/title_activity_main"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class LocalMNISTModuleDataSource constructor(

fun loadModule(): Module {
val module =
MessageProcessor().processTorchScript(resources.openRawResource(R.raw.tp_ts).readBytes())
MessageProcessor().processTorchScript(
resources.openRawResource(R.raw.tp_ts).readBytes()
)
val path = saveScript(module.obj)
Log.d("MainActivity", "TorchScript saved at $path")
return Module.load(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ private const val TAG = "FederatedCycleViewModel"

@ExperimentalUnsignedTypes
class FederatedCycleViewModel(
socketClient: SocketClient,
httpClient: HttpClient,
baseurl: String,
authToken: String,
networkSchedulers: ProcessSchedulers,
computeSchedulers: ProcessSchedulers
) : ViewModel() {
private val syftWorker = Syft.getInstance(
socketClient, httpClient,
baseurl, authToken,
networkSchedulers, computeSchedulers
)
private val mnistJob = syftWorker.newJob("mnist","1.0.0")
private val mnistJob = syftWorker.newJob("mnist", "1.0.0")

fun startCycle() {
Log.d(TAG,"mnist job started")
val jobStatusSubscriber = object : JobStatusSubscriber(){
Log.d(TAG, "mnist job started")
val jobStatusSubscriber = object : JobStatusSubscriber() {
override fun onReady(model: String, clientConfig: ClientConfig) {
//todo training code goes here
}
Expand Down
13 changes: 2 additions & 11 deletions demo-app/src/main/java/org/openmined/syft/demo/ui/MainActivity.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@ import org.openmined.syft.demo.datasource.LocalMNISTModuleDataSource
import org.openmined.syft.demo.domain.MNISTDataRepository
import org.openmined.syft.demo.domain.MNISTModuleRepository
import org.openmined.syft.demo.domain.MNISTTrainer
import org.openmined.syft.networking.clients.HttpClient
import org.openmined.syft.networking.clients.SocketClient
import org.openmined.syft.threading.ProcessSchedulers

@ExperimentalUnsignedTypes
class MainActivity : AppCompatActivity() {

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
val binding: ActivityMainBinding =
DataBindingUtil.setContentView(this, R.layout.activity_main)
setSupportActionBar(toolbar)
Expand All @@ -44,21 +41,15 @@ class MainActivity : AppCompatActivity() {
override val calleeThreadScheduler: Scheduler
get() = Schedulers.single()
}
val socketClient = SocketClient(
baseUrl,
2000u,
computeSchedulers
)
val httpClient = HttpClient(baseUrl)

val localModuleDataSource = LocalMNISTModuleDataSource(resources, filesDir)
val moduleRepository = MNISTModuleRepository(localModuleDataSource)
val localMNISTDataDataSource = LocalMNISTDataDataSource(resources)
val dataRepository = MNISTDataRepository(localMNISTDataDataSource)
val trainer = MNISTTrainer()
return MainViewModelFactory(
socketClient,
httpClient,
baseUrl,
"auth",
networkingSchedulers,
computeSchedulers
).create(FederatedCycleViewModel::class.java)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@ package org.openmined.syft.demo.ui

import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import org.openmined.syft.demo.domain.MNISTModuleRepository
import org.openmined.syft.demo.domain.MNISTDataRepository
import org.openmined.syft.demo.domain.MNISTTrainer
import org.openmined.syft.networking.clients.HttpClient
import org.openmined.syft.networking.clients.SocketClient
import org.openmined.syft.threading.ProcessSchedulers

@ExperimentalUnsignedTypes
class MainViewModelFactory(
private val socketClient: SocketClient,
private val httpClient: HttpClient,
private val baseUrl: String,
private val authToken: String,
private val networkSchedulers: ProcessSchedulers,
private val computeSchedulers: ProcessSchedulers
) : ViewModelProvider.Factory {

override fun <T : ViewModel?> create(modelClass: Class<T>): T {
return FederatedCycleViewModel(
socketClient,
httpClient,
baseUrl,
authToken,
networkSchedulers,
computeSchedulers
) as T
Expand Down
44 changes: 27 additions & 17 deletions syftlib/src/main/java/org/openmined/syft/Syft.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import org.openmined.syft.networking.requests.CommunicationAPI
import org.openmined.syft.networking.requests.HttpAPI
import org.openmined.syft.networking.requests.SocketAPI
import org.openmined.syft.processes.JobStatusSubscriber
import org.openmined.syft.processes.SyftJob
import org.openmined.syft.threading.ProcessSchedulers
Expand All @@ -17,9 +18,10 @@ import java.util.concurrent.ConcurrentHashMap
private const val TAG = "Syft"

@ExperimentalUnsignedTypes
class Syft private constructor(
private val socketClient: SocketClient,
private val httpClient: HttpClient,
class Syft internal constructor(
private val authToken: String,
private var socketClient: SocketClient,
private var httpClient: HttpClient,
//todo this will be removed by syft configuration class
private val computeSchedulers: ProcessSchedulers,
//todo change this to read from syft configuration
Expand All @@ -31,16 +33,17 @@ class Syft private constructor(
private var INSTANCE: Syft? = null

fun getInstance(
socketClient: SocketClient,
httpClient: HttpClient,
baseUrl: String,
authToken: String,
networkingSchedulers: ProcessSchedulers,
//todo this will be removed by syft configuration class
computeSchedulers: ProcessSchedulers
): Syft =
INSTANCE ?: synchronized(this) {
INSTANCE ?: Syft(
socketClient,
httpClient,
authToken,
SocketClient(baseUrl, 2000u, networkingSchedulers),
HttpClient(baseUrl),
networkingSchedulers,
computeSchedulers
).also { INSTANCE = it }
Expand Down Expand Up @@ -77,15 +80,15 @@ class Syft private constructor(
fun requestCycle(job: SyftJob) {
if (this::workerId.isInitialized)
socketClient.getCycle(
CycleRequest(
workerId,
job.modelName,
job.version,
getPing(),
getDownloadSpeed(),
getUploadSpeed()
)
).compose(networkingSchedulers.applySingleSchedulers())
CycleRequest(
workerId,
job.modelName,
job.version,
getPing(),
getDownloadSpeed(),
getUploadSpeed()
)
).compose(networkingSchedulers.applySingleSchedulers())
.subscribe { response: CycleResponseData ->
when (response) {
is CycleResponseData.CycleAccept -> handleCycleAccept(response)
Expand Down Expand Up @@ -115,8 +118,15 @@ class Syft private constructor(

//todo decide this based on configuration
fun getSignallingClient(): CommunicationAPI = socketClient
fun getWebRTCSignallingClient(): SocketAPI = socketClient

fun setHttpClient(httpClient: HttpClient) {
this.httpClient = httpClient
}

fun getWebRTCSignallingClient(): SocketClient = socketClient
fun setSocketClient(socketClient: SocketClient) {
this.socketClient = socketClient
}

@Synchronized
private fun setSyftWorkerId(workerId: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SocketClient(
private val Json = Json(JsonConfiguration.Stable)

private val syftWebSocket = SyftWebSocket(Protocol.WSS, baseUrl, timeout)

@Volatile
private var socketClientSubscribed = AtomicBoolean(false)
private val messageProcessor = PublishProcessor.create<NetworkModels>()
Expand All @@ -56,7 +57,10 @@ class SocketClient(
}

override fun getCycle(cycleRequest: CycleRequest): Single<CycleResponseData> {
Log.d(TAG, "sending message: " + serializeNetworkModel(REQUESTS.CYCLE_REQUEST, cycleRequest))
Log.d(
TAG,
"sending message: " + serializeNetworkModel(REQUESTS.CYCLE_REQUEST, cycleRequest)
)
syftWebSocket.send(serializeNetworkModel(REQUESTS.CYCLE_REQUEST, cycleRequest))
return messageProcessor.onBackpressureBuffer()
.ofType(CycleResponseData::class.java)
Expand Down Expand Up @@ -114,7 +118,7 @@ class SocketClient(
it.throwable
)
is NetworkMessage.MessageReceived -> {
Log.d(TAG,"received the message "+it.message)
Log.d(TAG, "received the message " + it.message)
emitMessage(deserializeSocket(it.message))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const val CYCLE_REJECT = "rejected"

@Serializable
sealed class CycleResponseData : NetworkModels() {

abstract val modelName: String
abstract val version: String?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ import retrofit2.http.Body
import retrofit2.http.GET
import retrofit2.http.POST
import retrofit2.http.Query
import retrofit2.http.Streaming

interface HttpAPI : CommunicationAPI {

// @Streaming
// @Streaming
@GET("/federated/get-plan")
fun downloadPlan(
@Query("worker_id") workerId: String,
Expand All @@ -28,15 +27,15 @@ interface HttpAPI : CommunicationAPI {
@Query("receive_operations_as") op_type: String
): Single<Response<ResponseBody>>

// @Streaming
// @Streaming
@GET("/federated/get-protocol")
fun downloadProtocol(
@Query("worker_id") workerId: String,
@Query("request_key") requestKey: String,
@Query("protocol_id") protocolId: String
): Single<Response<ResponseBody>>

// @Streaming
// @Streaming
@GET("/federated/get-model")
fun downloadModel(
@Query("worker_id") workerId: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,10 @@ sealed class Protocol {
return "http"
}
}

object HTTPS : Protocol() {
override fun toString(): String {
return "https"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ import org.openmined.syft.networking.datamodels.ClientConfig

sealed class JobStatusMessage {
class JobCycleRejected(val timeout: String) : JobStatusMessage()
class JobReady(val model: String, val clientConfig: ClientConfig) : JobStatusMessage()
class JobReady(val model: String, val clientConfig: ClientConfig) : JobStatusMessage()
}
Loading