Sync pairing will now always happen in parallel for direct and relayed and reduced amount of occupied threads.

This commit is contained in:
Koen J
2025-10-03 14:11:48 +02:00
parent 642d218c54
commit 137ba85538
4 changed files with 131 additions and 104 deletions
@@ -216,10 +216,9 @@ private fun ByteArray.toInetAddress(): InetAddress {
return InetAddress.getByAddress(this);
}
fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int): Socket? {
fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int, timeoutMs: Int = 10_000): Socket? {
ensureNotMainThread()
val timeout = 10000
val addresses = if(!Settings.instance.casting.allowIpv6) attemptAddresses.filterIsInstance<Inet4Address>() else attemptAddresses;
if(addresses.isEmpty())
throw IllegalStateException("No valid addresses found (ipv6: ${(if(Settings.instance.casting.allowIpv6) "enabled" else "disabled")})");
@@ -232,7 +231,7 @@ fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int): Socket?
val socket = Socket()
try {
return socket.apply { this.connect(InetSocketAddress(addresses[0], port), timeout) }
return socket.apply { this.connect(InetSocketAddress(addresses[0], port), timeoutMs) }
} catch (e: Throwable) {
Log.i("getConnectedSocket", "Failed to connect to: ${addresses[0]}", e)
socket.close()
@@ -263,7 +262,7 @@ fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int): Socket?
}
}
socket.connect(InetSocketAddress(address, port), timeout);
socket.connect(InetSocketAddress(address, port), timeoutMs);
synchronized(syncObject) {
if (connectedSocket == null) {
@@ -110,7 +110,7 @@ class SyncPairActivity : AppCompatActivity() {
lifecycleScope.launch(Dispatchers.IO) {
try {
StateSync.instance.syncService?.connect(deviceInfo) { complete, message ->
StateSync.instance.syncService?.connect(deviceInfo, true) { complete, message ->
lifecycleScope.launch(Dispatchers.Main) {
if (complete != null) {
if (complete) {
@@ -6,6 +6,7 @@ import android.net.nsd.NsdServiceInfo
import android.os.Build
import android.util.Log
import com.futo.platformplayer.Settings
import com.futo.platformplayer.ensureNotMainThread
import com.futo.platformplayer.generateReadablePassword
import com.futo.platformplayer.getConnectedSocket
import com.futo.platformplayer.logging.Logger
@@ -17,14 +18,23 @@ import com.futo.polycentric.core.base64UrlToByteArray
import com.futo.polycentric.core.toBase64
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.selects.select
import kotlinx.coroutines.withContext
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.ServerSocket
import java.net.Socket
import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel
import java.util.Base64
import java.util.Locale
import kotlin.math.min
@@ -64,11 +74,7 @@ class SyncService(
private val database: ISyncDatabaseProvider,
private val settings: SyncServiceSettings = SyncServiceSettings()
) {
private var _serverSocket: ServerSocket? = null
private var _thread: Thread? = null
private var _connectThread: Thread? = null
private var _mdnsThread: Thread? = null
@Volatile private var _started = false
private var _serverSocket: ServerSocketChannel? = null
private val _sessions: MutableMap<String, SyncSession> = mutableMapOf()
private val _lastConnectTimesMdns: MutableMap<String, Long> = mutableMapOf()
private val _lastConnectTimesIp: MutableMap<String, Long> = mutableMapOf()
@@ -82,10 +88,10 @@ class SyncService(
private val _pairingCode: String? = generateReadablePassword(8)
val pairingCode: String? get() = _pairingCode
private var _relaySession: SyncSocketSession? = null
private var _threadRelay: Thread? = null
private val _remotePendingStatusUpdate = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>()
private val _remotePendingStatusUpdateRelayed = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>()
private val _remotePendingStatusUpdateDirect = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>()
private var _nsdManager: NsdManager? = null
private var _scope: CoroutineScope? = null
@Volatile private var _scope: CoroutineScope? = null
private val _mdnsCache = mutableMapOf<String, SyncDeviceInfo>()
private var _discoveryListener: NsdManager.DiscoveryListener = object : NsdManager.DiscoveryListener {
override fun onDiscoveryStarted(regType: String) {
@@ -216,11 +222,12 @@ class SyncService(
var authorizePrompt: ((String, (Boolean) -> Unit) -> Unit)? = null
fun start(context: Context) {
if (_started) {
Logger.i(TAG, "Already started.")
if (_scope != null) {
Log.i(TAG, "Already started.")
return
}
_started = true
Log.i(TAG, "Start SyncService.")
_scope = CoroutineScope(Dispatchers.IO)
try {
@@ -294,27 +301,30 @@ class SyncService(
private fun startListener() {
serverSocketFailedToStart = false
serverSocketStarted = false
_thread = Thread {
_scope?.launch(Dispatchers.IO) {
try {
val serverSocket = ServerSocket(settings.listenerPort)
val serverSocket = ServerSocketChannel.open()
serverSocket.socket().bind(InetSocketAddress("0.0.0.0", settings.listenerPort))
_serverSocket = serverSocket
serverSocketStarted = true
Log.i(TAG, "Running on port ${settings.listenerPort} (TCP)")
serverSocketStarted = true
while (_started) {
while (isActive) {
val socket = serverSocket.accept()
val session = createSocketSession(socket, true)
//TODO: Switch to SocketChannel?
val session = createSocketSession(socket.socket(), true)
session.startAsResponder()
}
serverSocketStarted = false
} catch (e: ClosedChannelException) {
// normal shutdown
} catch (e: Throwable) {
Logger.e(TAG, "Failed to bind server socket to port ${settings.listenerPort}", e)
Log.e(TAG, "Failed to bind server socket to port ${settings.listenerPort}", e)
serverSocketFailedToStart = true
} finally {
serverSocketStarted = false
}
}.apply { start() }
}
}
private fun startMdnsRetryLoop() {
@@ -322,43 +332,44 @@ class SyncService(
discoverServices(serviceName, NsdManager.PROTOCOL_DNS_SD, _discoveryListener)
}
_mdnsThread = Thread {
while (_started) {
_scope?.launch(Dispatchers.IO) {
while (isActive) {
try {
val now = System.currentTimeMillis()
synchronized(_mdnsCache) {
for ((pkey, info) in _mdnsCache) {
if (!database.isAuthorized(pkey) || getLinkType(pkey) == LinkType.Direct) continue
val pairs = synchronized (_mdnsCache) { _mdnsCache.toList() }
for ((pkey, info) in pairs) {
if (!database.isAuthorized(pkey) || getLinkType(pkey) == LinkType.Direct) continue
val last = synchronized(_lastConnectTimesMdns) {
_lastConnectTimesMdns[pkey] ?: 0L
}
if (now - last > 30_000L) {
val last = synchronized(_lastConnectTimesMdns) {
_lastConnectTimesMdns[pkey] ?: 0L
}
if (now - last > 30_000L) {
synchronized(_lastConnectTimesMdns) {
_lastConnectTimesMdns[pkey] = now
try {
Logger.i(TAG, "MDNS-retry: connecting to $pkey")
connect(info)
} catch (ex: Throwable) {
Logger.w(TAG, "MDNS retry failed for $pkey", ex)
}
}
try {
Log.i(TAG, "MDNS-retry: connecting to $pkey")
connect(info)
if (!isActive) break
} catch (ex: Throwable) {
Log.w(TAG, "MDNS retry failed for $pkey", ex)
}
}
}
} catch (ex: Throwable) {
Logger.e(TAG, "Error in MDNS retry loop", ex)
Log.e(TAG, "Error in MDNS retry loop", ex)
}
Thread.sleep(5000)
delay(5000)
}
}.apply { start() }
}
}
private fun startConnectLastLoop() {
_connectThread = Thread {
_scope?.launch(Dispatchers.IO) {
Log.i(TAG, "Running auto reconnector")
while (_started) {
val authorizedDevices = database.getAllAuthorizedDevices() ?: arrayOf()
while (isActive) {
val authorizedDevices = database.getAllAuthorizedDevices()?.toList() ?: listOf()
val addressesToConnect = authorizedDevices.mapNotNull {
val connectedDirectly = getLinkType(it) == LinkType.Direct
if (connectedDirectly) {
@@ -382,26 +393,26 @@ class SyncService(
_lastConnectTimesIp[connectPair.first] = now
}
Logger.i(TAG, "Attempting to connect to authorized device by last known IP '${connectPair.first}' with pkey=${connectPair.first}")
Log.i(TAG, "Attempting to connect to authorized device by last known IP '${connectPair.first}' with pkey=${connectPair.first}")
connect(arrayOf(connectPair.second), settings.listenerPort, connectPair.first, null)
}
} catch (e: Throwable) {
Logger.i(TAG, "Failed to connect to " + connectPair.first, e)
Log.i(TAG, "Failed to connect to " + connectPair.first, e)
}
}
Thread.sleep(5000)
delay(5000)
}
}.apply { start() }
}
}
private fun startRelayLoop() {
relayConnected = false
_threadRelay = Thread {
_scope?.launch(Dispatchers.IO) {
try {
var backoffs: Array<Long> = arrayOf(1000, 5000, 10000, 20000)
var backoffIndex = 0;
while (_started) {
while (isActive) {
try {
Log.i(TAG, "Starting relay session...")
relayConnected = false
@@ -465,7 +476,7 @@ class SyncService(
Thread {
try {
while (_started && !socketClosed) {
while (isActive && !socketClosed) {
val unconnectedAuthorizedDevices =
database.getAllAuthorizedDevices()
?.filter {
@@ -503,27 +514,14 @@ class SyncService(
connectionInfo.ipv4Addresses
.filter { it != connectionInfo.remoteIp }
if (getLinkType(targetKey) != LinkType.Direct && connectionInfo.allowLocalDirect && Settings.instance.synchronization.connectLocalDirectThroughRelay) {
Thread {
launch(Dispatchers.IO) {
try {
Log.v(
TAG,
"Attempting to connect directly, locally to '$targetKey'."
)
connect(
potentialLocalAddresses.map { it }
.toTypedArray(),
settings.listenerPort,
targetKey,
null
)
Log.v(TAG, "Attempting to connect directly, locally to '$targetKey'.")
connect(potentialLocalAddresses.map { it }.toTypedArray(), settings.listenerPort, targetKey, null)
} catch (e: Throwable) {
Log.e(
TAG,
"Failed to start direct connection using connection info with $targetKey.",
e
)
Log.e(TAG, "Failed to start direct connection using connection info with $targetKey.", e)
}
}.start()
}
}
if (connectionInfo.allowRemoteDirect) {
@@ -587,7 +585,7 @@ class SyncService(
} catch (ex: Throwable) {
Log.i(TAG, "Unhandled exception in relay loop.", ex)
}
}.apply { start() }
}
}
private fun createSocketSession(socket: Socket, isResponder: Boolean): SyncSocketSession {
@@ -699,14 +697,21 @@ class SyncService(
return _pairingCode == pairingCode
}
private fun sendRemotePendingStatusUpdate(remotePublicKey: String, complete: Boolean, message: String) {
synchronized(_remotePendingStatusUpdateDirect) {
_remotePendingStatusUpdateDirect.remove(remotePublicKey)?.invoke(complete, message)
}
synchronized(_remotePendingStatusUpdateRelayed) {
_remotePendingStatusUpdateRelayed.remove(remotePublicKey)?.invoke(complete, message)
}
}
private fun createNewSyncSession(rpk: String, remoteDeviceName: String?): SyncSession {
val remotePublicKey = rpk.base64ToByteArray().toBase64()
return SyncSession(
remotePublicKey,
onAuthorized = { it, isNewlyAuthorized, isNewSession ->
synchronized(_remotePendingStatusUpdate) {
_remotePendingStatusUpdate.remove(remotePublicKey)?.invoke(true, "Authorized")
}
sendRemotePendingStatusUpdate(remotePublicKey, true, "Authorized")
if (isNewSession) {
it.remoteDeviceName?.let { remoteDeviceName ->
@@ -719,10 +724,7 @@ class SyncService(
onAuthorized?.invoke(it, isNewlyAuthorized, isNewSession)
},
onUnauthorized = {
synchronized(_remotePendingStatusUpdate) {
_remotePendingStatusUpdate.remove(remotePublicKey)?.invoke(false, "Unauthorized")
}
sendRemotePendingStatusUpdate(remotePublicKey, false, "Unauthorized")
onUnauthorized?.invoke(it)
},
onConnectedChanged = { it, connected ->
@@ -733,9 +735,7 @@ class SyncService(
Logger.i(TAG, "$remotePublicKey closed")
removeSession(it.remotePublicKey)
synchronized(_remotePendingStatusUpdate) {
_remotePendingStatusUpdate.remove(remotePublicKey)?.invoke(false, "Connection closed")
}
sendRemotePendingStatusUpdate(remotePublicKey, false, "Connection closed")
onClose?.invoke(it)
},
@@ -757,42 +757,67 @@ class SyncService(
fun getAllAuthorizedDevices(): Array<String>? = database.getAllAuthorizedDevices()
fun removeAuthorizedDevice(publicKey: String) = database.removeAuthorizedDevice(publicKey)
fun connect(deviceInfo: SyncDeviceInfo, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null) {
try {
connect(deviceInfo.addresses, deviceInfo.port, deviceInfo.publicKey, deviceInfo.pairingCode, onStatusUpdate)
} catch (e: Throwable) {
Logger.e(TAG, "Failed to connect directly", e)
val relaySession = _relaySession
if (relaySession != null && Settings.instance.synchronization.pairThroughRelay) {
onStatusUpdate?.invoke(null, "Connecting via relay...")
runBlocking {
if (onStatusUpdate != null) {
synchronized(_remotePendingStatusUpdate) {
_remotePendingStatusUpdate[deviceInfo.publicKey.base64ToByteArray().toBase64()] = onStatusUpdate
}
}
relaySession.startRelayedChannel(deviceInfo.publicKey.base64ToByteArray().toBase64(), appId, deviceInfo.pairingCode)
suspend fun connect(deviceInfo: SyncDeviceInfo, alsoTryRelayed: Boolean = false, timeout_ms: Int = 5_000, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null) {
val rs = _relaySession
val startTime = System.currentTimeMillis()
if (alsoTryRelayed && rs != null && settings.relayPairAllowed) {
onStatusUpdate?.invoke(null, "Connecting via relay...")
if (onStatusUpdate != null) {
synchronized(_remotePendingStatusUpdateRelayed) {
_remotePendingStatusUpdateRelayed[deviceInfo.publicKey.base64ToByteArray().toBase64()] = onStatusUpdate
}
} else {
throw e
}
//TODO: Do not try relayed channel here only for pairing mode?
rs.startRelayedChannel(deviceInfo.publicKey.base64ToByteArray().toBase64(), appId, deviceInfo.pairingCode)
}
try {
connect(deviceInfo.addresses, deviceInfo.port, deviceInfo.publicKey, deviceInfo.pairingCode, onStatusUpdate, timeout_ms)
} catch (e: Throwable) {
Log.e(TAG, "Failed to connect directly", e)
val waitTime_ms = timeout_ms - (System.currentTimeMillis() - startTime)
if (waitTime_ms > 0)
delay(waitTime_ms)
onStatusUpdate?.invoke(false, "Failed to connect.")
synchronized(_remotePendingStatusUpdateRelayed) {
_remotePendingStatusUpdateRelayed.remove(deviceInfo.publicKey.base64ToByteArray().toBase64())
}
}
}
fun connect(addresses: Array<String>, port: Int, publicKey: String, pairingCode: String?, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null): SyncSocketSession {
suspend fun connect(addresses: Array<String>, port: Int, publicKey: String, pairingCode: String?, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null, timeout_ms: Int = 10_000): SyncSocketSession {
val startTime_ms = System.currentTimeMillis()
onStatusUpdate?.invoke(null, "Connecting directly...")
val socket = getConnectedSocket(addresses.map { InetAddress.getByName(it) }, port) ?: throw Exception("Failed to connect")
val socket = getConnectedSocket(addresses.map { InetAddress.getByName(it) }, port, timeout_ms) ?: throw Exception("Failed to connect")
onStatusUpdate?.invoke(null, "Handshaking...")
val session = createSocketSession(socket, false)
if (onStatusUpdate != null) {
synchronized(_remotePendingStatusUpdate) {
_remotePendingStatusUpdate[publicKey.base64ToByteArray().toBase64()] = onStatusUpdate
synchronized(_remotePendingStatusUpdateDirect) {
_remotePendingStatusUpdateDirect[publicKey.base64ToByteArray().toBase64()] = onStatusUpdate
}
}
session.startAsInitiator(publicKey, appId, pairingCode)
while (timeout_ms - (startTime_ms - System.currentTimeMillis()) > 0 && !session.isAuthorized && session.started) {
delay(100)
}
if (!session.isAuthorized) {
Log.i(TAG, "Session is not authorized after timeout, cancelling connection.")
session.stop()
onStatusUpdate?.invoke(false, "Session not authorized.")
synchronized(_remotePendingStatusUpdateDirect) {
_remotePendingStatusUpdateDirect.remove(publicKey.base64ToByteArray().toBase64())
}
}
return session
}
@@ -811,6 +836,8 @@ class SyncService(
synchronized(_sessions) {
_sessions.clear()
}
_remotePendingStatusUpdateDirect.clear()
_remotePendingStatusUpdateRelayed.clear()
}
private fun getDeviceName(): String {
@@ -56,6 +56,7 @@ class SyncSocketSession {
private var _remotePublicKey: String? = null
val remotePublicKey: String? get() = _remotePublicKey
private var _started: Boolean = false
val started get() = _started
private val _localKeyPair: DHState
private var _thread: Thread? = null
private var _localPublicKey: String