mirror of
https://gitlab.futo.org/videostreaming/grayjay.git
synced 2026-05-17 21:32:39 +02:00
Compare commits
5 Commits
upgrade
...
remote-sync
| Author | SHA1 | Date | |
|---|---|---|---|
| b025e8a30f | |||
| 5b2f8b8617 | |||
| 955ba23b0d | |||
| 1ae9f0ea26 | |||
| 97381739dd |
@@ -0,0 +1,266 @@
|
||||
package com.futo.platformplayer
|
||||
|
||||
import com.futo.platformplayer.noise.protocol.Noise
|
||||
import com.futo.platformplayer.sync.internal.*
|
||||
import kotlinx.coroutines.*
|
||||
import org.junit.Assert.*
|
||||
import org.junit.Test
|
||||
import java.net.Socket
|
||||
import java.nio.ByteBuffer
|
||||
import kotlin.random.Random
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
|
||||
class SyncServerTests {
|
||||
|
||||
//private val relayHost = "relay.grayjay.app"
|
||||
//private val relayKey = "xGbHRzDOvE6plRbQaFgSen82eijF+gxS0yeUaeEErkw="
|
||||
private val relayKey = "XlUaSpIlRaCg0TGzZ7JYmPupgUHDqTZXUUBco2K7ejw="
|
||||
private val relayHost = "192.168.1.175"
|
||||
private val relayPort = 9000
|
||||
|
||||
/** Creates a client connected to the live relay server. */
|
||||
private suspend fun createClient(
|
||||
onHandshakeComplete: ((SyncSocketSession) -> Unit)? = null,
|
||||
onData: ((SyncSocketSession, UByte, UByte, ByteBuffer) -> Unit)? = null,
|
||||
onNewChannel: ((SyncSocketSession, ChannelRelayed) -> Unit)? = null,
|
||||
isHandshakeAllowed: ((SyncSocketSession, String, String?) -> Boolean)? = null
|
||||
): SyncSocketSession = withContext(Dispatchers.IO) {
|
||||
val p = Noise.createDH("25519")
|
||||
p.generateKeyPair()
|
||||
val socket = Socket(relayHost, relayPort)
|
||||
val inputStream = LittleEndianDataInputStream(socket.getInputStream())
|
||||
val outputStream = LittleEndianDataOutputStream(socket.getOutputStream())
|
||||
val tcs = CompletableDeferred<Boolean>()
|
||||
val socketSession = SyncSocketSession(
|
||||
relayHost,
|
||||
p,
|
||||
inputStream,
|
||||
outputStream,
|
||||
onClose = { socket.close() },
|
||||
onHandshakeComplete = { s ->
|
||||
onHandshakeComplete?.invoke(s)
|
||||
tcs.complete(true)
|
||||
},
|
||||
onData = onData ?: { _, _, _, _ -> },
|
||||
onNewChannel = onNewChannel ?: { _, _ -> },
|
||||
isHandshakeAllowed = isHandshakeAllowed ?: { _, _, _ -> true }
|
||||
)
|
||||
socketSession.authorizable = AlwaysAuthorized()
|
||||
socketSession.startAsInitiator(relayKey)
|
||||
withTimeout(5000.milliseconds) { tcs.await() }
|
||||
return@withContext socketSession
|
||||
}
|
||||
|
||||
@Test
|
||||
fun multipleClientsHandshake_Success() = runBlocking {
|
||||
val client1 = createClient()
|
||||
val client2 = createClient()
|
||||
assertNotNull(client1.remotePublicKey, "Client 1 handshake failed")
|
||||
assertNotNull(client2.remotePublicKey, "Client 2 handshake failed")
|
||||
client1.stop()
|
||||
client2.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun publishAndRequestConnectionInfo_Authorized_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val clientC = createClient()
|
||||
clientA.publishConnectionInformation(arrayOf(clientB.localPublicKey), 12345, true, true, true, true)
|
||||
delay(100.milliseconds)
|
||||
val infoB = clientB.requestConnectionInfo(clientA.localPublicKey)
|
||||
val infoC = clientC.requestConnectionInfo(clientA.localPublicKey)
|
||||
assertNotNull("Client B should receive connection info", infoB)
|
||||
assertEquals(12345.toUShort(), infoB!!.port)
|
||||
assertNull("Client C should not receive connection info (unauthorized)", infoC)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
clientC.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun relayedTransport_Bidirectional_Success() = runBlocking {
|
||||
val tcsA = CompletableDeferred<ChannelRelayed>()
|
||||
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
|
||||
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
|
||||
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
|
||||
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
|
||||
channelA.authorizable = AlwaysAuthorized()
|
||||
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
|
||||
channelB.authorizable = AlwaysAuthorized()
|
||||
channelTask.await()
|
||||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(1, 2, 3)))
|
||||
|
||||
val tcsDataA = CompletableDeferred<ByteArray>()
|
||||
channelA.setDataHandler { _, _, o, so, d ->
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataA.complete(b)
|
||||
}
|
||||
channelB.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(byteArrayOf(4, 5, 6)))
|
||||
|
||||
val receivedB = withTimeout(5000.milliseconds) { tcsDataB.await() }
|
||||
val receivedA = withTimeout(5000.milliseconds) { tcsDataA.await() }
|
||||
assertArrayEquals(byteArrayOf(1, 2, 3), receivedB)
|
||||
assertArrayEquals(byteArrayOf(4, 5, 6), receivedA)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun relayedTransport_MaximumMessageSize_Success() = runBlocking {
|
||||
val MAX_DATA_PER_PACKET = SyncSocketSession.MAXIMUM_PACKET_SIZE - SyncSocketSession.HEADER_SIZE - 8 - 16 - 16
|
||||
val maxSizeData = ByteArray(MAX_DATA_PER_PACKET).apply { Random.nextBytes(this) }
|
||||
val tcsA = CompletableDeferred<ChannelRelayed>()
|
||||
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
|
||||
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
|
||||
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
|
||||
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
|
||||
channelA.authorizable = AlwaysAuthorized()
|
||||
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
|
||||
channelB.authorizable = AlwaysAuthorized()
|
||||
channelTask.await()
|
||||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(maxSizeData))
|
||||
val receivedData = withTimeout(5000.milliseconds) { tcsDataB.await() }
|
||||
assertArrayEquals(maxSizeData, receivedData)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun publishAndGetRecord_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val clientC = createClient()
|
||||
val data = byteArrayOf(1, 2, 3)
|
||||
val success = clientA.publishRecords(listOf(clientB.localPublicKey), "testKey", data)
|
||||
val recordB = clientB.getRecord(clientA.localPublicKey, "testKey")
|
||||
val recordC = clientC.getRecord(clientA.localPublicKey, "testKey")
|
||||
assertTrue(success)
|
||||
assertNotNull(recordB)
|
||||
assertArrayEquals(data, recordB!!.first)
|
||||
assertNull("Unauthorized client should not access record", recordC)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
clientC.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun getNonExistentRecord_ReturnsNull() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val record = clientB.getRecord(clientA.localPublicKey, "nonExistentKey")
|
||||
assertNull("Getting non-existent record should return null", record)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun updateRecord_TimestampUpdated() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val key = "updateKey"
|
||||
val data1 = byteArrayOf(1)
|
||||
val data2 = byteArrayOf(2)
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), key, data1)
|
||||
val record1 = clientB.getRecord(clientA.localPublicKey, key)
|
||||
delay(1000.milliseconds)
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), key, data2)
|
||||
val record2 = clientB.getRecord(clientA.localPublicKey, key)
|
||||
assertNotNull(record1)
|
||||
assertNotNull(record2)
|
||||
assertTrue(record2!!.second > record1!!.second)
|
||||
assertArrayEquals(data2, record2.first)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun deleteRecord_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val data = byteArrayOf(1, 2, 3)
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), "toDelete", data)
|
||||
val success = clientB.deleteRecords(clientA.localPublicKey, clientB.localPublicKey, listOf("toDelete"))
|
||||
val record = clientB.getRecord(clientA.localPublicKey, "toDelete")
|
||||
assertTrue(success)
|
||||
assertNull(record)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun listRecordKeys_Success() = runBlocking {
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val keys = arrayOf("key1", "key2", "key3")
|
||||
keys.forEach { key ->
|
||||
clientA.publishRecords(listOf(clientB.localPublicKey), key, byteArrayOf(1))
|
||||
}
|
||||
val listedKeys = clientB.listRecordKeys(clientA.localPublicKey, clientB.localPublicKey)
|
||||
assertArrayEquals(keys, listedKeys.map { it.first }.toTypedArray())
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun singleLargeMessageViaRelayedChannel_Success() = runBlocking {
|
||||
val largeData = ByteArray(100000).apply { Random.nextBytes(this) }
|
||||
val tcsA = CompletableDeferred<ChannelRelayed>()
|
||||
val tcsB = CompletableDeferred<ChannelRelayed>()
|
||||
val clientA = createClient(onNewChannel = { _, c -> tcsA.complete(c) })
|
||||
val clientB = createClient(onNewChannel = { _, c -> tcsB.complete(c) })
|
||||
val channelTask = async { clientA.startRelayedChannel(clientB.localPublicKey) }
|
||||
val channelA = withTimeout(5000.milliseconds) { tcsA.await() }
|
||||
channelA.authorizable = AlwaysAuthorized()
|
||||
val channelB = withTimeout(5000.milliseconds) { tcsB.await() }
|
||||
channelB.authorizable = AlwaysAuthorized()
|
||||
channelTask.await()
|
||||
|
||||
val tcsDataB = CompletableDeferred<ByteArray>()
|
||||
channelB.setDataHandler { _, _, o, so, d ->
|
||||
val b = ByteArray(d.remaining())
|
||||
d.get(b)
|
||||
if (o == Opcode.DATA.value && so == 0u.toUByte()) tcsDataB.complete(b)
|
||||
}
|
||||
channelA.send(Opcode.DATA.value, 0u, ByteBuffer.wrap(largeData))
|
||||
val receivedData = withTimeout(10000.milliseconds) { tcsDataB.await() }
|
||||
assertArrayEquals(largeData, receivedData)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun publishAndGetLargeRecord_Success() = runBlocking {
|
||||
val largeData = ByteArray(1000000).apply { Random.nextBytes(this) }
|
||||
val clientA = createClient()
|
||||
val clientB = createClient()
|
||||
val success = clientA.publishRecords(listOf(clientB.localPublicKey), "largeRecord", largeData)
|
||||
val record = clientB.getRecord(clientA.localPublicKey, "largeRecord")
|
||||
assertTrue(success)
|
||||
assertNotNull(record)
|
||||
assertArrayEquals(largeData, record!!.first)
|
||||
clientA.stop()
|
||||
clientB.stop()
|
||||
}
|
||||
}
|
||||
|
||||
class AlwaysAuthorized : IAuthorizable {
|
||||
override val isAuthorized: Boolean get() = true
|
||||
}
|
||||
@@ -100,7 +100,8 @@ class SyncHomeActivity : AppCompatActivity() {
|
||||
|
||||
private fun updateDeviceView(syncDeviceView: SyncDeviceView, publicKey: String, session: SyncSession?): SyncDeviceView {
|
||||
val connected = session?.connected ?: false
|
||||
syncDeviceView.setLinkType(if (connected) LinkType.Local else LinkType.None)
|
||||
|
||||
syncDeviceView.setLinkType(session?.linkType ?: LinkType.None)
|
||||
.setName(session?.displayName ?: StateSync.instance.getCachedName(publicKey) ?: publicKey)
|
||||
//TODO: also display public key?
|
||||
.setStatus(if (connected) "Connected" else "Disconnected")
|
||||
|
||||
@@ -109,9 +109,9 @@ class SyncPairActivity : AppCompatActivity() {
|
||||
|
||||
lifecycleScope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
StateSync.instance.connect(deviceInfo) { session, complete, message ->
|
||||
StateSync.instance.connect(deviceInfo) { complete, message ->
|
||||
lifecycleScope.launch(Dispatchers.Main) {
|
||||
if (complete) {
|
||||
if (complete != null && complete) {
|
||||
_layoutPairingSuccess.visibility = View.VISIBLE
|
||||
_layoutPairing.visibility = View.GONE
|
||||
} else {
|
||||
|
||||
@@ -31,6 +31,7 @@ import com.futo.platformplayer.sync.SyncSessionData
|
||||
import com.futo.platformplayer.sync.internal.ChannelSocket
|
||||
import com.futo.platformplayer.sync.internal.GJSyncOpcodes
|
||||
import com.futo.platformplayer.sync.internal.IAuthorizable
|
||||
import com.futo.platformplayer.sync.internal.IChannel
|
||||
import com.futo.platformplayer.sync.internal.Opcode
|
||||
import com.futo.platformplayer.sync.internal.SyncDeviceInfo
|
||||
import com.futo.platformplayer.sync.internal.SyncKeyPair
|
||||
@@ -51,6 +52,7 @@ import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.encodeToString
|
||||
import kotlinx.serialization.json.Json
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.lang.Thread.sleep
|
||||
import java.net.InetAddress
|
||||
import java.net.InetSocketAddress
|
||||
import java.net.ServerSocket
|
||||
@@ -92,6 +94,8 @@ class StateSync {
|
||||
val deviceRemoved: Event1<String> = Event1()
|
||||
val deviceUpdatedOrAdded: Event2<String, SyncSession> = Event2()
|
||||
|
||||
//TODO: Should authorize acknowledge be implemented?
|
||||
|
||||
fun hasAuthorizedDevice(): Boolean {
|
||||
synchronized(_sessions) {
|
||||
return _sessions.any{ it.value.connected && it.value.isAuthorized };
|
||||
@@ -220,6 +224,7 @@ class StateSync {
|
||||
try {
|
||||
Log.i(TAG, "Starting relay session...")
|
||||
|
||||
var socketClosed = false;
|
||||
val socket = Socket(RELAY_SERVER, 9000)
|
||||
_relaySession = SyncSocketSession(
|
||||
(socket.remoteSocketAddress as InetSocketAddress).address.hostAddress!!,
|
||||
@@ -271,61 +276,61 @@ class StateSync {
|
||||
session?.removeChannel(channel)
|
||||
}
|
||||
},
|
||||
onChannelEstablished = { _, channel, isResponder ->
|
||||
handleAuthorization(channel, isResponder)
|
||||
},
|
||||
onClose = { socketClosed = true },
|
||||
onHandshakeComplete = { relaySession ->
|
||||
try {
|
||||
while (_started) {
|
||||
val unconnectedAuthorizedDevices = synchronized(_authorizedDevices) {
|
||||
_authorizedDevices.values.filter { !isConnected(it) }.toTypedArray()
|
||||
}
|
||||
Thread {
|
||||
try {
|
||||
while (_started && !socketClosed) {
|
||||
val unconnectedAuthorizedDevices = synchronized(_authorizedDevices) {
|
||||
_authorizedDevices.values.filter { !isConnected(it) }.toTypedArray()
|
||||
}
|
||||
|
||||
relaySession.publishConnectionInformation(unconnectedAuthorizedDevices, PORT, true, false, false, true)
|
||||
relaySession.publishConnectionInformation(unconnectedAuthorizedDevices, PORT, true, false, false, true)
|
||||
|
||||
val connectionInfos = runBlocking { relaySession.requestBulkConnectionInfo(unconnectedAuthorizedDevices) }
|
||||
val connectionInfos = runBlocking { relaySession.requestBulkConnectionInfo(unconnectedAuthorizedDevices) }
|
||||
|
||||
for ((targetKey, connectionInfo) in connectionInfos) {
|
||||
val potentialLocalAddresses = connectionInfo.ipv4Addresses.union(connectionInfo.ipv6Addresses)
|
||||
.filter { it != connectionInfo.remoteIp }
|
||||
if (connectionInfo.allowLocalDirect) {
|
||||
Thread {
|
||||
for ((targetKey, connectionInfo) in connectionInfos) {
|
||||
val potentialLocalAddresses = connectionInfo.ipv4Addresses.union(connectionInfo.ipv6Addresses)
|
||||
.filter { it != connectionInfo.remoteIp }
|
||||
if (connectionInfo.allowLocalDirect) {
|
||||
Thread {
|
||||
try {
|
||||
Log.v(TAG, "Attempting to connect directly, locally to '$targetKey'.")
|
||||
connect(potentialLocalAddresses.map { it }.toTypedArray(), PORT, targetKey, null)
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Failed to start direct connection using connection info with $targetKey.", e)
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
if (connectionInfo.allowRemoteDirect) {
|
||||
// TODO: Implement direct remote connection if needed
|
||||
}
|
||||
|
||||
if (connectionInfo.allowRemoteHolePunched) {
|
||||
// TODO: Implement hole punching if needed
|
||||
}
|
||||
|
||||
if (connectionInfo.allowRemoteProxied) {
|
||||
try {
|
||||
val syncDeviceInfo = SyncDeviceInfo(
|
||||
targetKey,
|
||||
potentialLocalAddresses.map { it }.toTypedArray(),
|
||||
PORT,
|
||||
null
|
||||
)
|
||||
Log.v(TAG, "Attempting to connect directly, locally to '$targetKey'.")
|
||||
connect(syncDeviceInfo)
|
||||
Log.v(TAG, "Attempting relayed connection with '$targetKey'.")
|
||||
runBlocking { relaySession.startRelayedChannel(targetKey, null) }
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Failed to start direct connection using connection info with $targetKey.", e)
|
||||
Log.e(TAG, "Failed to start relayed channel with $targetKey.", e)
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
if (connectionInfo.allowRemoteDirect) {
|
||||
// TODO: Implement direct remote connection if needed
|
||||
}
|
||||
|
||||
if (connectionInfo.allowRemoteHolePunched) {
|
||||
// TODO: Implement hole punching if needed
|
||||
}
|
||||
|
||||
if (connectionInfo.allowRemoteProxied) {
|
||||
try {
|
||||
Log.v(TAG, "Attempting relayed connection with '$targetKey'.")
|
||||
runBlocking { relaySession.startRelayedChannel(targetKey, null) }
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Failed to start relayed channel with $targetKey.", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Thread.sleep(15000)
|
||||
Thread.sleep(15000)
|
||||
}
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Unhandled exception in relay session.", e)
|
||||
relaySession.stop()
|
||||
}
|
||||
} catch (e: Throwable) {
|
||||
Log.e(TAG, "Unhandled exception in relay session.", e)
|
||||
relaySession.stop()
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
)
|
||||
|
||||
@@ -718,7 +723,6 @@ class StateSync {
|
||||
}
|
||||
|
||||
deviceRemoved.emit(it.remotePublicKey)
|
||||
|
||||
},
|
||||
dataHandler = { it, opcode, subOpcode, data ->
|
||||
handleData(it, opcode, subOpcode, data)
|
||||
@@ -782,65 +786,7 @@ class StateSync {
|
||||
session!!.addChannel(channelSocket!!)
|
||||
}
|
||||
|
||||
if (isResponder) {
|
||||
val isAuthorized = synchronized(_authorizedDevices) {
|
||||
_authorizedDevices.values.contains(remotePublicKey)
|
||||
}
|
||||
|
||||
if (!isAuthorized) {
|
||||
val scope = StateApp.instance.scopeOrNull
|
||||
val activity = SyncShowPairingCodeActivity.activity
|
||||
|
||||
if (scope != null && activity != null) {
|
||||
scope.launch(Dispatchers.Main) {
|
||||
UIDialogs.showConfirmationDialog(activity, "Allow connection from ${remotePublicKey}?",
|
||||
action = {
|
||||
scope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
session!!.authorize()
|
||||
Logger.i(TAG, "Connection authorized for $remotePublicKey by confirmation")
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to send authorize", e)
|
||||
}
|
||||
}
|
||||
},
|
||||
cancelAction = {
|
||||
scope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
unauthorize(remotePublicKey)
|
||||
} catch (e: Throwable) {
|
||||
Logger.w(TAG, "Failed to send unauthorize", e)
|
||||
}
|
||||
|
||||
synchronized(_sessions) {
|
||||
session?.close()
|
||||
_sessions.remove(remotePublicKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
} else {
|
||||
val publicKey = session!!.remotePublicKey
|
||||
session!!.unauthorize()
|
||||
session!!.close()
|
||||
|
||||
synchronized(_sessions) {
|
||||
_sessions.remove(publicKey)
|
||||
}
|
||||
|
||||
Logger.i(TAG, "Connection unauthorized for $remotePublicKey because not authorized and not on pairing activity to ask")
|
||||
}
|
||||
} else {
|
||||
//Responder does not need to check because already approved
|
||||
session!!.authorize()
|
||||
Logger.i(TAG, "Connection authorized for $remotePublicKey because already authorized")
|
||||
}
|
||||
} else {
|
||||
//Initiator does not need to check because the manual action of scanning the QR counts as approval
|
||||
session!!.authorize()
|
||||
Logger.i(TAG, "Connection authorized for $remotePublicKey because initiator")
|
||||
}
|
||||
handleAuthorization(channelSocket!!, isResponder)
|
||||
},
|
||||
onData = { s, opcode, subOpcode, data ->
|
||||
session?.handlePacket(opcode, subOpcode, data)
|
||||
@@ -848,6 +794,71 @@ class StateSync {
|
||||
)
|
||||
}
|
||||
|
||||
private fun handleAuthorization(channel: IChannel, isResponder: Boolean) {
|
||||
val syncSession = channel.syncSession!!
|
||||
val remotePublicKey = channel.remotePublicKey!!
|
||||
|
||||
if (isResponder) {
|
||||
val isAuthorized = synchronized(_authorizedDevices) {
|
||||
_authorizedDevices.values.contains(remotePublicKey)
|
||||
}
|
||||
|
||||
if (!isAuthorized) {
|
||||
val scope = StateApp.instance.scopeOrNull
|
||||
val activity = SyncShowPairingCodeActivity.activity
|
||||
|
||||
if (scope != null && activity != null) {
|
||||
scope.launch(Dispatchers.Main) {
|
||||
UIDialogs.showConfirmationDialog(activity, "Allow connection from ${remotePublicKey}?",
|
||||
action = {
|
||||
scope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
syncSession.authorize()
|
||||
Logger.i(TAG, "Connection authorized for $remotePublicKey by confirmation")
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to send authorize", e)
|
||||
}
|
||||
}
|
||||
},
|
||||
cancelAction = {
|
||||
scope.launch(Dispatchers.IO) {
|
||||
try {
|
||||
unauthorize(remotePublicKey)
|
||||
} catch (e: Throwable) {
|
||||
Logger.w(TAG, "Failed to send unauthorize", e)
|
||||
}
|
||||
|
||||
syncSession.close()
|
||||
synchronized(_sessions) {
|
||||
_sessions.remove(remotePublicKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
} else {
|
||||
val publicKey = syncSession.remotePublicKey
|
||||
syncSession.unauthorize()
|
||||
syncSession.close()
|
||||
|
||||
synchronized(_sessions) {
|
||||
_sessions.remove(publicKey)
|
||||
}
|
||||
|
||||
Logger.i(TAG, "Connection unauthorized for $remotePublicKey because not authorized and not on pairing activity to ask")
|
||||
}
|
||||
} else {
|
||||
//Responder does not need to check because already approved
|
||||
syncSession.authorize()
|
||||
Logger.i(TAG, "Connection authorized for $remotePublicKey because already authorized")
|
||||
}
|
||||
} else {
|
||||
//Initiator does not need to check because the manual action of scanning the QR counts as approval
|
||||
syncSession.authorize()
|
||||
Logger.i(TAG, "Connection authorized for $remotePublicKey because initiator")
|
||||
}
|
||||
}
|
||||
|
||||
inline fun <reified T> broadcastJsonData(subOpcode: UByte, data: T) {
|
||||
broadcast(Opcode.DATA.value, subOpcode, Json.encodeToString(data));
|
||||
}
|
||||
@@ -895,16 +906,35 @@ class StateSync {
|
||||
_relaySession = null
|
||||
}
|
||||
|
||||
fun connect(deviceInfo: SyncDeviceInfo, onStatusUpdate: ((session: SyncSession?, complete: Boolean, message: String) -> Unit)? = null): SyncSocketSession {
|
||||
onStatusUpdate?.invoke(null, false, "Connecting...")
|
||||
val socket = getConnectedSocket(deviceInfo.addresses.map { InetAddress.getByName(it) }, deviceInfo.port) ?: throw Exception("Failed to connect")
|
||||
onStatusUpdate?.invoke(null, false, "Handshaking...")
|
||||
fun connect(deviceInfo: SyncDeviceInfo, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null) {
|
||||
try {
|
||||
connect(deviceInfo.addresses, deviceInfo.port, deviceInfo.publicKey, deviceInfo.pairingCode, onStatusUpdate)
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to connect directly", e)
|
||||
val relaySession = _relaySession
|
||||
if (relaySession != null) {
|
||||
onStatusUpdate?.invoke(null, "Connecting via relay...")
|
||||
|
||||
runBlocking {
|
||||
relaySession.startRelayedChannel(deviceInfo.publicKey, deviceInfo.pairingCode)
|
||||
onStatusUpdate?.invoke(true, "Connected")
|
||||
}
|
||||
} else {
|
||||
throw Exception("Failed to connect.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun connect(addresses: Array<String>, port: Int, publicKey: String, pairingCode: String?, onStatusUpdate: ((complete: Boolean?, message: String) -> Unit)? = null): SyncSocketSession {
|
||||
onStatusUpdate?.invoke(null, "Connecting directly...")
|
||||
val socket = getConnectedSocket(addresses.map { InetAddress.getByName(it) }, port) ?: throw Exception("Failed to connect")
|
||||
onStatusUpdate?.invoke(null, "Handshaking...")
|
||||
|
||||
val session = createSocketSession(socket, false) { s ->
|
||||
onStatusUpdate?.invoke(s, true, "Handshake complete")
|
||||
onStatusUpdate?.invoke(true, "Authorized")
|
||||
}
|
||||
|
||||
session.startAsInitiator(deviceInfo.publicKey, deviceInfo.pairingCode)
|
||||
session.startAsInitiator(publicKey, pairingCode)
|
||||
return session
|
||||
}
|
||||
|
||||
|
||||
@@ -13,8 +13,9 @@ interface IChannel : AutoCloseable {
|
||||
val remotePublicKey: String?
|
||||
val remoteVersion: Int?
|
||||
var authorizable: IAuthorizable?
|
||||
var syncSession: SyncSession?
|
||||
fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?)
|
||||
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer? = null)
|
||||
fun send(opcode: UByte, subOpcode: UByte = 0u, data: ByteBuffer? = null)
|
||||
fun setCloseHandler(onClose: ((IChannel) -> Unit)?)
|
||||
}
|
||||
|
||||
@@ -27,6 +28,7 @@ class ChannelSocket(private val session: SyncSocketSession) : IChannel {
|
||||
override var authorizable: IAuthorizable?
|
||||
get() = session.authorizable
|
||||
set(value) { session.authorizable = value }
|
||||
override var syncSession: SyncSession? = null
|
||||
|
||||
override fun setDataHandler(onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)?) {
|
||||
this.onData = onData
|
||||
@@ -76,10 +78,11 @@ class ChannelRelayed(
|
||||
override var authorizable: IAuthorizable? = null
|
||||
val isAuthorized: Boolean get() = authorizable?.isAuthorized ?: false
|
||||
var connectionId: Long = 0L
|
||||
override var remotePublicKey: String? = null
|
||||
override var remotePublicKey: String? = publicKey
|
||||
private set
|
||||
override var remoteVersion: Int? = null
|
||||
private set
|
||||
override var syncSession: SyncSession? = null
|
||||
|
||||
private var onData: ((SyncSocketSession, IChannel, UByte, UByte, ByteBuffer) -> Unit)? = null
|
||||
private var onClose: ((IChannel) -> Unit)? = null
|
||||
@@ -153,7 +156,7 @@ class ChannelRelayed(
|
||||
put(encryptedPayload, 0, encryptedLength)
|
||||
}
|
||||
|
||||
session.send(Opcode.RELAY.value, RelayOpcode.RELAYED_DATA.value, ByteBuffer.wrap(relayedPacket))
|
||||
session.send(Opcode.RELAY.value, RelayOpcode.DATA.value, ByteBuffer.wrap(relayedPacket).order(ByteOrder.LITTLE_ENDIAN))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +176,7 @@ class ChannelRelayed(
|
||||
put(encryptedPayload, 0, encryptedLength)
|
||||
}
|
||||
|
||||
session.send(Opcode.RELAY.value, RelayOpcode.RELAYED_ERROR.value, ByteBuffer.wrap(relayedPacket))
|
||||
session.send(Opcode.RELAY.value, RelayOpcode.ERROR.value, ByteBuffer.wrap(relayedPacket))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,8 +254,7 @@ class ChannelRelayed(
|
||||
if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes")
|
||||
|
||||
val (pairingMessageLength, pairingMessage) = if (pairingCode != null) {
|
||||
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
|
||||
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.INITIATOR).apply {
|
||||
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply {
|
||||
remotePublicKey.setPublicKey(publicKeyBytes, 0)
|
||||
start()
|
||||
}
|
||||
@@ -312,7 +314,7 @@ class ChannelRelayed(
|
||||
val decryptedPayload = ByteArray(encryptedBytes.size - 16)
|
||||
val plen = transport!!.receiver.decryptWithAd(null, encryptedBytes, 0, decryptedPayload, 0, encryptedBytes.size)
|
||||
if (plen != decryptedPayload.size) throw IllegalStateException("Expected decrypted payload length to be $plen")
|
||||
return ByteBuffer.wrap(decryptedPayload)
|
||||
return ByteBuffer.wrap(decryptedPayload).order(ByteOrder.LITTLE_ENDIAN)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,12 +329,4 @@ class ChannelRelayed(
|
||||
completeHandshake(remoteVersion, transport)
|
||||
}
|
||||
}
|
||||
|
||||
fun handleData(data: ByteBuffer) {
|
||||
val size = data.int
|
||||
if (size != data.remaining() + 2) throw IllegalStateException("Incomplete packet received")
|
||||
val opcode = data.get().toUByte()
|
||||
val subOpcode = data.get().toUByte()
|
||||
invokeDataHandler(opcode, subOpcode, data)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,6 @@ package com.futo.platformplayer.sync.internal;
|
||||
|
||||
public enum LinkType {
|
||||
None,
|
||||
Local,
|
||||
Proxied
|
||||
Direct,
|
||||
Relayed
|
||||
}
|
||||
|
||||
@@ -33,6 +33,30 @@ class SyncSession : IAuthorizable {
|
||||
private set
|
||||
val displayName: String get() = remoteDeviceName ?: remotePublicKey
|
||||
|
||||
val linkType: LinkType get()
|
||||
{
|
||||
var hasRelayed = false
|
||||
var hasDirect = false
|
||||
synchronized(_channels)
|
||||
{
|
||||
for (channel in _channels)
|
||||
{
|
||||
if (channel is ChannelRelayed)
|
||||
hasRelayed = true
|
||||
if (channel is ChannelSocket)
|
||||
hasDirect = true
|
||||
if (hasRelayed && hasDirect)
|
||||
return LinkType.Direct
|
||||
}
|
||||
}
|
||||
|
||||
if (hasRelayed)
|
||||
return LinkType.Relayed
|
||||
if (hasDirect)
|
||||
return LinkType.Direct
|
||||
return LinkType.None
|
||||
}
|
||||
|
||||
var connected: Boolean = false
|
||||
private set(v) {
|
||||
if (field != v) {
|
||||
@@ -70,6 +94,7 @@ class SyncSession : IAuthorizable {
|
||||
}
|
||||
|
||||
channel.authorizable = this
|
||||
channel.syncSession = this
|
||||
}
|
||||
|
||||
fun authorize() {
|
||||
|
||||
@@ -37,8 +37,8 @@ class SyncSocketSession {
|
||||
private val _onClose: ((session: SyncSocketSession) -> Unit)?
|
||||
private val _onHandshakeComplete: ((session: SyncSocketSession) -> Unit)?
|
||||
private val _onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)?
|
||||
private val _onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)?
|
||||
private val _isHandshakeAllowed: ((session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)?
|
||||
private var _thread: Thread? = null
|
||||
private var _cipherStatePair: CipherStatePair? = null
|
||||
private var _remotePublicKey: String? = null
|
||||
val remotePublicKey: String? get() = _remotePublicKey
|
||||
@@ -86,6 +86,7 @@ class SyncSocketSession {
|
||||
onHandshakeComplete: ((session: SyncSocketSession) -> Unit)? = null,
|
||||
onData: ((session: SyncSocketSession, opcode: UByte, subOpcode: UByte, data: ByteBuffer) -> Unit)? = null,
|
||||
onNewChannel: ((session: SyncSocketSession, channel: ChannelRelayed) -> Unit)? = null,
|
||||
onChannelEstablished: ((session: SyncSocketSession, channel: ChannelRelayed, isResponder: Boolean) -> Unit)? = null,
|
||||
isHandshakeAllowed: ((session: SyncSocketSession, remotePublicKey: String, pairingCode: String?) -> Boolean)? = null
|
||||
) {
|
||||
_inputStream = inputStream
|
||||
@@ -95,6 +96,7 @@ class SyncSocketSession {
|
||||
_localKeyPair = localKeyPair
|
||||
_onData = onData
|
||||
_onNewChannel = onNewChannel
|
||||
_onChannelEstablished = onChannelEstablished
|
||||
_isHandshakeAllowed = isHandshakeAllowed
|
||||
this.remoteAddress = remoteAddress
|
||||
|
||||
@@ -105,33 +107,29 @@ class SyncSocketSession {
|
||||
|
||||
fun startAsInitiator(remotePublicKey: String, pairingCode: String? = null) {
|
||||
_started = true
|
||||
_thread = Thread {
|
||||
try {
|
||||
handshakeAsInitiator(remotePublicKey, pairingCode)
|
||||
_onHandshakeComplete?.invoke(this)
|
||||
receiveLoop()
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to run as initiator", e)
|
||||
} finally {
|
||||
stop()
|
||||
}
|
||||
}.apply { start() }
|
||||
try {
|
||||
handshakeAsInitiator(remotePublicKey, pairingCode)
|
||||
_onHandshakeComplete?.invoke(this)
|
||||
receiveLoop()
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to run as initiator", e)
|
||||
} finally {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
|
||||
fun startAsResponder() {
|
||||
_started = true
|
||||
_thread = Thread {
|
||||
try {
|
||||
if (handshakeAsResponder()) {
|
||||
_onHandshakeComplete?.invoke(this)
|
||||
receiveLoop()
|
||||
}
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to run as responder", e)
|
||||
} finally {
|
||||
stop()
|
||||
try {
|
||||
if (handshakeAsResponder()) {
|
||||
_onHandshakeComplete?.invoke(this)
|
||||
receiveLoop()
|
||||
}
|
||||
}.apply { start() }
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Failed to run as responder", e)
|
||||
} finally {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
|
||||
private fun receiveLoop() {
|
||||
@@ -155,7 +153,7 @@ class SyncSocketSession {
|
||||
val plen: Int = _cipherStatePair!!.receiver.decryptWithAd(null, _buffer, 0, _bufferDecrypted, 0, messageSize)
|
||||
//Logger.i(TAG, "Decrypted message (size = ${plen})")
|
||||
|
||||
handleData(_bufferDecrypted, plen)
|
||||
handleData(_bufferDecrypted, plen, null)
|
||||
} catch (e: Throwable) {
|
||||
Logger.e(TAG, "Exception while receiving data", e)
|
||||
break
|
||||
@@ -191,7 +189,6 @@ class SyncSocketSession {
|
||||
_outputStream.close()
|
||||
_cipherStatePair?.sender?.destroy()
|
||||
_cipherStatePair?.receiver?.destroy()
|
||||
_thread = null
|
||||
Logger.i(TAG, "Session closed")
|
||||
}
|
||||
|
||||
@@ -206,8 +203,7 @@ class SyncSocketSession {
|
||||
val pairingMessage: ByteArray
|
||||
val pairingMessageLength: Int
|
||||
if (pairingCode != null) {
|
||||
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
|
||||
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.INITIATOR)
|
||||
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR)
|
||||
pairingHandshake.remotePublicKey.setPublicKey(Base64.getDecoder().decode(remotePublicKey), 0)
|
||||
pairingHandshake.start()
|
||||
val pairingCodeBytes = pairingCode.toByteArray(Charsets.UTF_8)
|
||||
@@ -261,8 +257,7 @@ class SyncSocketSession {
|
||||
|
||||
var pairingCode: String? = null
|
||||
if (pairingMessageLength > 0) {
|
||||
val pairingProtocolName = "Noise_N_25519_ChaChaPoly_Blake2b"
|
||||
val pairingHandshake = HandshakeState(pairingProtocolName, HandshakeState.RESPONDER)
|
||||
val pairingHandshake = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER)
|
||||
pairingHandshake.localKeyPair.copyFrom(_localKeyPair)
|
||||
pairingHandshake.start()
|
||||
val pairingPlaintext = ByteArray(512)
|
||||
@@ -298,47 +293,8 @@ class SyncSocketSession {
|
||||
throw Exception("Invalid version")
|
||||
}
|
||||
|
||||
private fun handshake(handshakeState: HandshakeState): CipherStatePair {
|
||||
handshakeState.start()
|
||||
|
||||
val message = ByteArray(8192)
|
||||
val plaintext = ByteArray(8192)
|
||||
|
||||
while (_started) {
|
||||
when (handshakeState.action) {
|
||||
HandshakeState.READ_MESSAGE -> {
|
||||
val messageSize = _inputStream.readInt()
|
||||
Logger.i(TAG, "Handshake read message (size = ${messageSize})")
|
||||
|
||||
var bytesRead = 0
|
||||
while (bytesRead < messageSize) {
|
||||
val read = _inputStream.read(message, bytesRead, messageSize - bytesRead)
|
||||
if (read == -1)
|
||||
throw Exception("Stream closed")
|
||||
bytesRead += read
|
||||
}
|
||||
|
||||
handshakeState.readMessage(message, 0, messageSize, plaintext, 0)
|
||||
}
|
||||
HandshakeState.WRITE_MESSAGE -> {
|
||||
val messageSize = handshakeState.writeMessage(message, 0, null, 0, 0)
|
||||
Logger.i(TAG, "Handshake wrote message (size = ${messageSize})")
|
||||
_outputStream.writeInt(messageSize)
|
||||
_outputStream.write(message, 0, messageSize)
|
||||
}
|
||||
HandshakeState.SPLIT -> {
|
||||
//Logger.i(TAG, "Handshake split")
|
||||
return handshakeState.split()
|
||||
}
|
||||
else -> throw Exception("Unexpected state (handshakeState.action = ${handshakeState.action})")
|
||||
}
|
||||
}
|
||||
|
||||
throw Exception("Handshake finished without completing")
|
||||
}
|
||||
|
||||
fun generateStreamId(): Int = synchronized(_streamIdGeneratorLock) { _streamIdGenerator++ }
|
||||
fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ }
|
||||
private fun generateRequestId(): Int = synchronized(_requestIdGeneratorLock) { _requestIdGenerator++ }
|
||||
|
||||
fun send(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
|
||||
ensureNotMainThread()
|
||||
@@ -415,22 +371,25 @@ class SyncSocketSession {
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalUnsignedTypes::class)
|
||||
private fun handleData(data: ByteArray, length: Int) {
|
||||
private fun handleData(data: ByteArray, length: Int, sourceChannel: ChannelRelayed?) {
|
||||
return handleData(ByteBuffer.wrap(data, 0, length).order(ByteOrder.LITTLE_ENDIAN), sourceChannel)
|
||||
}
|
||||
|
||||
private fun handleData(data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
val length = data.remaining()
|
||||
if (length < HEADER_SIZE)
|
||||
throw Exception("Packet must be at least 6 bytes (header size)")
|
||||
|
||||
val size = ByteBuffer.wrap(data, 0, 4).order(ByteOrder.LITTLE_ENDIAN).int
|
||||
val size = data.int
|
||||
if (size != length - 4)
|
||||
throw Exception("Incomplete packet received")
|
||||
|
||||
val opcode = data.asUByteArray()[4]
|
||||
val subOpcode = data.asUByteArray()[5]
|
||||
val packetData = ByteBuffer.wrap(data, HEADER_SIZE, size - 2)
|
||||
handlePacket(opcode, subOpcode, packetData.order(ByteOrder.LITTLE_ENDIAN))
|
||||
val opcode = data.get().toUByte()
|
||||
val subOpcode = data.get().toUByte()
|
||||
handlePacket(opcode, subOpcode, data, sourceChannel)
|
||||
}
|
||||
|
||||
private fun handleRequest(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleRequest(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
when (subOpcode) {
|
||||
RequestOpcode.TRANSPORT_RELAYED.value -> {
|
||||
Logger.i(TAG, "Received request for a relayed transport")
|
||||
@@ -453,7 +412,7 @@ class SyncSocketSession {
|
||||
val channelHandshakeMessage = ByteArray(channelMessageLength).also { data.get(it) }
|
||||
val publicKey = Base64.getEncoder().encodeToString(publicKeyBytes)
|
||||
val pairingCode = if (pairingMessageLength > 0) {
|
||||
val pairingProtocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val pairingProtocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
@@ -467,20 +426,22 @@ class SyncSocketSession {
|
||||
rp.putInt(2) // Status code for not allowed
|
||||
rp.putLong(connectionId)
|
||||
rp.putInt(requestId)
|
||||
rp.rewind()
|
||||
send(Opcode.RESPONSE.value, ResponseOpcode.TRANSPORT.value, rp)
|
||||
return
|
||||
}
|
||||
val channel = ChannelRelayed(this, _localKeyPair, publicKey, false)
|
||||
_onNewChannel?.invoke(this, channel)
|
||||
channel.connectionId = connectionId
|
||||
_onNewChannel?.invoke(this, channel)
|
||||
_channels[connectionId] = channel
|
||||
channel.sendResponseTransport(remoteVersion, requestId, channelHandshakeMessage)
|
||||
_onChannelEstablished?.invoke(this, channel, true)
|
||||
}
|
||||
else -> Logger.w(TAG, "Unhandled request opcode: $subOpcode")
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleResponse(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleResponse(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
if (data.remaining() < 8) {
|
||||
Logger.e(TAG, "Response packet too short")
|
||||
return
|
||||
@@ -502,7 +463,7 @@ class SyncSocketSession {
|
||||
}
|
||||
} ?: Logger.e(TAG, "No pending request for requestId $requestId")
|
||||
}
|
||||
ResponseOpcode.TRANSPORT.value -> {
|
||||
ResponseOpcode.TRANSPORT_RELAYED.value -> {
|
||||
if (statusCode == 0) {
|
||||
if (data.remaining() < 16) {
|
||||
Logger.e(TAG, "RESPONSE_TRANSPORT packet too short")
|
||||
@@ -520,6 +481,7 @@ class SyncSocketSession {
|
||||
channel.handleTransportRelayed(remoteVersion, connectionId, handshakeMessage)
|
||||
_channels[connectionId] = channel
|
||||
tcs.complete(channel)
|
||||
_onChannelEstablished?.invoke(this, channel, false)
|
||||
} ?: Logger.e(TAG, "No pending channel for requestId $requestId")
|
||||
} else {
|
||||
_pendingChannels.remove(requestId)?.let { (channel, tcs) ->
|
||||
@@ -564,7 +526,7 @@ class SyncSocketSession {
|
||||
val blobLength = data.int
|
||||
val encryptedBlob = ByteArray(blobLength).also { data.get(it) }
|
||||
val timestamp = data.long
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
@@ -607,7 +569,7 @@ class SyncSocketSession {
|
||||
val blobLength = data.int
|
||||
val encryptedBlob = ByteArray(blobLength).also { data.get(it) }
|
||||
val timestamp = data.long
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
@@ -667,7 +629,7 @@ class SyncSocketSession {
|
||||
val remoteIp = remoteIpBytes.joinToString(".") { it.toUByte().toString() }
|
||||
val handshakeMessage = ByteArray(48).also { data.get(it) }
|
||||
val ciphertext = ByteArray(data.remaining()).also { data.get(it) }
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.RESPONDER).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.RESPONDER).apply {
|
||||
localKeyPair.copyFrom(_localKeyPair)
|
||||
start()
|
||||
}
|
||||
@@ -691,9 +653,14 @@ class SyncSocketSession {
|
||||
return ConnectionInfo(port, name, remoteIp, ipv4Addresses, ipv6Addresses, allowLocalDirect, allowRemoteDirect, allowRemoteHolePunched, allowRemoteProxied)
|
||||
}
|
||||
|
||||
private fun handleNotify(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleNotify(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
when (subOpcode) {
|
||||
NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> _onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data)
|
||||
NotifyOpcode.AUTHORIZED.value, NotifyOpcode.UNAUTHORIZED.value -> {
|
||||
if (sourceChannel != null)
|
||||
sourceChannel.invokeDataHandler(Opcode.NOTIFY.value, subOpcode, data)
|
||||
else
|
||||
_onData?.invoke(this, Opcode.NOTIFY.value, subOpcode, data)
|
||||
}
|
||||
NotifyOpcode.CONNECTION_INFO.value -> { /* Handle connection info if needed */ }
|
||||
}
|
||||
}
|
||||
@@ -702,10 +669,11 @@ class SyncSocketSession {
|
||||
val packet = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN)
|
||||
packet.putLong(connectionId)
|
||||
packet.putInt(errorCode.value)
|
||||
packet.rewind()
|
||||
send(Opcode.RELAY.value, RelayOpcode.RELAY_ERROR.value, packet)
|
||||
}
|
||||
|
||||
private fun handleRelay(subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handleRelay(subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
when (subOpcode) {
|
||||
RelayOpcode.RELAYED_DATA.value -> {
|
||||
if (data.remaining() < 8) {
|
||||
@@ -719,7 +687,7 @@ class SyncSocketSession {
|
||||
}
|
||||
val decryptedPayload = channel.decrypt(data)
|
||||
try {
|
||||
channel.handleData(decryptedPayload)
|
||||
handleData(decryptedPayload, channel)
|
||||
} catch (e: Exception) {
|
||||
Logger.e(TAG, "Exception while handling relayed data", e)
|
||||
channel.sendError(SyncErrorCode.ConnectionClosed)
|
||||
@@ -765,33 +733,36 @@ class SyncSocketSession {
|
||||
}
|
||||
}
|
||||
|
||||
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer) {
|
||||
private fun handlePacket(opcode: UByte, subOpcode: UByte, data: ByteBuffer, sourceChannel: ChannelRelayed?) {
|
||||
Logger.i(TAG, "Handle packet (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||
|
||||
when (opcode) {
|
||||
Opcode.PING.value -> {
|
||||
send(Opcode.PONG.value)
|
||||
if (sourceChannel != null)
|
||||
sourceChannel.send(Opcode.PONG.value)
|
||||
else
|
||||
send(Opcode.PONG.value)
|
||||
//Logger.i(TAG, "Received ping, sent pong")
|
||||
return
|
||||
}
|
||||
Opcode.PONG.value -> {
|
||||
//Logger.i(TAG, "Received pong")
|
||||
Logger.v(TAG, "Received pong")
|
||||
return
|
||||
}
|
||||
Opcode.REQUEST.value -> {
|
||||
handleRequest(subOpcode, data)
|
||||
handleRequest(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
Opcode.RESPONSE.value -> {
|
||||
handleResponse(subOpcode, data)
|
||||
handleResponse(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
Opcode.NOTIFY.value -> {
|
||||
handleNotify(subOpcode, data)
|
||||
handleNotify(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
Opcode.RELAY.value -> {
|
||||
handleRelay(subOpcode, data)
|
||||
handleRelay(subOpcode, data, sourceChannel)
|
||||
return
|
||||
}
|
||||
else -> if (isAuthorized) when (opcode) {
|
||||
@@ -848,18 +819,20 @@ class SyncSocketSession {
|
||||
throw Exception("After sync stream end, the stream must be complete")
|
||||
}
|
||||
|
||||
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) })
|
||||
}
|
||||
else -> {
|
||||
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||
handlePacket(syncStream.opcode, syncStream.subOpcode, syncStream.getBytes().let { ByteBuffer.wrap(it).order(ByteOrder.LITTLE_ENDIAN) }, sourceChannel)
|
||||
}
|
||||
}
|
||||
Opcode.DATA.value -> {
|
||||
if (sourceChannel != null)
|
||||
sourceChannel.invokeDataHandler(opcode, subOpcode, data)
|
||||
else
|
||||
_onData?.invoke(this, opcode, subOpcode, data)
|
||||
}
|
||||
else -> {
|
||||
Logger.w(TAG, "Unknown opcode received (opcode = ${opcode}, subOpcode = ${subOpcode})")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (authorizable?.isAuthorized != true) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun requestConnectionInfo(publicKey: String): ConnectionInfo? {
|
||||
@@ -872,6 +845,7 @@ class SyncSocketSession {
|
||||
val packet = ByteBuffer.allocate(4 + 32).order(ByteOrder.LITTLE_ENDIAN)
|
||||
packet.putInt(requestId)
|
||||
packet.put(publicKeyBytes)
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.CONNECTION_INFO.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingConnectionInfoRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -893,6 +867,7 @@ class SyncSocketSession {
|
||||
if (pkBytes.size != 32) throw IllegalArgumentException("Invalid public key length for $pk")
|
||||
packet.put(pkBytes)
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_CONNECTION_INFO.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingBulkConnectionInfoRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -949,7 +924,6 @@ class SyncSocketSession {
|
||||
) {
|
||||
if (authorizedKeys.size > 255) throw IllegalArgumentException("Number of authorized keys exceeds 255")
|
||||
|
||||
// **Step 1: Collect Network Information**
|
||||
val ipv4Addresses = mutableListOf<String>()
|
||||
val ipv6Addresses = mutableListOf<String>()
|
||||
for (nic in NetworkInterface.getNetworkInterfaces()) {
|
||||
@@ -965,11 +939,9 @@ class SyncSocketSession {
|
||||
}
|
||||
}
|
||||
|
||||
// **Step 2: Get Device Name**
|
||||
val deviceName = getDeviceName()
|
||||
val nameBytes = getLimitedUtf8Bytes(deviceName, 255)
|
||||
|
||||
// **Step 3: Serialize Connection Information**
|
||||
val blobSize = 2 + 1 + nameBytes.size + 1 + ipv4Addresses.size * 4 + 1 + ipv6Addresses.size * 16 + 1 + 1 + 1 + 1
|
||||
val data = ByteBuffer.allocate(blobSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||
data.putShort(port.toShort())
|
||||
@@ -990,19 +962,19 @@ class SyncSocketSession {
|
||||
data.put(if (allowRemoteHolePunched) 1 else 0)
|
||||
data.put(if (allowRemoteProxied) 1 else 0)
|
||||
|
||||
// **Step 4: Precalculate Total Size**
|
||||
val handshakeSize = 48 // Noise handshake size for N pattern
|
||||
|
||||
data.rewind()
|
||||
val ciphertextSize = data.remaining() + 16 // Encrypted data size
|
||||
val totalSize = 1 + authorizedKeys.size * (32 + handshakeSize + 4 + ciphertextSize)
|
||||
val publishBytes = ByteBuffer.allocate(totalSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||
publishBytes.put(authorizedKeys.size.toByte())
|
||||
|
||||
// **Step 5: Encrypt Data for Each Authorized Key**
|
||||
for (key in authorizedKeys) {
|
||||
val publicKeyBytes = Base64.getDecoder().decode(key)
|
||||
if (publicKeyBytes.size != 32) throw IllegalArgumentException("Public key must be 32 bytes")
|
||||
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.INITIATOR)
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR)
|
||||
protocol.remotePublicKey.setPublicKey(publicKeyBytes, 0)
|
||||
protocol.start()
|
||||
|
||||
@@ -1023,7 +995,7 @@ class SyncSocketSession {
|
||||
publishBytes.put(ciphertext, 0, ciphertextBytesWritten)
|
||||
}
|
||||
|
||||
// **Step 6: Send the Encrypted Data**
|
||||
publishBytes.rewind()
|
||||
send(Opcode.NOTIFY.value, NotifyOpcode.CONNECTION_INFO.value, publishBytes)
|
||||
}
|
||||
|
||||
@@ -1035,23 +1007,32 @@ class SyncSocketSession {
|
||||
val deferred = CompletableDeferred<Boolean>()
|
||||
_pendingPublishRequests[requestId] = deferred
|
||||
try {
|
||||
val MAX_PLAINTEXT_SIZE = 65535 - 16 // Adjust for tag size
|
||||
val MAX_PLAINTEXT_SIZE = 65535
|
||||
val HANDSHAKE_SIZE = 48
|
||||
val LENGTH_SIZE = 4
|
||||
val TAG_SIZE = 16
|
||||
val chunkCount = (data.size + MAX_PLAINTEXT_SIZE - 1) / MAX_PLAINTEXT_SIZE
|
||||
val blobSize = HANDSHAKE_SIZE + chunkCount * (LENGTH_SIZE + MAX_PLAINTEXT_SIZE + TAG_SIZE)
|
||||
|
||||
var blobSize = HANDSHAKE_SIZE
|
||||
var dataOffset = 0
|
||||
for (i in 0 until chunkCount) {
|
||||
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||
blobSize += LENGTH_SIZE + (chunkSize + TAG_SIZE)
|
||||
dataOffset += chunkSize
|
||||
}
|
||||
|
||||
val totalPacketSize = 4 + 1 + keyBytes.size + 1 + consumerPublicKeys.size * (32 + 4 + blobSize)
|
||||
val packet = ByteBuffer.allocate(totalPacketSize).order(ByteOrder.LITTLE_ENDIAN)
|
||||
packet.putInt(requestId)
|
||||
packet.put(keyBytes.size.toByte())
|
||||
packet.put(keyBytes)
|
||||
packet.put(consumerPublicKeys.size.toByte())
|
||||
|
||||
for (consumer in consumerPublicKeys) {
|
||||
val consumerBytes = Base64.getDecoder().decode(consumer)
|
||||
if (consumerBytes.size != 32) throw IllegalArgumentException("Consumer public key must be 32 bytes")
|
||||
packet.put(consumerBytes)
|
||||
val protocol = HandshakeState("Noise_N_25519_ChaChaPoly_Blake2b", HandshakeState.INITIATOR).apply {
|
||||
val protocol = HandshakeState(SyncSocketSession.nProtocolName, HandshakeState.INITIATOR).apply {
|
||||
remotePublicKey.setPublicKey(consumerBytes, 0)
|
||||
start()
|
||||
}
|
||||
@@ -1060,9 +1041,10 @@ class SyncSocketSession {
|
||||
val transportPair = protocol.split()
|
||||
packet.putInt(blobSize)
|
||||
packet.put(handshakeMessage)
|
||||
var dataOffset = 0
|
||||
|
||||
dataOffset = 0
|
||||
for (i in 0 until chunkCount) {
|
||||
val chunkSize = min(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||
val chunkSize = minOf(MAX_PLAINTEXT_SIZE, data.size - dataOffset)
|
||||
val plaintext = data.copyOfRange(dataOffset, dataOffset + chunkSize)
|
||||
val ciphertext = ByteArray(chunkSize + TAG_SIZE)
|
||||
val written = transportPair.sender.encryptWithAd(null, plaintext, 0, ciphertext, 0, plaintext.size)
|
||||
@@ -1071,6 +1053,7 @@ class SyncSocketSession {
|
||||
dataOffset += chunkSize
|
||||
}
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_PUBLISH_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingPublishRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -1093,6 +1076,7 @@ class SyncSocketSession {
|
||||
packet.put(publisherBytes)
|
||||
packet.put(keyBytes.size.toByte())
|
||||
packet.put(keyBytes)
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.GET_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingGetRecordRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -1119,6 +1103,7 @@ class SyncSocketSession {
|
||||
if (bytes.size != 32) throw IllegalArgumentException("Publisher public key must be 32 bytes")
|
||||
packet.put(bytes)
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_GET_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingBulkGetRecordRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -1148,6 +1133,7 @@ class SyncSocketSession {
|
||||
packet.put(keyBytes.size.toByte())
|
||||
packet.put(keyBytes)
|
||||
}
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.BULK_DELETE_RECORD.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingDeleteRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -1169,6 +1155,7 @@ class SyncSocketSession {
|
||||
packet.putInt(requestId)
|
||||
packet.put(publisherBytes)
|
||||
packet.put(consumerBytes)
|
||||
packet.rewind()
|
||||
send(Opcode.REQUEST.value, RequestOpcode.LIST_RECORD_KEYS.value, packet)
|
||||
} catch (e: Exception) {
|
||||
_pendingListKeysRequests.remove(requestId)?.completeExceptionally(e)
|
||||
@@ -1178,6 +1165,12 @@ class SyncSocketSession {
|
||||
}
|
||||
|
||||
companion object {
|
||||
val dh = "25519"
|
||||
val pattern = "N"
|
||||
val cipher = "ChaChaPoly"
|
||||
val hash = "BLAKE2b"
|
||||
var nProtocolName = "Noise_${pattern}_${dh}_${cipher}_${hash}"
|
||||
|
||||
private const val TAG = "SyncSocketSession"
|
||||
const val MAXIMUM_PACKET_SIZE = 65535 - 16
|
||||
const val MAXIMUM_PACKET_SIZE_ENCRYPTED = MAXIMUM_PACKET_SIZE + 16
|
||||
|
||||
@@ -43,13 +43,13 @@ class SyncDeviceView : ConstraintLayout {
|
||||
|
||||
_layoutLinkType.visibility = View.VISIBLE
|
||||
_imageLinkType.setImageResource(when (linkType) {
|
||||
LinkType.Proxied -> R.drawable.ic_internet
|
||||
LinkType.Local -> R.drawable.ic_lan
|
||||
LinkType.Relayed -> R.drawable.ic_internet
|
||||
LinkType.Direct -> R.drawable.ic_lan
|
||||
else -> 0
|
||||
})
|
||||
_textLinkType.text = when(linkType) {
|
||||
LinkType.Proxied -> "Proxied"
|
||||
LinkType.Local -> "Local"
|
||||
LinkType.Relayed -> "Relayed"
|
||||
LinkType.Direct -> "Direct"
|
||||
else -> null
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user