Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ allprojects {
url "https://dl.bintray.com/openmined/KotlinSyft"
}
maven { url 'https://jitpack.io' }
//latest robolectric
maven { url "https://oss.sonatype.org/content/repositories/snapshots" }
}
}

Expand Down
2 changes: 1 addition & 1 deletion buildSrc/src/main/java/Dependencies.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object Versions {
const val adxCore = "1.1.0"
const val adxRunner = "1.1.1"
const val adxExtTruth = "1.1.0"
const val robolectric = "4.3"
const val robolectric = "4.4-SNAPSHOT"

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MnistViewModel(
configuration: SyftConfiguration,
private val mnistDataRepository: MNISTDataRepository
) : ViewModel() {
private val syftWorker = Syft.getInstance(authToken, configuration)
private val syftWorker = Syft.getInstance(configuration,authToken)
private val mnistJob = syftWorker.newJob("mnist", "1.0.0")

val logger
Expand Down
2 changes: 1 addition & 1 deletion syftlib/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ dependencies {
implementation SyftlibDependencies.torchVisionAndroid

androidTestImplementation CommonDependencies.adxExtJunit
androidTestImplementation CommonDependencies.adxTest
androidTestImplementation CommonDependencies.adxRunner
androidTestImplementation CommonDependencies.adxExtTruth
androidTestImplementation CommonDependencies.espresso

testImplementation CommonDependencies.adxTest
testImplementation CommonDependencies.robolectric
testImplementation CommonDependencies.junit
testImplementation CommonDependencies.mockitoCore
Expand Down
38 changes: 24 additions & 14 deletions syftlib/src/main/java/org/openmined/syft/Syft.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.openmined.syft.networking.datamodels.syft.AuthenticationRequest
import org.openmined.syft.networking.datamodels.syft.AuthenticationResponse
import org.openmined.syft.networking.datamodels.syft.CycleRequest
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import java.lang.Exception
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

Expand All @@ -21,25 +22,29 @@ private const val TAG = "Syft"

@ExperimentalUnsignedTypes
class Syft internal constructor(
private val authToken: String,
private val syftConfig: SyftConfiguration,
private val deviceMonitor: DeviceMonitor
) : Disposable {
private val deviceMonitor: DeviceMonitor,
private val authToken: String?
) : Disposable {
companion object {
@Volatile
private var INSTANCE: Syft? = null

fun getInstance(
authToken: String,
syftConfiguration: SyftConfiguration
): Syft =
INSTANCE ?: synchronized(this) {
INSTANCE ?: Syft(
authToken,
syftConfiguration,
DeviceMonitor.construct(syftConfiguration)
).also { INSTANCE = it }
}
syftConfiguration: SyftConfiguration,
authToken: String? = null
): Syft {
return INSTANCE ?: synchronized(this) {
INSTANCE?.let {
if (it.syftConfig == syftConfiguration && it.authToken == authToken) it
else throw ExceptionInInitializerError("syft worker initialised with different parameters. Dispose previous worker")
} ?: Syft(
syftConfiguration,
DeviceMonitor.construct(syftConfiguration),
authToken
).also { INSTANCE = it }
}
}
}

private val workerJobs = ConcurrentHashMap<SyftJob.JobID, SyftJob>()
Expand All @@ -59,6 +64,9 @@ class Syft internal constructor(
this,
syftConfig
)
if (syftConfig.maxConcurrentJobs == workerJobs.size)
throw IndexOutOfBoundsException("maximum number of allowed jobs reached")

workerJobs[job.jobId] = job
job.subscribe(object : JobStatusSubscriber() {
override fun onComplete() {
Expand Down Expand Up @@ -213,8 +221,10 @@ class Syft internal constructor(
setSyftWorkerId(t.workerId)
executeCycleRequest(job)
}
is AuthenticationResponse.AuthenticationError ->
is AuthenticationResponse.AuthenticationError -> {
job.throwError(SecurityException(t.errorMessage))
Log.d(TAG, t.errorMessage)
}
}
}, {
job.throwError(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class SyftConfiguration internal constructor(
val networkConstraints: List<Int>,
val transportMedium: Int,
val cacheTimeOut: Long,
val maxConcurrentJobs: Int,
private val socketClient: SocketClient,
private val httpClient: HttpClient,
private val maxConcurrentJobs: Int,
private val messagingClient: NetworkingClients
) {
companion object {
Expand Down Expand Up @@ -55,7 +55,7 @@ class SyftConfiguration internal constructor(
get() = Schedulers.single()
}
private var socketClient = SocketClient(baseUrl, 20000u, networkingSchedulers)
private var httpClient = HttpClient(baseUrl)
private var httpClient = HttpClient.initialize(baseUrl)
private var filesDir = context.filesDir
private var batteryCheckEnabled = true
private var maxConcurrentJobs: Int = 1
Expand All @@ -80,9 +80,9 @@ class SyftConfiguration internal constructor(
constraintList,
networkTransportMedium,
cacheTimeOut,
maxConcurrentJobs,
socketClient,
httpClient,
maxConcurrentJobs,
messagingClient
)
}
Expand Down
11 changes: 7 additions & 4 deletions syftlib/src/main/java/org/openmined/syft/execution/SyftJob.kt
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,13 @@ class SyftJob(
override fun isDisposed() = isDisposed.get()

override fun dispose() {
jobStatusProcessor.onComplete()
networkDisposable.clear()
isDisposed.set(true)
Log.d(TAG,"job $jobId disposed")
if (!isDisposed()) {
jobStatusProcessor.onComplete()
networkDisposable.clear()
isDisposed.set(true)
Log.d(TAG, "job $jobId disposed")
} else
Log.d(TAG, "job $jobId already disposed")
}

private fun getDownloadables(workerId: String, request: String): List<Single<String>> {
Expand Down
24 changes: 17 additions & 7 deletions syftlib/src/main/java/org/openmined/syft/monitor/DeviceMonitor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@ import io.reactivex.disposables.Disposable
import org.openmined.syft.domain.SyftConfiguration
import org.openmined.syft.monitor.battery.BatteryStatusRepository
import org.openmined.syft.monitor.network.NetworkStatusRepository
import org.openmined.syft.threading.ProcessSchedulers
import java.util.concurrent.atomic.AtomicBoolean

private const val TAG = "device monitor"

@ExperimentalUnsignedTypes
class DeviceMonitor(
private val networkStatusRepository: NetworkStatusRepository,
private val batteryStatusRepository: BatteryStatusRepository
private val batteryStatusRepository: BatteryStatusRepository,
private val processSchedulers: ProcessSchedulers
) : Disposable {

companion object {
fun construct(syftConfig: SyftConfiguration): DeviceMonitor {
fun construct(
syftConfig: SyftConfiguration
): DeviceMonitor {
return DeviceMonitor(
NetworkStatusRepository.initialize(syftConfig),
BatteryStatusRepository.initialize(syftConfig)
BatteryStatusRepository.initialize(syftConfig),
syftConfig.networkingSchedulers
)
}
}
Expand Down Expand Up @@ -65,6 +70,8 @@ class DeviceMonitor(

compositeDisposable.add(
statusListener
.subscribeOn(processSchedulers.computeThreadScheduler)
.observeOn(processSchedulers.calleeThreadScheduler)
.subscribe {
when (it) {
is StateChangeMessage.Charging -> {
Expand All @@ -91,9 +98,12 @@ class DeviceMonitor(

override fun dispose() {
compositeDisposable.clear()
networkStatusRepository.unsubscribeStateChange()
batteryStatusRepository.unsubscribeStateChange()
isDisposed.set(true)
Log.d(TAG,"disposed device monitor")
if (!isDisposed()) {
networkStatusRepository.unsubscribeStateChange()
batteryStatusRepository.unsubscribeStateChange()
isDisposed.set(true)
Log.d(TAG, "disposed device monitor")
} else
Log.d(TAG,"device monitor already disposed")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@ import retrofit2.Retrofit
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory

/**
* Creates a retrofit api client for PyGrid endpoints.
*
* @property baseUrl url of the server hosting the pyGrid instance.
* @see org.openmined.syft.networking.requests.HttpAPI for endpoint description.
* @property apiClient the retrofit api client
*/
class HttpClient(baseUrl: String) {
val apiClient: HttpAPI = Retrofit.Builder()
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(Json.asConverterFactory("application/json".toMediaType()))
.baseUrl("${NetworkingProtocol.HTTP}://$baseUrl")
.build().create(HttpAPI::class.java)

class HttpClient(val apiClient: HttpAPI) {
companion object {
/**
* Creates a retrofit api client for PyGrid endpoints.
*
* @param baseUrl url of the server hosting the pyGrid instance.
* @see org.openmined.syft.networking.requests.HttpAPI for endpoint description.
*/
fun initialize(baseUrl: String): HttpClient {
val apiClient: HttpAPI = Retrofit.Builder()
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(Json.asConverterFactory("application/json".toMediaType()))
.baseUrl("${NetworkingProtocol.HTTP}://$baseUrl")
.build().create(HttpAPI::class.java)
return HttpClient(apiClient)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,13 @@ class SocketClient(

// Free resources
override fun dispose() {
syftWebSocket.close()
compositeDisposable.clear()
socketClientSubscribed.set(false)
Log.d(TAG,"Socket Client Disposed")
if (isDisposed) {
syftWebSocket.close()
socketClientSubscribed.set(false)
Log.d(TAG, "Socket Client Disposed")
} else
Log.d(TAG,"socket client already disposed")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const val AUTH_FAILURE = "rejected"
@Serializable
data class AuthenticationRequest(
@SerialName("auth_token")
val authToken: String
val authToken: String? = null
) : NetworkModels()

@Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import kotlinx.serialization.Serializable
import org.openmined.syft.networking.datamodels.NetworkModels

@Serializable
class SpeedCheckResponse(
data class SpeedCheckResponse(
@SerialName("error")
val error: String? = null
) : NetworkModels()

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.openmined.syft.integration

import android.content.Context
import android.content.Intent
import android.net.ConnectivityManager
import android.net.NetworkCapabilities
import android.os.BatteryManager
import androidx.test.core.app.ApplicationProvider
import io.reactivex.Scheduler
import io.reactivex.schedulers.Schedulers
import org.junit.Before
import org.junit.runner.RunWith
import org.openmined.syft.threading.ProcessSchedulers
import org.robolectric.RobolectricTestRunner
import org.robolectric.Shadows
import org.robolectric.annotation.LooperMode.Mode.PAUSED
import org.robolectric.annotation.LooperMode
import org.robolectric.shadow.api.Shadow
import org.robolectric.shadows.ShadowConnectivityManager
import org.robolectric.shadows.ShadowNetworkCapabilities

@ExperimentalUnsignedTypes
@RunWith(RobolectricTestRunner::class)
@LooperMode(PAUSED)
abstract class AbstractSyftWorkerTest {

protected val context: Context = ApplicationProvider.getApplicationContext()
protected val networkConstraints = listOf(
NetworkCapabilities.NET_CAPABILITY_INTERNET,
NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED,
NetworkCapabilities.NET_CAPABILITY_NOT_METERED
)
protected val networkingSchedulers = object : ProcessSchedulers {
override val computeThreadScheduler: Scheduler
get() = Schedulers.trampoline()
override val calleeThreadScheduler: Scheduler
get() = Schedulers.trampoline()
}
protected val computeSchedulers = object : ProcessSchedulers {
override val computeThreadScheduler: Scheduler
get() = Schedulers.trampoline()
override val calleeThreadScheduler: Scheduler
get() = Schedulers.trampoline()
}

@Before
open fun initialiseContext() {
val networkManager = getConnectivityManager()
val networkCapability = ShadowNetworkCapabilities.newInstance()
val shadowNC = Shadows.shadowOf(networkCapability)
shadowNC.addTransportType(ConnectivityManager.TYPE_WIFI)
networkConstraints.forEach { shadowNC.addCapability(it) }
getShadowConnectivityManager()
.setNetworkCapabilities(networkManager.activeNetwork, networkCapability)

val batteryStatus = Shadow.newInstanceOf(Intent::class.java)
batteryStatus.action = Intent.ACTION_BATTERY_CHANGED
batteryStatus.putExtra(BatteryManager.EXTRA_LEVEL, 1000)
batteryStatus.putExtra(BatteryManager.EXTRA_SCALE, 4000)
batteryStatus.putExtra(BatteryManager.EXTRA_STATUS, BatteryManager.BATTERY_STATUS_CHARGING)
context.sendStickyBroadcast(batteryStatus)
}

fun getConnectivityManager(): ConnectivityManager {
return context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
}

fun getShadowConnectivityManager(): ShadowConnectivityManager {
return Shadows.shadowOf(getConnectivityManager())
}
}
Loading