Sync pairing will now always happen in parallel for direct and relayed and reduced amount of occupied threads.

This commit is contained in:
Koen J
2025-10-03 14:11:48 +02:00
parent 642d218c54
commit 137ba85538
4 changed files with 131 additions and 104 deletions
@@ -216,10 +216,9 @@ private fun ByteArray.toInetAddress(): InetAddress {
return InetAddress.getByAddress(this); return InetAddress.getByAddress(this);
} }
fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int): Socket? { fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int, timeoutMs: Int = 10_000): Socket? {
ensureNotMainThread() ensureNotMainThread()
val timeout = 10000
val addresses = if(!Settings.instance.casting.allowIpv6) attemptAddresses.filterIsInstance<Inet4Address>() else attemptAddresses; val addresses = if(!Settings.instance.casting.allowIpv6) attemptAddresses.filterIsInstance<Inet4Address>() else attemptAddresses;
if(addresses.isEmpty()) if(addresses.isEmpty())
throw IllegalStateException("No valid addresses found (ipv6: ${(if(Settings.instance.casting.allowIpv6) "enabled" else "disabled")})"); throw IllegalStateException("No valid addresses found (ipv6: ${(if(Settings.instance.casting.allowIpv6) "enabled" else "disabled")})");
@@ -232,7 +231,7 @@ fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int): Socket?
val socket = Socket() val socket = Socket()
try { try {
return socket.apply { this.connect(InetSocketAddress(addresses[0], port), timeout) } return socket.apply { this.connect(InetSocketAddress(addresses[0], port), timeoutMs) }
} catch (e: Throwable) { } catch (e: Throwable) {
Log.i("getConnectedSocket", "Failed to connect to: ${addresses[0]}", e) Log.i("getConnectedSocket", "Failed to connect to: ${addresses[0]}", e)
socket.close() socket.close()
@@ -263,7 +262,7 @@ fun getConnectedSocket(attemptAddresses: List<InetAddress>, port: Int): Socket?
} }
} }
socket.connect(InetSocketAddress(address, port), timeout); socket.connect(InetSocketAddress(address, port), timeoutMs);
synchronized(syncObject) { synchronized(syncObject) {
if (connectedSocket == null) { if (connectedSocket == null) {
@@ -110,7 +110,7 @@ class SyncPairActivity : AppCompatActivity() {
lifecycleScope.launch(Dispatchers.IO) { lifecycleScope.launch(Dispatchers.IO) {
try { try {
StateSync.instance.syncService?.connect(deviceInfo) { complete, message -> StateSync.instance.syncService?.connect(deviceInfo, true) { complete, message ->
lifecycleScope.launch(Dispatchers.Main) { lifecycleScope.launch(Dispatchers.Main) {
if (complete != null) { if (complete != null) {
if (complete) { if (complete) {
@@ -6,6 +6,7 @@ import android.net.nsd.NsdServiceInfo
import android.os.Build import android.os.Build
import android.util.Log import android.util.Log
import com.futo.platformplayer.Settings import com.futo.platformplayer.Settings
import com.futo.platformplayer.ensureNotMainThread
import com.futo.platformplayer.generateReadablePassword import com.futo.platformplayer.generateReadablePassword
import com.futo.platformplayer.getConnectedSocket import com.futo.platformplayer.getConnectedSocket
import com.futo.platformplayer.logging.Logger import com.futo.platformplayer.logging.Logger
@@ -17,14 +18,23 @@ import com.futo.polycentric.core.base64UrlToByteArray
import com.futo.polycentric.core.toBase64 import com.futo.polycentric.core.toBase64
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.cancel import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.selects.select
import kotlinx.coroutines.withContext
import java.net.InetAddress import java.net.InetAddress
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.net.ServerSocket import java.net.ServerSocket
import java.net.Socket import java.net.Socket
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel
import java.util.Base64 import java.util.Base64
import java.util.Locale import java.util.Locale
import kotlin.math.min import kotlin.math.min
@@ -64,11 +74,7 @@ class SyncService(
private val database: ISyncDatabaseProvider, private val database: ISyncDatabaseProvider,
private val settings: SyncServiceSettings = SyncServiceSettings() private val settings: SyncServiceSettings = SyncServiceSettings()
) { ) {
private var _serverSocket: ServerSocket? = null private var _serverSocket: ServerSocketChannel? = null
private var _thread: Thread? = null
private var _connectThread: Thread? = null
private var _mdnsThread: Thread? = null
@Volatile private var _started = false
private val _sessions: MutableMap<String, SyncSession> = mutableMapOf() private val _sessions: MutableMap<String, SyncSession> = mutableMapOf()
private val _lastConnectTimesMdns: MutableMap<String, Long> = mutableMapOf() private val _lastConnectTimesMdns: MutableMap<String, Long> = mutableMapOf()
private val _lastConnectTimesIp: MutableMap<String, Long> = mutableMapOf() private val _lastConnectTimesIp: MutableMap<String, Long> = mutableMapOf()
@@ -82,10 +88,10 @@ class SyncService(
private val _pairingCode: String? = generateReadablePassword(8) private val _pairingCode: String? = generateReadablePassword(8)
val pairingCode: String? get() = _pairingCode val pairingCode: String? get() = _pairingCode
private var _relaySession: SyncSocketSession? = null private var _relaySession: SyncSocketSession? = null
private var _threadRelay: Thread? = null private val _remotePendingStatusUpdateRelayed = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>()
private val _remotePendingStatusUpdate = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>() private val _remotePendingStatusUpdateDirect = mutableMapOf<String, (complete: Boolean?, message: String) -> Unit>()
private var _nsdManager: NsdManager? = null private var _nsdManager: NsdManager? = null
private var _scope: CoroutineScope? = null @Volatile private var _scope: CoroutineScope? = null
private val _mdnsCache = mutableMapOf<String, SyncDeviceInfo>() private val _mdnsCache = mutableMapOf<String, SyncDeviceInfo>()
private var _discoveryListener: NsdManager.DiscoveryListener = object : NsdManager.DiscoveryListener { private var _discoveryListener: NsdManager.DiscoveryListener = object : NsdManager.DiscoveryListener {
override fun onDiscoveryStarted(regType: String) { override fun onDiscoveryStarted(regType: String) {
@@ -216,11 +222,12 @@ class SyncService(
var authorizePrompt: ((String, (Boolean) -> Unit) -> Unit)? = null var authorizePrompt: ((String, (Boolean) -> Unit) -> Unit)? = null
fun start(context: Context) { fun start(context: Context) {
if (_started) { if (_scope != null) {
Logger.i(TAG, "Already started.") Log.i(TAG, "Already started.")
return return
} }
_started = true
Log.i(TAG, "Start SyncService.")
_scope = CoroutineScope(Dispatchers.IO) _scope = CoroutineScope(Dispatchers.IO)
try { try {
@@ -294,27 +301,30 @@ class SyncService(
private fun startListener() { private fun startListener() {
serverSocketFailedToStart = false serverSocketFailedToStart = false
serverSocketStarted = false serverSocketStarted = false
_thread = Thread { _scope?.launch(Dispatchers.IO) {
try { try {
val serverSocket = ServerSocket(settings.listenerPort) val serverSocket = ServerSocketChannel.open()
serverSocket.socket().bind(InetSocketAddress("0.0.0.0", settings.listenerPort))
_serverSocket = serverSocket _serverSocket = serverSocket
serverSocketStarted = true
Log.i(TAG, "Running on port ${settings.listenerPort} (TCP)") Log.i(TAG, "Running on port ${settings.listenerPort} (TCP)")
serverSocketStarted = true
while (_started) { while (isActive) {
val socket = serverSocket.accept() val socket = serverSocket.accept()
val session = createSocketSession(socket, true) //TODO: Switch to SocketChannel?
val session = createSocketSession(socket.socket(), true)
session.startAsResponder() session.startAsResponder()
} }
} catch (e: ClosedChannelException) {
serverSocketStarted = false // normal shutdown
} catch (e: Throwable) { } catch (e: Throwable) {
Logger.e(TAG, "Failed to bind server socket to port ${settings.listenerPort}", e) Log.e(TAG, "Failed to bind server socket to port ${settings.listenerPort}", e)
serverSocketFailedToStart = true serverSocketFailedToStart = true
} finally {
serverSocketStarted = false serverSocketStarted = false
} }
}.apply { start() } }
} }
private fun startMdnsRetryLoop() { private fun startMdnsRetryLoop() {
@@ -322,43 +332,44 @@ class SyncService(
discoverServices(serviceName, NsdManager.PROTOCOL_DNS_SD, _discoveryListener) discoverServices(serviceName, NsdManager.PROTOCOL_DNS_SD, _discoveryListener)
} }
_mdnsThread = Thread { _scope?.launch(Dispatchers.IO) {
while (_started) { while (isActive) {
try { try {
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
synchronized(_mdnsCache) { val pairs = synchronized (_mdnsCache) { _mdnsCache.toList() }
for ((pkey, info) in _mdnsCache) { for ((pkey, info) in pairs) {
if (!database.isAuthorized(pkey) || getLinkType(pkey) == LinkType.Direct) continue if (!database.isAuthorized(pkey) || getLinkType(pkey) == LinkType.Direct) continue
val last = synchronized(_lastConnectTimesMdns) { val last = synchronized(_lastConnectTimesMdns) {
_lastConnectTimesMdns[pkey] ?: 0L _lastConnectTimesMdns[pkey] ?: 0L
} }
if (now - last > 30_000L) { if (now - last > 30_000L) {
synchronized(_lastConnectTimesMdns) {
_lastConnectTimesMdns[pkey] = now _lastConnectTimesMdns[pkey] = now
try { }
Logger.i(TAG, "MDNS-retry: connecting to $pkey") try {
connect(info) Log.i(TAG, "MDNS-retry: connecting to $pkey")
} catch (ex: Throwable) { connect(info)
Logger.w(TAG, "MDNS retry failed for $pkey", ex) if (!isActive) break
} } catch (ex: Throwable) {
Log.w(TAG, "MDNS retry failed for $pkey", ex)
} }
} }
} }
} catch (ex: Throwable) { } catch (ex: Throwable) {
Logger.e(TAG, "Error in MDNS retry loop", ex) Log.e(TAG, "Error in MDNS retry loop", ex)
} }
Thread.sleep(5000) delay(5000)
} }
}.apply { start() } }
} }
private fun startConnectLastLoop() { private fun startConnectLastLoop() {
_connectThread = Thread { _scope?.launch(Dispatchers.IO) {
Log.i(TAG, "Running auto reconnector") Log.i(TAG, "Running auto reconnector")
while (_started) { while (isActive) {
val authorizedDevices = database.getAllAuthorizedDevices() ?: arrayOf() val authorizedDevices = database.getAllAuthorizedDevices()?.toList() ?: listOf()
val addressesToConnect = authorizedDevices.mapNotNull { val addressesToConnect = authorizedDevices.mapNotNull {
val connectedDirectly = getLinkType(it) == LinkType.Direct val connectedDirectly = getLinkType(it) == LinkType.Direct
if (connectedDirectly) { if (connectedDirectly) {
@@ -382,26 +393,26 @@ class SyncService(
_lastConnectTimesIp[connectPair.first] = now _lastConnectTimesIp[connectPair.first] = now
} }
Logger.i(TAG, "Attempting to connect to authorized device by last known IP '${connectPair.first}' with pkey=${connectPair.first}") Log.i(TAG, "Attempting to connect to authorized device by last known IP '${connectPair.first}' with pkey=${connectPair.first}")
connect(arrayOf(connectPair.second), settings.listenerPort, connectPair.first, null) connect(arrayOf(connectPair.second), settings.listenerPort, connectPair.first, null)
} }
} catch (e: Throwable) { } catch (e: Throwable) {
Logger.i(TAG, "Failed to connect to " + connectPair.first, e) Log.i(TAG, "Failed to connect to " + connectPair.first, e)
} }
} }
Thread.sleep(5000) delay(5000)
} }
}.apply { start() } }
} }
private fun startRelayLoop() { private fun startRelayLoop() {
relayConnected = false relayConnected = false
_threadRelay = Thread { _scope?.launch(Dispatchers.IO) {
try { try {
var backoffs: Array<Long> = arrayOf(1000, 5000, 10000, 20000) var backoffs: Array<Long> = arrayOf(1000, 5000, 10000, 20000)
var backoffIndex = 0; var backoffIndex = 0;
while (_started) { while (isActive) {
try { try {
Log.i(TAG, "Starting relay session...") Log.i(TAG, "Starting relay session...")
relayConnected = false relayConnected = false
@@ -465,7 +476,7 @@ class SyncService(
Thread { Thread {
try { try {
while (_started && !socketClosed) { while (isActive && !socketClosed) {
val unconnectedAuthorizedDevices = val unconnectedAuthorizedDevices =
database.getAllAuthorizedDevices() database.getAllAuthorizedDevices()
?.filter { ?.filter {
@@ -503,27 +514,14 @@ class SyncService(
connectionInfo.ipv4Addresses connectionInfo.ipv4Addresses
.filter { it != connectionInfo.remoteIp } .filter { it != connectionInfo.remoteIp }
if (getLinkType(targetKey) != LinkType.Direct && connectionInfo.allowLocalDirect && Settings.instance.synchronization.connectLocalDirectThroughRelay) { if (getLinkType(targetKey) != LinkType.Direct && connectionInfo.allowLocalDirect && Settings.instance.synchronization.connectLocalDirectThroughRelay) {
Thread { launch(Dispatchers.IO) {
try { try {
Log.v( Log.v(TAG, "Attempting to connect directly, locally to '$targetKey'.")
TAG, connect(potentialLocalAddresses.map { it }.toTypedArray(), settings.listenerPort, targetKey, null)
"Attempting to connect directly, locally to '$targetKey'."
)
connect(
potentialLocalAddresses.map { it }
.toTypedArray(),
settings.listenerPort,
targetKey,
null
)
} catch (e: Throwable) { } catch (e: Throwable) {
Log.e( Log.e(TAG, "Failed to start direct connection using connection info with $targetKey.", e)
TAG,
"Failed to start direct connection using connection info with $targetKey.",
e
)
} }
}.start() }
} }
if (connectionInfo.allowRemoteDirect) { if (connectionInfo.allowRemoteDirect) {
@@ -587,7 +585,7 @@ class SyncService(
} catch (ex: Throwable) { } catch (ex: Throwable) {
Log.i(TAG, "Unhandled exception in relay loop.", ex) Log.i(TAG, "Unhandled exception in relay loop.", ex)
} }
}.apply { start() } }
} }
private fun createSocketSession(socket: Socket, isResponder: Boolean): SyncSocketSession { private fun createSocketSession(socket: Socket, isResponder: Boolean): SyncSocketSession {
@@ -699,14 +697,21 @@ class SyncService(
return _pairingCode == pairingCode return _pairingCode == pairingCode
} }
private fun sendRemotePendingStatusUpdate(remotePublicKey: String, complete: Boolean, message: String) {
synchronized(_remotePendingStatusUpdateDirect) {
_remotePendingStatusUpdateDirect.remove(remotePublicKey)?.invoke(complete, message)
}
synchronized(_remotePendingStatusUpdateRelayed) {
_remotePendingStatusUpdateRelayed.remove(remotePublicKey)?.invoke(complete, message)
}
}
private fun createNewSyncSession(rpk: String, remoteDeviceName: String?): SyncSession { private fun createNewSyncSession(rpk: String, remoteDeviceName: String?): SyncSession {
val remotePublicKey = rpk.base64ToByteArray().toBase64() val remotePublicKey = rpk.base64ToByteArray().toBase64()
return SyncSession( return SyncSession(
remotePublicKey, remotePublicKey,
onAuthorized = { it, isNewlyAuthorized, isNewSession -> onAuthorized = { it, isNewlyAuthorized, isNewSession ->
synchronized(_remotePendingStatusUpdate) { sendRemotePendingStatusUpdate(remotePublicKey, true, "Authorized")
_remotePendingStatusUpdate.remove(remotePublicKey)?.invoke(true, "Authorized")
}
if (isNewSession) { if (isNewSession) {
it.remoteDeviceName?.let { remoteDeviceName -> it.remoteDeviceName?.let { remoteDeviceName ->
@@ -719,10 +724,7 @@ class SyncService(
onAuthorized?.invoke(it, isNewlyAuthorized, isNewSession) onAuthorized?.invoke(it, isNewlyAuthorized, isNewSession)
}, },
onUnauthorized = { onUnauthorized = {
synchronized(_remotePendingStatusUpdate) { sendRemotePendingStatusUpdate(remotePublicKey, false, "Unauthorized")
_remotePendingStatusUpdate.remove(remotePublicKey)?.invoke(false, "Unauthorized")
}
onUnauthorized?.invoke(it) onUnauthorized?.invoke(it)
}, },
onConnectedChanged = { it, connected -> onConnectedChanged = { it, connected ->
@@ -733,9 +735,7 @@ class SyncService(
Logger.i(TAG, "$remotePublicKey closed") Logger.i(TAG, "$remotePublicKey closed")
removeSession(it.remotePublicKey) removeSession(it.remotePublicKey)
synchronized(_remotePendingStatusUpdate) { sendRemotePendingStatusUpdate(remotePublicKey, false, "Connection closed")
_remotePendingStatusUpdate.remove(remotePublicKey)?.invoke(false, "Connection closed")
}
onClose?.invoke(it) onClose?.invoke(it)
}, },
@@ -757,42 +757,67 @@ class SyncService(
fun getAllAuthorizedDevices(): Array<String>? = database.getAllAuthorizedDevices() fun getAllAuthorizedDevices(): Array<String>? = database.getAllAuthorizedDevices()
fun removeAuthorizedDevice(publicKey: String) = database.removeAuthorizedDevice(publicKey) fun removeAuthorizedDevice(publicKey: String) = database.removeAuthorizedDevice(publicKey)
fun connect(deviceInfo: SyncDeviceInfo, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null) {
try {
connect(deviceInfo.addresses, deviceInfo.port, deviceInfo.publicKey, deviceInfo.pairingCode, onStatusUpdate)
} catch (e: Throwable) {
Logger.e(TAG, "Failed to connect directly", e)
val relaySession = _relaySession
if (relaySession != null && Settings.instance.synchronization.pairThroughRelay) {
onStatusUpdate?.invoke(null, "Connecting via relay...")
runBlocking { suspend fun connect(deviceInfo: SyncDeviceInfo, alsoTryRelayed: Boolean = false, timeout_ms: Int = 5_000, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null) {
if (onStatusUpdate != null) { val rs = _relaySession
synchronized(_remotePendingStatusUpdate) { val startTime = System.currentTimeMillis()
_remotePendingStatusUpdate[deviceInfo.publicKey.base64ToByteArray().toBase64()] = onStatusUpdate if (alsoTryRelayed && rs != null && settings.relayPairAllowed) {
} onStatusUpdate?.invoke(null, "Connecting via relay...")
}
relaySession.startRelayedChannel(deviceInfo.publicKey.base64ToByteArray().toBase64(), appId, deviceInfo.pairingCode) if (onStatusUpdate != null) {
synchronized(_remotePendingStatusUpdateRelayed) {
_remotePendingStatusUpdateRelayed[deviceInfo.publicKey.base64ToByteArray().toBase64()] = onStatusUpdate
} }
} else { }
throw e //TODO: Do not try relayed channel here only for pairing mode?
rs.startRelayedChannel(deviceInfo.publicKey.base64ToByteArray().toBase64(), appId, deviceInfo.pairingCode)
}
try {
connect(deviceInfo.addresses, deviceInfo.port, deviceInfo.publicKey, deviceInfo.pairingCode, onStatusUpdate, timeout_ms)
} catch (e: Throwable) {
Log.e(TAG, "Failed to connect directly", e)
val waitTime_ms = timeout_ms - (System.currentTimeMillis() - startTime)
if (waitTime_ms > 0)
delay(waitTime_ms)
onStatusUpdate?.invoke(false, "Failed to connect.")
synchronized(_remotePendingStatusUpdateRelayed) {
_remotePendingStatusUpdateRelayed.remove(deviceInfo.publicKey.base64ToByteArray().toBase64())
} }
} }
} }
fun connect(addresses: Array<String>, port: Int, publicKey: String, pairingCode: String?, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null): SyncSocketSession { suspend fun connect(addresses: Array<String>, port: Int, publicKey: String, pairingCode: String?, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null, timeout_ms: Int = 10_000): SyncSocketSession {
val startTime_ms = System.currentTimeMillis()
onStatusUpdate?.invoke(null, "Connecting directly...") onStatusUpdate?.invoke(null, "Connecting directly...")
val socket = getConnectedSocket(addresses.map { InetAddress.getByName(it) }, port) ?: throw Exception("Failed to connect") val socket = getConnectedSocket(addresses.map { InetAddress.getByName(it) }, port, timeout_ms) ?: throw Exception("Failed to connect")
onStatusUpdate?.invoke(null, "Handshaking...") onStatusUpdate?.invoke(null, "Handshaking...")
val session = createSocketSession(socket, false) val session = createSocketSession(socket, false)
if (onStatusUpdate != null) { if (onStatusUpdate != null) {
synchronized(_remotePendingStatusUpdate) { synchronized(_remotePendingStatusUpdateDirect) {
_remotePendingStatusUpdate[publicKey.base64ToByteArray().toBase64()] = onStatusUpdate _remotePendingStatusUpdateDirect[publicKey.base64ToByteArray().toBase64()] = onStatusUpdate
} }
} }
session.startAsInitiator(publicKey, appId, pairingCode) session.startAsInitiator(publicKey, appId, pairingCode)
while (timeout_ms - (startTime_ms - System.currentTimeMillis()) > 0 && !session.isAuthorized && session.started) {
delay(100)
}
if (!session.isAuthorized) {
Log.i(TAG, "Session is not authorized after timeout, cancelling connection.")
session.stop()
onStatusUpdate?.invoke(false, "Session not authorized.")
synchronized(_remotePendingStatusUpdateDirect) {
_remotePendingStatusUpdateDirect.remove(publicKey.base64ToByteArray().toBase64())
}
}
return session return session
} }
@@ -811,6 +836,8 @@ class SyncService(
synchronized(_sessions) { synchronized(_sessions) {
_sessions.clear() _sessions.clear()
} }
_remotePendingStatusUpdateDirect.clear()
_remotePendingStatusUpdateRelayed.clear()
} }
private fun getDeviceName(): String { private fun getDeviceName(): String {
@@ -56,6 +56,7 @@ class SyncSocketSession {
private var _remotePublicKey: String? = null private var _remotePublicKey: String? = null
val remotePublicKey: String? get() = _remotePublicKey val remotePublicKey: String? get() = _remotePublicKey
private var _started: Boolean = false private var _started: Boolean = false
val started get() = _started
private val _localKeyPair: DHState private val _localKeyPair: DHState
private var _thread: Thread? = null private var _thread: Thread? = null
private var _localPublicKey: String private var _localPublicKey: String