diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/Application.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/Application.kt index 27c7977..ae5f416 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/Application.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/Application.kt @@ -7,6 +7,7 @@ import io.ktor.server.plugins.contentnegotiation.* import io.ktor.serialization.kotlinx.json.* import io.ktor.server.routing.* import io.ktor.server.websocket.* +import org.apache.commons.logging.LogFactory import org.pavloveugene.iot.backend.config.AppConfig import org.pavloveugene.iot.backend.config.configureSerialization import org.pavloveugene.iot.backend.config.configureWebSockets @@ -22,6 +23,7 @@ import kotlin.system.exitProcess fun main(args: Array) { + Database.init() runMigrations() diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/config/AppConfig.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/config/AppConfig.kt index d06ec4d..6e86538 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/config/AppConfig.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/config/AppConfig.kt @@ -17,4 +17,6 @@ object AppConfig { val dbUser = config.getString("ktor.database.user") val dbPassword = config.getString("ktor.database.password") + val storagePath = config.getString("iot.firmware.storage") + } \ No newline at end of file diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/db/Migration.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/db/Migration.kt index 2e07247..f20db93 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/db/Migration.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/db/Migration.kt @@ -1,10 +1,14 @@ package org.pavloveugene.iot.backend.db +import org.slf4j.LoggerFactory + data class Migration( val version: Int, val sql: String ) +private val log= LoggerFactory.getLogger("DB Migrations") + fun loadMigrations(): List { val cl = Thread.currentThread().contextClassLoader @@ -81,7 +85,7 @@ fun runMigrations() { for (m in migrations) { if (m.version in applied) continue - println("Applying migration V${m.version}") + log.info("Applying migration V${m.version}") try { conn.autoCommit = false @@ -90,7 +94,7 @@ fun runMigrations() { for (stmt in statements) { - println("Applying migration statement:\n $stmt\n========================") + log.info("Applying migration statement:\n $stmt\n========================") conn.createStatement().use { st -> st.execute(stmt) diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/EventDto.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/EventDto.kt index 967275c..18b90bd 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/EventDto.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/EventDto.kt @@ -4,5 +4,11 @@ import kotlinx.serialization.Serializable @Serializable data class EventDto ( - val type: String + val t: String, + val v: Int, + val hp: Int, + val hl: Int, + val rs: Int, + val ip: Int, + val si: String, ) \ No newline at end of file diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/FirmwareUpdateCommandDto.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/FirmwareUpdateCommandDto.kt new file mode 100644 index 0000000..b1eb91e --- /dev/null +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/dto/FirmwareUpdateCommandDto.kt @@ -0,0 +1,10 @@ +package org.pavloveugene.iot.backend.dto + +import kotlinx.serialization.Serializable + +@Serializable +data class FirmwareUpdateCommandDto( + val t: String, + val u: String, + val s: String, +) diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/FirmwareRoutes.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/FirmwareRoutes.kt index eae1ae3..d2a91b7 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/FirmwareRoutes.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/FirmwareRoutes.kt @@ -4,6 +4,7 @@ import io.ktor.http.HttpStatusCode import io.ktor.http.content.PartData import io.ktor.http.content.forEachPart import io.ktor.http.content.streamProvider +import io.ktor.server.application.Application import io.ktor.server.application.call import io.ktor.server.request.receiveMultipart import io.ktor.server.response.respond @@ -11,19 +12,62 @@ import io.ktor.server.response.respondFile import io.ktor.server.routing.Route import io.ktor.server.routing.get import io.ktor.server.routing.post +import kotlinx.serialization.Serializable +import org.pavloveugene.iot.backend.config.AppConfig.storagePath import org.pavloveugene.iot.backend.db.Database +import org.pavloveugene.iot.backend.services.sendUpdateCommand +import org.pavloveugene.iot.backend.services.updateUri +import org.slf4j.LoggerFactory import java.io.File -fun Route.firmwareRouting() { - uploadFirmware() - getFirmware() +private val log = LoggerFactory.getLogger("FirmwareRoutes") + +fun Route.firmwareRouting(app: Application) { + uploadFirmware(app) + getFirmware(app) + otaTrigger() +} + +private const val DEVICE_ID = "device_id" + +fun Route.otaTrigger() { post("/ota_trigger") { - // send WS command + + val par = call.parameters[DEVICE_ID] + if (par.isNullOrBlank()) { + call.respond(HttpStatusCode.BadRequest, "Device id is mandatory") + return@post + } + val devId = par.toLong(); + log.info("OTA trigger request. id: $par") + + if (Database.queryOne("select count(0) cnt from devices where id=?", listOf(devId)) + ?.get("cnt") as Number == 0 + ) { + log.error("No device has id $devId in database") + call.respond(HttpStatusCode.BadRequest, "Invalid device id") + return@post + } + if (sendUpdateCommand(devId)) { + call.respond(HttpStatusCode.OK) + } else { + call.respond(HttpStatusCode.BadRequest, "Something went wrong") + } } } -fun Route.uploadFirmware() { +@Serializable +data class UploadResponse( + val status: String, + val filename: String, + val version: Int +) + +fun Route.uploadFirmware(app: Application) { post("/firmware_upload") { + + log.info("Uploading Firmware file") + val multipart = call.receiveMultipart() var deviceId: Long? = null @@ -34,7 +78,7 @@ fun Route.uploadFirmware() { when (part) { is PartData.FormItem -> { when (part.name) { - "device_id" -> deviceId = part.value.toLong() + DEVICE_ID -> deviceId = part.value.toLong() } } @@ -49,70 +93,90 @@ fun Route.uploadFirmware() { part.dispose() } - if (deviceId == null || fileBytes == null) { - call.respond(HttpStatusCode.BadRequest, "Missing fields") + if (deviceId == null) { + call.respond(HttpStatusCode.BadRequest, "Missing device id") + log.error("Device id is missing") return@post } - // 👉 генерим имя файла - val filename = "${java.util.UUID.randomUUID()}.bin" - val path = "storage/firmware/$filename" - - // 👉 сохраняем файл - java.io.File(path).apply { - parentFile.mkdirs() - writeBytes(fileBytes!!) + val bytes = fileBytes ?: run { + call.respond(HttpStatusCode.BadRequest, "Missing file") + log.error("File is missing") + return@post } - // 👉 считаем sha256 - val sha256 = java.security.MessageDigest - .getInstance("SHA-256") - .digest(fileBytes!!) - .joinToString("") { "%02x".format(it) } + log.info("Parameters ok") - // 👉 сохраняем в БД - while (true) { - version = ((Database.queryOne( - """ + try { + + // 👉 генерим имя файла + val baseDir = File(storagePath) + val filename = "${java.util.UUID.randomUUID()}.bin" + val file = File(baseDir, filename) + val path = file.path + + file.parentFile.mkdirs() + file.writeBytes(bytes) + + log.info("File created: $path") + + // 👉 считаем sha256 + val sha256 = java.security.MessageDigest + .getInstance("SHA-256") + .digest(bytes) + .joinToString("") { "%02x".format(it) } + + log.info("sha-256= $sha256") + + // 👉 сохраняем в БД + while (true) { + version = ((Database.queryOne( + """ select coalesce(max(version), 0) + 1 as v from firmware where device_id = ? """, - listOf(deviceId) - )?.get("v") as Number?)?.toInt() ?: 1) + listOf(deviceId) + )?.get("v") as Number?)?.toInt() ?: 1) - try { - Database.execute( - """ + log.info("version: $version") + + try { + Database.execute( + """ insert into firmware (device_id, version, path, sha256, size) values (?, ?, ?, ?, ?) """, listOf( - deviceId, - version, - path, - sha256, - fileBytes.size + deviceId, + version, + path, + sha256, + bytes.size + ) ) - ) - break - } catch (e: Exception) { - println("Retry insert firmware: ${e.message}") + break + } catch (e: Exception) { + log.info("Retry insert firmware: ${e.message}") + } } - } - call.respond( - mapOf( - "status" to "ok", - "filename" to filename, - "version" to version, + call.respond( + UploadResponse( + status = "ok", + filename = filename, + version = version + ) ) - ) + } catch (e: Exception) { + log.error(e.message, e) + } } + } -fun Route.getFirmware() { - get("/firmware/download") { +fun Route.getFirmware(app: Application) { + get(updateUri) { val id = call.parameters["id"]!!.toInt() val row = Database.queryOne( diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/ProtocolWebSocket.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/ProtocolWebSocket.kt index c027baf..7870c06 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/ProtocolWebSocket.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/routes/ProtocolWebSocket.kt @@ -9,8 +9,11 @@ import kotlinx.serialization.json.Json import org.pavloveugene.iot.backend.config.AppConfig import org.pavloveugene.iot.backend.services.DeviceConnections import org.pavloveugene.iot.backend.services.ProtocolService +import org.slf4j.LoggerFactory import java.io.IOException +private val log= LoggerFactory.getLogger("ProtocolWebSocket") + fun Route.protocolWebSocket() { val json = Json { ignoreUnknownKeys = false @@ -20,7 +23,7 @@ fun Route.protocolWebSocket() { webSocket(AppConfig.wsPath) { - println("WS connected") + log.info("WS connected") var devId: UInt? = null @@ -32,34 +35,36 @@ fun Route.protocolWebSocket() { val msg = try { json.decodeFromString(text) } catch (e: Exception) { - println("WS decode error: ${e.message}") + log.info("WS decode error: ${e.message}") safeSend("""{"error":"invalid"}""") continue } try { devId = msg.d - protocolService.handleMessage(msg, this) + protocolService.handleMessage(msg, this, call.request) } catch (e: Exception) { - println("WS handler error: ${e.message}") + log.info("WS handler error: ${e.message}") } } } } catch (e: CancellationException) { // нормальный shutdown — молчим } catch (e: IOException) { - println("WS disconnected: ${e.message}") + log.info("WS disconnected: ${e.message}") } catch (e: Exception) { - println("WS error: ${e.message}") + log.info("WS error: ${e.message}") } finally { devId?.let { DeviceConnections.unregister(devId) - println("WS disconnected: $it") + log.info("WS disconnected: $it") } } } } + + suspend fun DefaultWebSocketServerSession.safeSend(text: String) { try { send(text) diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/CleanupService.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/CleanupService.kt index c7387b4..83a7da9 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/CleanupService.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/CleanupService.kt @@ -1,14 +1,17 @@ package org.pavloveugene.iot.backend.services import org.pavloveugene.iot.backend.db.Database +import org.slf4j.LoggerFactory import java.sql.SQLException +private val log = LoggerFactory.getLogger("CleanupService") + fun executeCleanup() { val ds = Database.dataSource val start = System.currentTimeMillis() val cutoff = start / 1000 - 60 * 60 * 24 * 2 - println("Begin cleanup") + log.info("Begin cleanup") ds.connection.use { conn -> conn.autoCommit = false @@ -17,7 +20,8 @@ fun executeCleanup() { var total = 0 do { - val deleted = conn.prepareStatement(""" + val deleted = conn.prepareStatement( + """ delete t from telemetry t join ( @@ -26,7 +30,8 @@ fun executeCleanup() { where ts < ? and processed = true limit 1000 ) p on p.id = t.id - """.trimIndent()).use { ps -> + """.trimIndent() + ).use { ps -> ps.setLong(1, cutoff) ps.executeUpdate() } @@ -34,7 +39,7 @@ fun executeCleanup() { total += deleted if (deleted > 0) { - println("Deleted $deleted rows (total $total)") + log.info("Deleted $deleted rows (total $total)") } conn.commit() @@ -49,5 +54,5 @@ fun executeCleanup() { } } - println("Cleanup complete in ${System.currentTimeMillis() - start} ms") + log.info("Cleanup complete in ${System.currentTimeMillis() - start} ms") } \ No newline at end of file diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/DeviceConnection.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/DeviceConnection.kt index c3cfa6b..ff51ed6 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/DeviceConnection.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/DeviceConnection.kt @@ -1,8 +1,11 @@ package org.pavloveugene.iot.backend.services +import io.ktor.server.request.ApplicationRequest import io.ktor.websocket.WebSocketSession data class DeviceConnection( val session: WebSocketSession, - var lastSeen: Long + val request: ApplicationRequest, + var lastId: UInt = 0u, + var lastSeen: Long, ) \ No newline at end of file diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/KtorServer.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/KtorServer.kt index 37ad005..17807d7 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/KtorServer.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/KtorServer.kt @@ -16,13 +16,18 @@ import org.pavloveugene.iot.backend.routes.protocolWebSocket import java.time.Duration import io.ktor.server.plugins.compression.* import org.pavloveugene.iot.backend.routes.firmwareRouting +import org.slf4j.LoggerFactory + +private val log = LoggerFactory.getLogger("KtorServer") fun startKtorServer() { + val server = embeddedServer( Netty, port = AppConfig.serverPort, host = AppConfig.serverHost, ) { + install(ContentNegotiation) { json() } @@ -35,18 +40,19 @@ fun startKtorServer() { } routing { + log.info("CONFIG = ${application.environment.config.toMap()}") protocolRoutes() protocolWebSocket() - firmwareRouting() + firmwareRouting(application) } } Runtime.getRuntime().addShutdownHook(Thread { - println("Shutting down...") + log.info("Shutting down...") try { server.stop(1000, 2000) } catch (e: Exception) { - println("Shutdown error: ${e.message}") + log.info("Shutdown error: ${e.message}") } }) diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/NormalizeService.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/NormalizeService.kt index b5cdd25..5142de2 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/NormalizeService.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/NormalizeService.kt @@ -1,6 +1,9 @@ package org.pavloveugene.iot.backend.services import org.pavloveugene.iot.backend.db.Database +import org.slf4j.LoggerFactory + +private val log = LoggerFactory.getLogger("NormalizeService") fun runNormalizeLoop() { var total = 0 @@ -9,10 +12,10 @@ fun runNormalizeLoop() { do { val count = normalizeBatch() total += count - println("Processed batch: $count") + log.info("Processed batch: $count") } while (count > 0) - println("Done. Total processed: $total in ${System.currentTimeMillis() - start} ms") + log.info("Done. Total processed: $total in ${System.currentTimeMillis() - start} ms") } @@ -124,7 +127,7 @@ fun verifyData(): Boolean { val ds = Database.dataSource var ret = true; - println("Executing verification") + log.info("Executing verification") ds.connection.use { conn -> @@ -155,9 +158,9 @@ fun verifyData(): Boolean { val count = rs.getInt(1) ret = count == 0 if (ret) { - println("All ok!") + log.info("All ok!") } else { - println("$count errors detected") + log.info("$count errors detected") } } @@ -181,7 +184,7 @@ fun verifyData(): Boolean { } } - println("Verification complete") + log.info("Verification complete") return ret } \ No newline at end of file diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/ProtocolService.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/ProtocolService.kt index 2fb7601..390ed3e 100644 --- a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/ProtocolService.kt +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/ProtocolService.kt @@ -1,7 +1,7 @@ package org.pavloveugene.iot.backend.services import MessageDto -import io.ktor.server.websocket.WebSocketServerSession +import io.ktor.server.request.ApplicationRequest import io.ktor.websocket.WebSocketSession import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.builtins.serializer @@ -9,42 +9,58 @@ import kotlinx.serialization.json.Json import org.pavloveugene.iot.backend.db.Database import org.pavloveugene.iot.backend.dto.EventDto import org.pavloveugene.iot.backend.dto.TelemetryDto +import org.slf4j.LoggerFactory +import java.net.InetAddress + +private val log = LoggerFactory.getLogger("ProtocolService") class ProtocolService( private val json: Json ) { - fun handleMessage(msg: MessageDto, session: WebSocketSession) { + fun handleMessage(msg: MessageDto, session: WebSocketSession, request: ApplicationRequest) { when (msg.t) { MessageType.TELEMETRY -> { handleTelemetry(msg) } MessageType.EVENT -> { - println("=== EVENT ===") - println(msg.p) - handleEvent(msg, session) + log.info("=== EVENT ===") + log.info("${msg.p}") + handleEvent(msg, session, request) } MessageType.COMMAND -> { - println("=== COMMAND ===") - println(msg.p) + log.info("=== COMMAND ===") + log.info("${msg.p}") } } } - private fun handleEvent(msg: MessageDto, session: WebSocketSession) { + private fun intToIp(ip: Int): String { + val reverse = Integer.reverseBytes(ip) + val bytes = byteArrayOf( + (reverse shr 24).toByte(), + (reverse shr 16).toByte(), + (reverse shr 8).toByte(), + reverse.toByte() + ) + return InetAddress.getByAddress(bytes).hostAddress + } + + private fun handleEvent(msg: MessageDto, session: WebSocketSession, request: ApplicationRequest) { val payload = json.decodeFromJsonElement(EventDto.serializer(), msg.p) - when (payload.type) { - "HB" -> { - println("=== HB devId = ${msg.d} ===") + when (payload.t) { + "hb" -> { + log.info("=== HB devId = ${msg.d} IP = ${intToIp(payload.ip)} ===") val connection = DeviceConnections.get(msg.d) if (connection == null) { DeviceConnections.register( msg.d, DeviceConnection( session = session, - lastSeen = System.currentTimeMillis() + lastSeen = System.currentTimeMillis(), + request = request ) ) } else { @@ -85,18 +101,18 @@ class ProtocolService( } if (!isEnabled) { - println("device ${msg.d} locked, message ignored") + log.info("device ${msg.d} locked, message ignored") conn.commit() return } val payload = json.decodeFromJsonElement(TelemetryDto.serializer(), msg.p) - println("=== TELEMETRY ===") - println("device=${msg.d}") - println("ts=${msg.ts}") - println("metric=${payload.m}") - println("values=${payload.v}") + log.info("=== TELEMETRY ===") + log.info("device=${msg.d}") + log.info("ts=${msg.ts}") + log.info("metric=${payload.m}") + log.info("values=${payload.v}") // insert telemetry conn.prepareStatement( diff --git a/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/UpdateTrigger.kt b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/UpdateTrigger.kt new file mode 100644 index 0000000..22ea7f9 --- /dev/null +++ b/backend/src/main/kotlin/org/pavloveugene/iot/backend/services/UpdateTrigger.kt @@ -0,0 +1,136 @@ +package org.pavloveugene.iot.backend.services + +import MessageDto +import MessageType +import io.ktor.server.plugins.origin +import io.ktor.util.date.getTimeMillis +import io.ktor.websocket.send +import kotlinx.serialization.json.Json +import org.pavloveugene.iot.backend.db.Database +import org.pavloveugene.iot.backend.db.Database.queryOne +import org.pavloveugene.iot.backend.dto.FirmwareUpdateCommandDto +import org.slf4j.LoggerFactory + +private val log = LoggerFactory.getLogger("OTA") + +const val updateUri = "/firmware/download" + +suspend fun sendUpdateCommand( + devId: Long +): Boolean { + var ret = true + + log.info("Sending update command to device: $devId") + + val ver = queryOne( + """ + select * + from firmware + where device_id = ? + order by version desc + limit 1 + """.trimIndent(), listOf(devId) + ) + + if (ver == null) { + log.error("No firmware uploaded for device: $devId.") + ret = false + } else { + val firmware_id = ver["id"] as Int + val sha256 = ver["sha256"] as String + + var update = queryOne( + """ + select fu.* + from firmware_updates fu + where fu.firmware_id = ? + """.trimIndent(), + listOf(firmware_id) + ) + + if (update == null) { + val update_id = Database.insertAndReturnId( + """ + insert into firmware_updates (firmware_id, status) + values (?, ?) + """.trimIndent(), + listOf(firmware_id, "pending") + ) + + update = queryOne( + """ + select fu.* + from firmware_updates fu + where fu.id = ? + """.trimIndent(), + listOf(update_id), + ) + } + + if (update == null) { + ret = false + log.error("No firmware sent to device: $devId. Something went wrong.") + } else { + val status = update["status"] as String + + if (status == "pending") { + + val deviceConnection = DeviceConnections.get(devId.toUInt()) + if (deviceConnection == null) { + ret = false + log.warn("No firmware sent to device: $devId. Device not connected.") + } else { + val origin = deviceConnection.request.origin + val portPart = + if (origin.serverPort == 80 || origin.serverPort == 443) "" + else ":${origin.serverPort}" + + val url = "${origin.scheme}://${origin.serverHost}$portPart$updateUri?id=${firmware_id}" + + val updPayload = FirmwareUpdateCommandDto( + t = "fw", + u = url, + s = sha256, + ) + + deviceConnection.lastId++ + + val updateCommand = MessageDto( + v = 1, + id = deviceConnection.lastId, + d = devId.toUInt(), + t = MessageType.COMMAND, + ts = getTimeMillis(), + p = Json.encodeToJsonElement(FirmwareUpdateCommandDto.serializer(), updPayload), + ) + + val cmdJson: String = Json.encodeToJsonElement( + MessageDto.serializer(), + updateCommand + ).toString(); + + try { + deviceConnection.session.send(cmdJson) + + Database.execute( + """ + update firmware_updates + set status = 'sent' + where id = ? + """.trimIndent(), listOf(update["id"]) + ) + + log.info("Firmware update sent for device: $devId. Command: $cmdJson") + + } catch (e: Exception) { + ret = false + log.error("Error updating device $devId", e) + } + + } + } + } + + } + return ret +} \ No newline at end of file diff --git a/esp32/CMakeLists.txt b/esp32/CMakeLists.txt index 67c0a87..3bd4463 100644 --- a/esp32/CMakeLists.txt +++ b/esp32/CMakeLists.txt @@ -4,6 +4,7 @@ cmake_minimum_required(VERSION 3.22) file(READ "${CMAKE_SOURCE_DIR}/version.txt" PROJECT_VER) string(STRIP "${PROJECT_VER}" PROJECT_VER) +add_compile_definitions(FW_VERSION=${PROJECT_VER}) include($ENV{IDF_PATH}/tools/cmake/project.cmake) project(esp32) diff --git a/esp32/main/CMakeLists.txt b/esp32/main/CMakeLists.txt index 0a9c12d..d3e3a22 100644 --- a/esp32/main/CMakeLists.txt +++ b/esp32/main/CMakeLists.txt @@ -29,6 +29,7 @@ idf_component_register( esp_driver_gpio esp_http_client app_update + mbedtls ) # добавляем кастомный Kconfig set(COMPONENT_KCONFIG "Kconfig") \ No newline at end of file diff --git a/esp32/main/heartbeat_task.cpp b/esp32/main/heartbeat_task.cpp index 893e50a..64fa80a 100644 --- a/esp32/main/heartbeat_task.cpp +++ b/esp32/main/heartbeat_task.cpp @@ -22,9 +22,11 @@ static void heartbeat_task(void* arg) time(&now); protocol_send_event_hb((uint32_t)now); + vTaskDelay(pdMS_TO_TICKS(30000)); // 30 сек + } else + { + vTaskDelay(pdMS_TO_TICKS(2000)); } - - vTaskDelay(pdMS_TO_TICKS(10000)); // 10 сек } } diff --git a/esp32/main/ota.cpp b/esp32/main/ota.cpp index 9df6c3f..fbfb512 100644 --- a/esp32/main/ota.cpp +++ b/esp32/main/ota.cpp @@ -8,57 +8,102 @@ #include "ws.h" #include "heartbeat_task.h" #include "fw_command.h" +#include "psa/crypto.h" static TaskHandle_t ota_task_handle = nullptr; static char url_copy[256]; static char sha256[65]; +#define TAG "OTA" + void perform_ota(void* par) { + sampler_task_stop(); + sender_task_stop(); + heartbeat_task_stop(); + ws_disconnect(); + + ESP_LOGI(TAG, "OTA task started"); esp_ota_handle_t ota_handle; uint8_t buf[1024]; int read_bytes; const esp_partition_t* partition = nullptr; + int content_length = 0; + psa_hash_operation_t op = PSA_HASH_OPERATION_INIT; const char* url = static_cast(par); + ESP_LOGI(TAG, "URL: %s", url); + esp_http_client_config_t config = {}; config.url = url; config.timeout_ms = 5000; config.buffer_size = 1024; - esp_http_client_handle_t client = esp_http_client_init(&config); + ESP_LOGI(TAG, "HTTP init complete"); + + if (esp_http_client_open(client, 0) != ESP_OK) { goto cleanup; } + ESP_LOGI(TAG, "HTTP connection established"); + + content_length = esp_http_client_fetch_headers(client); + ESP_LOGI(TAG, "content_length=%d", content_length); + partition = esp_ota_get_next_update_partition(nullptr); + ESP_LOGI(TAG, "OTA update partition selected %s (offset 0x%08x)", partition->label, partition->address); + if (esp_ota_begin(partition, OTA_SIZE_UNKNOWN, &ota_handle) != ESP_OK) { goto cleanup; } - while ((read_bytes = esp_http_client_read(client, (char*)buf, sizeof(buf))) > 0) + psa_crypto_init(); + psa_hash_setup(&op, PSA_ALG_SHA_256); + + ESP_LOGI(TAG, "OTA update started"); + + while (1) { - if (esp_ota_write(ota_handle, buf, read_bytes) != ESP_OK) + int data_read = esp_http_client_read(client, (char*)buf, sizeof(buf)); + + if (data_read < 0) { - esp_ota_end(ota_handle); + ESP_LOGE(TAG, "read error"); + break; + } + else if (data_read == 0) + { + ESP_LOGI(TAG, "download finished"); + break; + } + + ESP_LOGI(TAG, "write chunk: %d bytes", data_read); + + if (esp_ota_write(ota_handle, buf, data_read) != ESP_OK) + { + ESP_LOGE(TAG, "flash write failed"); goto cleanup; } + if (psa_hash_update(&op, buf, data_read) != PSA_SUCCESS) { + ESP_LOGE(TAG, "hash update failed"); + } } + ESP_LOGI(TAG, "finalizing..."); + if (esp_ota_end(ota_handle) == ESP_OK) { uint8_t hash[32]; + size_t hash_len; - if (esp_partition_get_sha256(partition, hash) != ESP_OK) - { - goto cleanup; - } + psa_hash_finish(&op, hash, sizeof(hash), &hash_len); // переводим в hex char hash_str[65]; @@ -68,16 +113,24 @@ void perform_ota(void* par) } hash_str[64] = '\0'; + ESP_LOGI(TAG, "sha256 = %s. Comparing...", hash_str); + // сравнение if (strncmp(hash_str, sha256, 64) != 0) { - printf("SHA256 mismatch!\n"); + ESP_LOGE(TAG, "SHA256 mismatch!"); goto cleanup; } + ESP_LOGI(TAG, "set boot partition"); esp_ota_set_boot_partition(partition); + ESP_LOGI(TAG, "restart pending"); esp_restart(); } + else + { + ESP_LOGE(TAG, "ota_end failed"); + } cleanup: esp_http_client_cleanup(client); @@ -89,10 +142,6 @@ void ota_task_start(const fw_cmd_t* cmd) { strncpy(url_copy, cmd->url, sizeof(url_copy)); strncpy(sha256, cmd->sha256, sizeof(sha256)); - sampler_task_stop(); - sender_task_stop(); - heartbeat_task_stop(); - ws_disconnect(); if (ota_task_handle == nullptr) { xTaskCreate( diff --git a/esp32/main/protocol.cpp b/esp32/main/protocol.cpp index ed843a5..76c8ddc 100644 --- a/esp32/main/protocol.cpp +++ b/esp32/main/protocol.cpp @@ -3,6 +3,8 @@ #include #include #include +#include "esp_wifi.h" +#include "esp_netif.h" #define PROTOCOL_VERSION 1 @@ -195,16 +197,45 @@ uint32_t protocol_next_id() void protocol_send_event_hb(int64_t ts) { - // формируешь JSON строго по контракту + char buf[256]; - // пример: - char buf[128]; + wifi_ap_record_t ap; + esp_wifi_sta_get_ap_info(&ap); + + int rssi = ap.rssi; + const char *ssid = (const char *)ap.ssid; + + esp_netif_ip_info_t ip_info; + esp_netif_t *netif = esp_netif_get_handle_from_ifkey("WIFI_STA_DEF"); + + esp_netif_get_ip_info(netif, &ip_info); + + uint32_t ip = ip_info.ip.addr; + + size_t heap_free = heap_caps_get_free_size(MALLOC_CAP_DEFAULT); + size_t heap_largest = heap_caps_get_largest_free_block(MALLOC_CAP_DEFAULT); uint32_t id = protocol_next_id(); snprintf(buf, sizeof(buf), - "{\"v\":1,\"id\":%" PRIu32 ",\"t\":\"e\",\"ts\":%" PRIu64 ",\"d\":%u,\"p\":{\"type\":\"hb\"}}", - id, ts, CONFIG_DEVICE_ID); + "{\"v\":1,\"id\":%" PRIu32 ",\"t\":\"e\",\"ts\":%" PRIu64 ",\"d\":%u," + "\"p\":{" + "\"t\":\"hb\"," + "\"v\":%u," + "\"hp\":%u," + "\"hl\":%u," + "\"rs\":%d," + "\"ip\":%" PRIu32 "," + "\"si\":\"%s\"" + "}}", + id, ts, CONFIG_DEVICE_ID, + FW_VERSION, + heap_free, + heap_largest, + rssi, + ip, + ssid +); ws_send(buf); } diff --git a/esp32/version.txt b/esp32/version.txt index 56a6051..301160a 100644 --- a/esp32/version.txt +++ b/esp32/version.txt @@ -1 +1 @@ -1 \ No newline at end of file +8 \ No newline at end of file