OTA pipeline working

This commit is contained in:
2026-04-25 22:27:17 +03:00
parent 90037587c0
commit 8f15fe0f1d
19 changed files with 458 additions and 112 deletions

View File

@@ -7,6 +7,7 @@ import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.routing.*
import io.ktor.server.websocket.*
import org.apache.commons.logging.LogFactory
import org.pavloveugene.iot.backend.config.AppConfig
import org.pavloveugene.iot.backend.config.configureSerialization
import org.pavloveugene.iot.backend.config.configureWebSockets
@@ -22,6 +23,7 @@ import kotlin.system.exitProcess
fun main(args: Array<String>) {
Database.init()
runMigrations()

View File

@@ -17,4 +17,6 @@ object AppConfig {
val dbUser = config.getString("ktor.database.user")
val dbPassword = config.getString("ktor.database.password")
val storagePath = config.getString("iot.firmware.storage")
}

View File

@@ -1,10 +1,14 @@
package org.pavloveugene.iot.backend.db
import org.slf4j.LoggerFactory
data class Migration(
val version: Int,
val sql: String
)
private val log= LoggerFactory.getLogger("DB Migrations")
fun loadMigrations(): List<Migration> {
val cl = Thread.currentThread().contextClassLoader
@@ -81,7 +85,7 @@ fun runMigrations() {
for (m in migrations) {
if (m.version in applied) continue
println("Applying migration V${m.version}")
log.info("Applying migration V${m.version}")
try {
conn.autoCommit = false
@@ -90,7 +94,7 @@ fun runMigrations() {
for (stmt in statements) {
println("Applying migration statement:\n $stmt\n========================")
log.info("Applying migration statement:\n $stmt\n========================")
conn.createStatement().use { st ->
st.execute(stmt)

View File

@@ -4,5 +4,11 @@ import kotlinx.serialization.Serializable
@Serializable
data class EventDto (
val type: String
val t: String,
val v: Int,
val hp: Int,
val hl: Int,
val rs: Int,
val ip: Int,
val si: String,
)

View File

@@ -0,0 +1,10 @@
package org.pavloveugene.iot.backend.dto
import kotlinx.serialization.Serializable
@Serializable
data class FirmwareUpdateCommandDto(
val t: String,
val u: String,
val s: String,
)

View File

@@ -4,6 +4,7 @@ import io.ktor.http.HttpStatusCode
import io.ktor.http.content.PartData
import io.ktor.http.content.forEachPart
import io.ktor.http.content.streamProvider
import io.ktor.server.application.Application
import io.ktor.server.application.call
import io.ktor.server.request.receiveMultipart
import io.ktor.server.response.respond
@@ -11,19 +12,62 @@ import io.ktor.server.response.respondFile
import io.ktor.server.routing.Route
import io.ktor.server.routing.get
import io.ktor.server.routing.post
import kotlinx.serialization.Serializable
import org.pavloveugene.iot.backend.config.AppConfig.storagePath
import org.pavloveugene.iot.backend.db.Database
import org.pavloveugene.iot.backend.services.sendUpdateCommand
import org.pavloveugene.iot.backend.services.updateUri
import org.slf4j.LoggerFactory
import java.io.File
fun Route.firmwareRouting() {
uploadFirmware()
getFirmware()
private val log = LoggerFactory.getLogger("FirmwareRoutes")
fun Route.firmwareRouting(app: Application) {
uploadFirmware(app)
getFirmware(app)
otaTrigger()
}
private const val DEVICE_ID = "device_id"
fun Route.otaTrigger() {
post("/ota_trigger") {
// send WS command
val par = call.parameters[DEVICE_ID]
if (par.isNullOrBlank()) {
call.respond(HttpStatusCode.BadRequest, "Device id is mandatory")
return@post
}
val devId = par.toLong();
log.info("OTA trigger request. id: $par")
if (Database.queryOne("select count(0) cnt from devices where id=?", listOf(devId))
?.get("cnt") as Number == 0
) {
log.error("No device has id $devId in database")
call.respond(HttpStatusCode.BadRequest, "Invalid device id")
return@post
}
if (sendUpdateCommand(devId)) {
call.respond(HttpStatusCode.OK)
} else {
call.respond(HttpStatusCode.BadRequest, "Something went wrong")
}
}
}
fun Route.uploadFirmware() {
@Serializable
data class UploadResponse(
val status: String,
val filename: String,
val version: Int
)
fun Route.uploadFirmware(app: Application) {
post("/firmware_upload") {
log.info("Uploading Firmware file")
val multipart = call.receiveMultipart()
var deviceId: Long? = null
@@ -34,7 +78,7 @@ fun Route.uploadFirmware() {
when (part) {
is PartData.FormItem -> {
when (part.name) {
"device_id" -> deviceId = part.value.toLong()
DEVICE_ID -> deviceId = part.value.toLong()
}
}
@@ -49,70 +93,90 @@ fun Route.uploadFirmware() {
part.dispose()
}
if (deviceId == null || fileBytes == null) {
call.respond(HttpStatusCode.BadRequest, "Missing fields")
if (deviceId == null) {
call.respond(HttpStatusCode.BadRequest, "Missing device id")
log.error("Device id is missing")
return@post
}
// 👉 генерим имя файла
val filename = "${java.util.UUID.randomUUID()}.bin"
val path = "storage/firmware/$filename"
// 👉 сохраняем файл
java.io.File(path).apply {
parentFile.mkdirs()
writeBytes(fileBytes!!)
val bytes = fileBytes ?: run {
call.respond(HttpStatusCode.BadRequest, "Missing file")
log.error("File is missing")
return@post
}
// 👉 считаем sha256
val sha256 = java.security.MessageDigest
.getInstance("SHA-256")
.digest(fileBytes!!)
.joinToString("") { "%02x".format(it) }
log.info("Parameters ok")
// 👉 сохраняем в БД
while (true) {
version = ((Database.queryOne(
"""
try {
// 👉 генерим имя файла
val baseDir = File(storagePath)
val filename = "${java.util.UUID.randomUUID()}.bin"
val file = File(baseDir, filename)
val path = file.path
file.parentFile.mkdirs()
file.writeBytes(bytes)
log.info("File created: $path")
// 👉 считаем sha256
val sha256 = java.security.MessageDigest
.getInstance("SHA-256")
.digest(bytes)
.joinToString("") { "%02x".format(it) }
log.info("sha-256= $sha256")
// 👉 сохраняем в БД
while (true) {
version = ((Database.queryOne(
"""
select coalesce(max(version), 0) + 1 as v
from firmware
where device_id = ?
""",
listOf(deviceId)
)?.get("v") as Number?)?.toInt() ?: 1)
listOf(deviceId)
)?.get("v") as Number?)?.toInt() ?: 1)
try {
Database.execute(
"""
log.info("version: $version")
try {
Database.execute(
"""
insert into firmware (device_id, version, path, sha256, size)
values (?, ?, ?, ?, ?)
""", listOf(
deviceId,
version,
path,
sha256,
fileBytes.size
deviceId,
version,
path,
sha256,
bytes.size
)
)
)
break
} catch (e: Exception) {
println("Retry insert firmware: ${e.message}")
break
} catch (e: Exception) {
log.info("Retry insert firmware: ${e.message}")
}
}
}
call.respond(
mapOf(
"status" to "ok",
"filename" to filename,
"version" to version,
call.respond(
UploadResponse(
status = "ok",
filename = filename,
version = version
)
)
)
} catch (e: Exception) {
log.error(e.message, e)
}
}
}
fun Route.getFirmware() {
get("/firmware/download") {
fun Route.getFirmware(app: Application) {
get(updateUri) {
val id = call.parameters["id"]!!.toInt()
val row = Database.queryOne(

View File

@@ -9,8 +9,11 @@ import kotlinx.serialization.json.Json
import org.pavloveugene.iot.backend.config.AppConfig
import org.pavloveugene.iot.backend.services.DeviceConnections
import org.pavloveugene.iot.backend.services.ProtocolService
import org.slf4j.LoggerFactory
import java.io.IOException
private val log= LoggerFactory.getLogger("ProtocolWebSocket")
fun Route.protocolWebSocket() {
val json = Json {
ignoreUnknownKeys = false
@@ -20,7 +23,7 @@ fun Route.protocolWebSocket() {
webSocket(AppConfig.wsPath) {
println("WS connected")
log.info("WS connected")
var devId: UInt? = null
@@ -32,34 +35,36 @@ fun Route.protocolWebSocket() {
val msg = try {
json.decodeFromString<MessageDto>(text)
} catch (e: Exception) {
println("WS decode error: ${e.message}")
log.info("WS decode error: ${e.message}")
safeSend("""{"error":"invalid"}""")
continue
}
try {
devId = msg.d
protocolService.handleMessage(msg, this)
protocolService.handleMessage(msg, this, call.request)
} catch (e: Exception) {
println("WS handler error: ${e.message}")
log.info("WS handler error: ${e.message}")
}
}
}
} catch (e: CancellationException) {
// нормальный shutdown — молчим
} catch (e: IOException) {
println("WS disconnected: ${e.message}")
log.info("WS disconnected: ${e.message}")
} catch (e: Exception) {
println("WS error: ${e.message}")
log.info("WS error: ${e.message}")
} finally {
devId?.let {
DeviceConnections.unregister(devId)
println("WS disconnected: $it")
log.info("WS disconnected: $it")
}
}
}
}
suspend fun DefaultWebSocketServerSession.safeSend(text: String) {
try {
send(text)

View File

@@ -1,14 +1,17 @@
package org.pavloveugene.iot.backend.services
import org.pavloveugene.iot.backend.db.Database
import org.slf4j.LoggerFactory
import java.sql.SQLException
private val log = LoggerFactory.getLogger("CleanupService")
fun executeCleanup() {
val ds = Database.dataSource
val start = System.currentTimeMillis()
val cutoff = start / 1000 - 60 * 60 * 24 * 2
println("Begin cleanup")
log.info("Begin cleanup")
ds.connection.use { conn ->
conn.autoCommit = false
@@ -17,7 +20,8 @@ fun executeCleanup() {
var total = 0
do {
val deleted = conn.prepareStatement("""
val deleted = conn.prepareStatement(
"""
delete t
from telemetry t
join (
@@ -26,7 +30,8 @@ fun executeCleanup() {
where ts < ? and processed = true
limit 1000
) p on p.id = t.id
""".trimIndent()).use { ps ->
""".trimIndent()
).use { ps ->
ps.setLong(1, cutoff)
ps.executeUpdate()
}
@@ -34,7 +39,7 @@ fun executeCleanup() {
total += deleted
if (deleted > 0) {
println("Deleted $deleted rows (total $total)")
log.info("Deleted $deleted rows (total $total)")
}
conn.commit()
@@ -49,5 +54,5 @@ fun executeCleanup() {
}
}
println("Cleanup complete in ${System.currentTimeMillis() - start} ms")
log.info("Cleanup complete in ${System.currentTimeMillis() - start} ms")
}

View File

@@ -1,8 +1,11 @@
package org.pavloveugene.iot.backend.services
import io.ktor.server.request.ApplicationRequest
import io.ktor.websocket.WebSocketSession
data class DeviceConnection(
val session: WebSocketSession,
var lastSeen: Long
val request: ApplicationRequest,
var lastId: UInt = 0u,
var lastSeen: Long,
)

View File

@@ -16,13 +16,18 @@ import org.pavloveugene.iot.backend.routes.protocolWebSocket
import java.time.Duration
import io.ktor.server.plugins.compression.*
import org.pavloveugene.iot.backend.routes.firmwareRouting
import org.slf4j.LoggerFactory
private val log = LoggerFactory.getLogger("KtorServer")
fun startKtorServer() {
val server = embeddedServer(
Netty,
port = AppConfig.serverPort,
host = AppConfig.serverHost,
) {
install(ContentNegotiation) {
json()
}
@@ -35,18 +40,19 @@ fun startKtorServer() {
}
routing {
log.info("CONFIG = ${application.environment.config.toMap()}")
protocolRoutes()
protocolWebSocket()
firmwareRouting()
firmwareRouting(application)
}
}
Runtime.getRuntime().addShutdownHook(Thread {
println("Shutting down...")
log.info("Shutting down...")
try {
server.stop(1000, 2000)
} catch (e: Exception) {
println("Shutdown error: ${e.message}")
log.info("Shutdown error: ${e.message}")
}
})

View File

@@ -1,6 +1,9 @@
package org.pavloveugene.iot.backend.services
import org.pavloveugene.iot.backend.db.Database
import org.slf4j.LoggerFactory
private val log = LoggerFactory.getLogger("NormalizeService")
fun runNormalizeLoop() {
var total = 0
@@ -9,10 +12,10 @@ fun runNormalizeLoop() {
do {
val count = normalizeBatch()
total += count
println("Processed batch: $count")
log.info("Processed batch: $count")
} while (count > 0)
println("Done. Total processed: $total in ${System.currentTimeMillis() - start} ms")
log.info("Done. Total processed: $total in ${System.currentTimeMillis() - start} ms")
}
@@ -124,7 +127,7 @@ fun verifyData(): Boolean {
val ds = Database.dataSource
var ret = true;
println("Executing verification")
log.info("Executing verification")
ds.connection.use { conn ->
@@ -155,9 +158,9 @@ fun verifyData(): Boolean {
val count = rs.getInt(1)
ret = count == 0
if (ret) {
println("All ok!")
log.info("All ok!")
} else {
println("$count errors detected")
log.info("$count errors detected")
}
}
@@ -181,7 +184,7 @@ fun verifyData(): Boolean {
}
}
println("Verification complete")
log.info("Verification complete")
return ret
}

View File

@@ -1,7 +1,7 @@
package org.pavloveugene.iot.backend.services
import MessageDto
import io.ktor.server.websocket.WebSocketServerSession
import io.ktor.server.request.ApplicationRequest
import io.ktor.websocket.WebSocketSession
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.builtins.serializer
@@ -9,42 +9,58 @@ import kotlinx.serialization.json.Json
import org.pavloveugene.iot.backend.db.Database
import org.pavloveugene.iot.backend.dto.EventDto
import org.pavloveugene.iot.backend.dto.TelemetryDto
import org.slf4j.LoggerFactory
import java.net.InetAddress
private val log = LoggerFactory.getLogger("ProtocolService")
class ProtocolService(
private val json: Json
) {
fun handleMessage(msg: MessageDto, session: WebSocketSession) {
fun handleMessage(msg: MessageDto, session: WebSocketSession, request: ApplicationRequest) {
when (msg.t) {
MessageType.TELEMETRY -> {
handleTelemetry(msg)
}
MessageType.EVENT -> {
println("=== EVENT ===")
println(msg.p)
handleEvent(msg, session)
log.info("=== EVENT ===")
log.info("${msg.p}")
handleEvent(msg, session, request)
}
MessageType.COMMAND -> {
println("=== COMMAND ===")
println(msg.p)
log.info("=== COMMAND ===")
log.info("${msg.p}")
}
}
}
private fun handleEvent(msg: MessageDto, session: WebSocketSession) {
private fun intToIp(ip: Int): String {
val reverse = Integer.reverseBytes(ip)
val bytes = byteArrayOf(
(reverse shr 24).toByte(),
(reverse shr 16).toByte(),
(reverse shr 8).toByte(),
reverse.toByte()
)
return InetAddress.getByAddress(bytes).hostAddress
}
private fun handleEvent(msg: MessageDto, session: WebSocketSession, request: ApplicationRequest) {
val payload = json.decodeFromJsonElement(EventDto.serializer(), msg.p)
when (payload.type) {
"HB" -> {
println("=== HB devId = ${msg.d} ===")
when (payload.t) {
"hb" -> {
log.info("=== HB devId = ${msg.d} IP = ${intToIp(payload.ip)} ===")
val connection = DeviceConnections.get(msg.d)
if (connection == null) {
DeviceConnections.register(
msg.d, DeviceConnection(
session = session,
lastSeen = System.currentTimeMillis()
lastSeen = System.currentTimeMillis(),
request = request
)
)
} else {
@@ -85,18 +101,18 @@ class ProtocolService(
}
if (!isEnabled) {
println("device ${msg.d} locked, message ignored")
log.info("device ${msg.d} locked, message ignored")
conn.commit()
return
}
val payload = json.decodeFromJsonElement(TelemetryDto.serializer(), msg.p)
println("=== TELEMETRY ===")
println("device=${msg.d}")
println("ts=${msg.ts}")
println("metric=${payload.m}")
println("values=${payload.v}")
log.info("=== TELEMETRY ===")
log.info("device=${msg.d}")
log.info("ts=${msg.ts}")
log.info("metric=${payload.m}")
log.info("values=${payload.v}")
// insert telemetry
conn.prepareStatement(

View File

@@ -0,0 +1,136 @@
package org.pavloveugene.iot.backend.services
import MessageDto
import MessageType
import io.ktor.server.plugins.origin
import io.ktor.util.date.getTimeMillis
import io.ktor.websocket.send
import kotlinx.serialization.json.Json
import org.pavloveugene.iot.backend.db.Database
import org.pavloveugene.iot.backend.db.Database.queryOne
import org.pavloveugene.iot.backend.dto.FirmwareUpdateCommandDto
import org.slf4j.LoggerFactory
private val log = LoggerFactory.getLogger("OTA")
const val updateUri = "/firmware/download"
suspend fun sendUpdateCommand(
devId: Long
): Boolean {
var ret = true
log.info("Sending update command to device: $devId")
val ver = queryOne(
"""
select *
from firmware
where device_id = ?
order by version desc
limit 1
""".trimIndent(), listOf(devId)
)
if (ver == null) {
log.error("No firmware uploaded for device: $devId.")
ret = false
} else {
val firmware_id = ver["id"] as Int
val sha256 = ver["sha256"] as String
var update = queryOne(
"""
select fu.*
from firmware_updates fu
where fu.firmware_id = ?
""".trimIndent(),
listOf(firmware_id)
)
if (update == null) {
val update_id = Database.insertAndReturnId(
"""
insert into firmware_updates (firmware_id, status)
values (?, ?)
""".trimIndent(),
listOf(firmware_id, "pending")
)
update = queryOne(
"""
select fu.*
from firmware_updates fu
where fu.id = ?
""".trimIndent(),
listOf(update_id),
)
}
if (update == null) {
ret = false
log.error("No firmware sent to device: $devId. Something went wrong.")
} else {
val status = update["status"] as String
if (status == "pending") {
val deviceConnection = DeviceConnections.get(devId.toUInt())
if (deviceConnection == null) {
ret = false
log.warn("No firmware sent to device: $devId. Device not connected.")
} else {
val origin = deviceConnection.request.origin
val portPart =
if (origin.serverPort == 80 || origin.serverPort == 443) ""
else ":${origin.serverPort}"
val url = "${origin.scheme}://${origin.serverHost}$portPart$updateUri?id=${firmware_id}"
val updPayload = FirmwareUpdateCommandDto(
t = "fw",
u = url,
s = sha256,
)
deviceConnection.lastId++
val updateCommand = MessageDto(
v = 1,
id = deviceConnection.lastId,
d = devId.toUInt(),
t = MessageType.COMMAND,
ts = getTimeMillis(),
p = Json.encodeToJsonElement(FirmwareUpdateCommandDto.serializer(), updPayload),
)
val cmdJson: String = Json.encodeToJsonElement(
MessageDto.serializer(),
updateCommand
).toString();
try {
deviceConnection.session.send(cmdJson)
Database.execute(
"""
update firmware_updates
set status = 'sent'
where id = ?
""".trimIndent(), listOf(update["id"])
)
log.info("Firmware update sent for device: $devId. Command: $cmdJson")
} catch (e: Exception) {
ret = false
log.error("Error updating device $devId", e)
}
}
}
}
}
return ret
}