OTA pipeline working

This commit is contained in:
2026-04-25 22:27:17 +03:00
parent 90037587c0
commit 8f15fe0f1d
19 changed files with 458 additions and 112 deletions

View File

@@ -7,6 +7,7 @@ import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.serialization.kotlinx.json.* import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.routing.* import io.ktor.server.routing.*
import io.ktor.server.websocket.* import io.ktor.server.websocket.*
import org.apache.commons.logging.LogFactory
import org.pavloveugene.iot.backend.config.AppConfig import org.pavloveugene.iot.backend.config.AppConfig
import org.pavloveugene.iot.backend.config.configureSerialization import org.pavloveugene.iot.backend.config.configureSerialization
import org.pavloveugene.iot.backend.config.configureWebSockets import org.pavloveugene.iot.backend.config.configureWebSockets
@@ -22,6 +23,7 @@ import kotlin.system.exitProcess
fun main(args: Array<String>) { fun main(args: Array<String>) {
Database.init() Database.init()
runMigrations() runMigrations()

View File

@@ -17,4 +17,6 @@ object AppConfig {
val dbUser = config.getString("ktor.database.user") val dbUser = config.getString("ktor.database.user")
val dbPassword = config.getString("ktor.database.password") val dbPassword = config.getString("ktor.database.password")
val storagePath = config.getString("iot.firmware.storage")
} }

View File

@@ -1,10 +1,14 @@
package org.pavloveugene.iot.backend.db package org.pavloveugene.iot.backend.db
import org.slf4j.LoggerFactory
data class Migration( data class Migration(
val version: Int, val version: Int,
val sql: String val sql: String
) )
private val log= LoggerFactory.getLogger("DB Migrations")
fun loadMigrations(): List<Migration> { fun loadMigrations(): List<Migration> {
val cl = Thread.currentThread().contextClassLoader val cl = Thread.currentThread().contextClassLoader
@@ -81,7 +85,7 @@ fun runMigrations() {
for (m in migrations) { for (m in migrations) {
if (m.version in applied) continue if (m.version in applied) continue
println("Applying migration V${m.version}") log.info("Applying migration V${m.version}")
try { try {
conn.autoCommit = false conn.autoCommit = false
@@ -90,7 +94,7 @@ fun runMigrations() {
for (stmt in statements) { for (stmt in statements) {
println("Applying migration statement:\n $stmt\n========================") log.info("Applying migration statement:\n $stmt\n========================")
conn.createStatement().use { st -> conn.createStatement().use { st ->
st.execute(stmt) st.execute(stmt)

View File

@@ -4,5 +4,11 @@ import kotlinx.serialization.Serializable
@Serializable @Serializable
data class EventDto ( data class EventDto (
val type: String val t: String,
val v: Int,
val hp: Int,
val hl: Int,
val rs: Int,
val ip: Int,
val si: String,
) )

View File

@@ -0,0 +1,10 @@
package org.pavloveugene.iot.backend.dto
import kotlinx.serialization.Serializable
@Serializable
data class FirmwareUpdateCommandDto(
val t: String,
val u: String,
val s: String,
)

View File

@@ -4,6 +4,7 @@ import io.ktor.http.HttpStatusCode
import io.ktor.http.content.PartData import io.ktor.http.content.PartData
import io.ktor.http.content.forEachPart import io.ktor.http.content.forEachPart
import io.ktor.http.content.streamProvider import io.ktor.http.content.streamProvider
import io.ktor.server.application.Application
import io.ktor.server.application.call import io.ktor.server.application.call
import io.ktor.server.request.receiveMultipart import io.ktor.server.request.receiveMultipart
import io.ktor.server.response.respond import io.ktor.server.response.respond
@@ -11,19 +12,62 @@ import io.ktor.server.response.respondFile
import io.ktor.server.routing.Route import io.ktor.server.routing.Route
import io.ktor.server.routing.get import io.ktor.server.routing.get
import io.ktor.server.routing.post import io.ktor.server.routing.post
import kotlinx.serialization.Serializable
import org.pavloveugene.iot.backend.config.AppConfig.storagePath
import org.pavloveugene.iot.backend.db.Database import org.pavloveugene.iot.backend.db.Database
import org.pavloveugene.iot.backend.services.sendUpdateCommand
import org.pavloveugene.iot.backend.services.updateUri
import org.slf4j.LoggerFactory
import java.io.File import java.io.File
fun Route.firmwareRouting() { private val log = LoggerFactory.getLogger("FirmwareRoutes")
uploadFirmware()
getFirmware() fun Route.firmwareRouting(app: Application) {
uploadFirmware(app)
getFirmware(app)
otaTrigger()
}
private const val DEVICE_ID = "device_id"
fun Route.otaTrigger() {
post("/ota_trigger") { post("/ota_trigger") {
// send WS command
val par = call.parameters[DEVICE_ID]
if (par.isNullOrBlank()) {
call.respond(HttpStatusCode.BadRequest, "Device id is mandatory")
return@post
}
val devId = par.toLong();
log.info("OTA trigger request. id: $par")
if (Database.queryOne("select count(0) cnt from devices where id=?", listOf(devId))
?.get("cnt") as Number == 0
) {
log.error("No device has id $devId in database")
call.respond(HttpStatusCode.BadRequest, "Invalid device id")
return@post
}
if (sendUpdateCommand(devId)) {
call.respond(HttpStatusCode.OK)
} else {
call.respond(HttpStatusCode.BadRequest, "Something went wrong")
}
} }
} }
fun Route.uploadFirmware() { @Serializable
data class UploadResponse(
val status: String,
val filename: String,
val version: Int
)
fun Route.uploadFirmware(app: Application) {
post("/firmware_upload") { post("/firmware_upload") {
log.info("Uploading Firmware file")
val multipart = call.receiveMultipart() val multipart = call.receiveMultipart()
var deviceId: Long? = null var deviceId: Long? = null
@@ -34,7 +78,7 @@ fun Route.uploadFirmware() {
when (part) { when (part) {
is PartData.FormItem -> { is PartData.FormItem -> {
when (part.name) { when (part.name) {
"device_id" -> deviceId = part.value.toLong() DEVICE_ID -> deviceId = part.value.toLong()
} }
} }
@@ -49,70 +93,90 @@ fun Route.uploadFirmware() {
part.dispose() part.dispose()
} }
if (deviceId == null || fileBytes == null) { if (deviceId == null) {
call.respond(HttpStatusCode.BadRequest, "Missing fields") call.respond(HttpStatusCode.BadRequest, "Missing device id")
log.error("Device id is missing")
return@post return@post
} }
// 👉 генерим имя файла val bytes = fileBytes ?: run {
val filename = "${java.util.UUID.randomUUID()}.bin" call.respond(HttpStatusCode.BadRequest, "Missing file")
val path = "storage/firmware/$filename" log.error("File is missing")
return@post
// 👉 сохраняем файл
java.io.File(path).apply {
parentFile.mkdirs()
writeBytes(fileBytes!!)
} }
// 👉 считаем sha256 log.info("Parameters ok")
val sha256 = java.security.MessageDigest
.getInstance("SHA-256")
.digest(fileBytes!!)
.joinToString("") { "%02x".format(it) }
// 👉 сохраняем в БД try {
while (true) {
version = ((Database.queryOne( // 👉 генерим имя файла
""" val baseDir = File(storagePath)
val filename = "${java.util.UUID.randomUUID()}.bin"
val file = File(baseDir, filename)
val path = file.path
file.parentFile.mkdirs()
file.writeBytes(bytes)
log.info("File created: $path")
// 👉 считаем sha256
val sha256 = java.security.MessageDigest
.getInstance("SHA-256")
.digest(bytes)
.joinToString("") { "%02x".format(it) }
log.info("sha-256= $sha256")
// 👉 сохраняем в БД
while (true) {
version = ((Database.queryOne(
"""
select coalesce(max(version), 0) + 1 as v select coalesce(max(version), 0) + 1 as v
from firmware from firmware
where device_id = ? where device_id = ?
""", """,
listOf(deviceId) listOf(deviceId)
)?.get("v") as Number?)?.toInt() ?: 1) )?.get("v") as Number?)?.toInt() ?: 1)
try { log.info("version: $version")
Database.execute(
""" try {
Database.execute(
"""
insert into firmware (device_id, version, path, sha256, size) insert into firmware (device_id, version, path, sha256, size)
values (?, ?, ?, ?, ?) values (?, ?, ?, ?, ?)
""", listOf( """, listOf(
deviceId, deviceId,
version, version,
path, path,
sha256, sha256,
fileBytes.size bytes.size
)
) )
)
break break
} catch (e: Exception) { } catch (e: Exception) {
println("Retry insert firmware: ${e.message}") log.info("Retry insert firmware: ${e.message}")
}
} }
}
call.respond( call.respond(
mapOf( UploadResponse(
"status" to "ok", status = "ok",
"filename" to filename, filename = filename,
"version" to version, version = version
)
) )
) } catch (e: Exception) {
log.error(e.message, e)
}
} }
} }
fun Route.getFirmware() { fun Route.getFirmware(app: Application) {
get("/firmware/download") { get(updateUri) {
val id = call.parameters["id"]!!.toInt() val id = call.parameters["id"]!!.toInt()
val row = Database.queryOne( val row = Database.queryOne(

View File

@@ -9,8 +9,11 @@ import kotlinx.serialization.json.Json
import org.pavloveugene.iot.backend.config.AppConfig import org.pavloveugene.iot.backend.config.AppConfig
import org.pavloveugene.iot.backend.services.DeviceConnections import org.pavloveugene.iot.backend.services.DeviceConnections
import org.pavloveugene.iot.backend.services.ProtocolService import org.pavloveugene.iot.backend.services.ProtocolService
import org.slf4j.LoggerFactory
import java.io.IOException import java.io.IOException
private val log= LoggerFactory.getLogger("ProtocolWebSocket")
fun Route.protocolWebSocket() { fun Route.protocolWebSocket() {
val json = Json { val json = Json {
ignoreUnknownKeys = false ignoreUnknownKeys = false
@@ -20,7 +23,7 @@ fun Route.protocolWebSocket() {
webSocket(AppConfig.wsPath) { webSocket(AppConfig.wsPath) {
println("WS connected") log.info("WS connected")
var devId: UInt? = null var devId: UInt? = null
@@ -32,34 +35,36 @@ fun Route.protocolWebSocket() {
val msg = try { val msg = try {
json.decodeFromString<MessageDto>(text) json.decodeFromString<MessageDto>(text)
} catch (e: Exception) { } catch (e: Exception) {
println("WS decode error: ${e.message}") log.info("WS decode error: ${e.message}")
safeSend("""{"error":"invalid"}""") safeSend("""{"error":"invalid"}""")
continue continue
} }
try { try {
devId = msg.d devId = msg.d
protocolService.handleMessage(msg, this) protocolService.handleMessage(msg, this, call.request)
} catch (e: Exception) { } catch (e: Exception) {
println("WS handler error: ${e.message}") log.info("WS handler error: ${e.message}")
} }
} }
} }
} catch (e: CancellationException) { } catch (e: CancellationException) {
// нормальный shutdown — молчим // нормальный shutdown — молчим
} catch (e: IOException) { } catch (e: IOException) {
println("WS disconnected: ${e.message}") log.info("WS disconnected: ${e.message}")
} catch (e: Exception) { } catch (e: Exception) {
println("WS error: ${e.message}") log.info("WS error: ${e.message}")
} finally { } finally {
devId?.let { devId?.let {
DeviceConnections.unregister(devId) DeviceConnections.unregister(devId)
println("WS disconnected: $it") log.info("WS disconnected: $it")
} }
} }
} }
} }
suspend fun DefaultWebSocketServerSession.safeSend(text: String) { suspend fun DefaultWebSocketServerSession.safeSend(text: String) {
try { try {
send(text) send(text)

View File

@@ -1,14 +1,17 @@
package org.pavloveugene.iot.backend.services package org.pavloveugene.iot.backend.services
import org.pavloveugene.iot.backend.db.Database import org.pavloveugene.iot.backend.db.Database
import org.slf4j.LoggerFactory
import java.sql.SQLException import java.sql.SQLException
private val log = LoggerFactory.getLogger("CleanupService")
fun executeCleanup() { fun executeCleanup() {
val ds = Database.dataSource val ds = Database.dataSource
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
val cutoff = start / 1000 - 60 * 60 * 24 * 2 val cutoff = start / 1000 - 60 * 60 * 24 * 2
println("Begin cleanup") log.info("Begin cleanup")
ds.connection.use { conn -> ds.connection.use { conn ->
conn.autoCommit = false conn.autoCommit = false
@@ -17,7 +20,8 @@ fun executeCleanup() {
var total = 0 var total = 0
do { do {
val deleted = conn.prepareStatement(""" val deleted = conn.prepareStatement(
"""
delete t delete t
from telemetry t from telemetry t
join ( join (
@@ -26,7 +30,8 @@ fun executeCleanup() {
where ts < ? and processed = true where ts < ? and processed = true
limit 1000 limit 1000
) p on p.id = t.id ) p on p.id = t.id
""".trimIndent()).use { ps -> """.trimIndent()
).use { ps ->
ps.setLong(1, cutoff) ps.setLong(1, cutoff)
ps.executeUpdate() ps.executeUpdate()
} }
@@ -34,7 +39,7 @@ fun executeCleanup() {
total += deleted total += deleted
if (deleted > 0) { if (deleted > 0) {
println("Deleted $deleted rows (total $total)") log.info("Deleted $deleted rows (total $total)")
} }
conn.commit() conn.commit()
@@ -49,5 +54,5 @@ fun executeCleanup() {
} }
} }
println("Cleanup complete in ${System.currentTimeMillis() - start} ms") log.info("Cleanup complete in ${System.currentTimeMillis() - start} ms")
} }

View File

@@ -1,8 +1,11 @@
package org.pavloveugene.iot.backend.services package org.pavloveugene.iot.backend.services
import io.ktor.server.request.ApplicationRequest
import io.ktor.websocket.WebSocketSession import io.ktor.websocket.WebSocketSession
data class DeviceConnection( data class DeviceConnection(
val session: WebSocketSession, val session: WebSocketSession,
var lastSeen: Long val request: ApplicationRequest,
var lastId: UInt = 0u,
var lastSeen: Long,
) )

View File

@@ -16,13 +16,18 @@ import org.pavloveugene.iot.backend.routes.protocolWebSocket
import java.time.Duration import java.time.Duration
import io.ktor.server.plugins.compression.* import io.ktor.server.plugins.compression.*
import org.pavloveugene.iot.backend.routes.firmwareRouting import org.pavloveugene.iot.backend.routes.firmwareRouting
import org.slf4j.LoggerFactory
private val log = LoggerFactory.getLogger("KtorServer")
fun startKtorServer() { fun startKtorServer() {
val server = embeddedServer( val server = embeddedServer(
Netty, Netty,
port = AppConfig.serverPort, port = AppConfig.serverPort,
host = AppConfig.serverHost, host = AppConfig.serverHost,
) { ) {
install(ContentNegotiation) { install(ContentNegotiation) {
json() json()
} }
@@ -35,18 +40,19 @@ fun startKtorServer() {
} }
routing { routing {
log.info("CONFIG = ${application.environment.config.toMap()}")
protocolRoutes() protocolRoutes()
protocolWebSocket() protocolWebSocket()
firmwareRouting() firmwareRouting(application)
} }
} }
Runtime.getRuntime().addShutdownHook(Thread { Runtime.getRuntime().addShutdownHook(Thread {
println("Shutting down...") log.info("Shutting down...")
try { try {
server.stop(1000, 2000) server.stop(1000, 2000)
} catch (e: Exception) { } catch (e: Exception) {
println("Shutdown error: ${e.message}") log.info("Shutdown error: ${e.message}")
} }
}) })

View File

@@ -1,6 +1,9 @@
package org.pavloveugene.iot.backend.services package org.pavloveugene.iot.backend.services
import org.pavloveugene.iot.backend.db.Database import org.pavloveugene.iot.backend.db.Database
import org.slf4j.LoggerFactory
private val log = LoggerFactory.getLogger("NormalizeService")
fun runNormalizeLoop() { fun runNormalizeLoop() {
var total = 0 var total = 0
@@ -9,10 +12,10 @@ fun runNormalizeLoop() {
do { do {
val count = normalizeBatch() val count = normalizeBatch()
total += count total += count
println("Processed batch: $count") log.info("Processed batch: $count")
} while (count > 0) } while (count > 0)
println("Done. Total processed: $total in ${System.currentTimeMillis() - start} ms") log.info("Done. Total processed: $total in ${System.currentTimeMillis() - start} ms")
} }
@@ -124,7 +127,7 @@ fun verifyData(): Boolean {
val ds = Database.dataSource val ds = Database.dataSource
var ret = true; var ret = true;
println("Executing verification") log.info("Executing verification")
ds.connection.use { conn -> ds.connection.use { conn ->
@@ -155,9 +158,9 @@ fun verifyData(): Boolean {
val count = rs.getInt(1) val count = rs.getInt(1)
ret = count == 0 ret = count == 0
if (ret) { if (ret) {
println("All ok!") log.info("All ok!")
} else { } else {
println("$count errors detected") log.info("$count errors detected")
} }
} }
@@ -181,7 +184,7 @@ fun verifyData(): Boolean {
} }
} }
println("Verification complete") log.info("Verification complete")
return ret return ret
} }

View File

@@ -1,7 +1,7 @@
package org.pavloveugene.iot.backend.services package org.pavloveugene.iot.backend.services
import MessageDto import MessageDto
import io.ktor.server.websocket.WebSocketServerSession import io.ktor.server.request.ApplicationRequest
import io.ktor.websocket.WebSocketSession import io.ktor.websocket.WebSocketSession
import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.builtins.serializer import kotlinx.serialization.builtins.serializer
@@ -9,42 +9,58 @@ import kotlinx.serialization.json.Json
import org.pavloveugene.iot.backend.db.Database import org.pavloveugene.iot.backend.db.Database
import org.pavloveugene.iot.backend.dto.EventDto import org.pavloveugene.iot.backend.dto.EventDto
import org.pavloveugene.iot.backend.dto.TelemetryDto import org.pavloveugene.iot.backend.dto.TelemetryDto
import org.slf4j.LoggerFactory
import java.net.InetAddress
private val log = LoggerFactory.getLogger("ProtocolService")
class ProtocolService( class ProtocolService(
private val json: Json private val json: Json
) { ) {
fun handleMessage(msg: MessageDto, session: WebSocketSession) { fun handleMessage(msg: MessageDto, session: WebSocketSession, request: ApplicationRequest) {
when (msg.t) { when (msg.t) {
MessageType.TELEMETRY -> { MessageType.TELEMETRY -> {
handleTelemetry(msg) handleTelemetry(msg)
} }
MessageType.EVENT -> { MessageType.EVENT -> {
println("=== EVENT ===") log.info("=== EVENT ===")
println(msg.p) log.info("${msg.p}")
handleEvent(msg, session) handleEvent(msg, session, request)
} }
MessageType.COMMAND -> { MessageType.COMMAND -> {
println("=== COMMAND ===") log.info("=== COMMAND ===")
println(msg.p) log.info("${msg.p}")
} }
} }
} }
private fun handleEvent(msg: MessageDto, session: WebSocketSession) { private fun intToIp(ip: Int): String {
val reverse = Integer.reverseBytes(ip)
val bytes = byteArrayOf(
(reverse shr 24).toByte(),
(reverse shr 16).toByte(),
(reverse shr 8).toByte(),
reverse.toByte()
)
return InetAddress.getByAddress(bytes).hostAddress
}
private fun handleEvent(msg: MessageDto, session: WebSocketSession, request: ApplicationRequest) {
val payload = json.decodeFromJsonElement(EventDto.serializer(), msg.p) val payload = json.decodeFromJsonElement(EventDto.serializer(), msg.p)
when (payload.type) { when (payload.t) {
"HB" -> { "hb" -> {
println("=== HB devId = ${msg.d} ===") log.info("=== HB devId = ${msg.d} IP = ${intToIp(payload.ip)} ===")
val connection = DeviceConnections.get(msg.d) val connection = DeviceConnections.get(msg.d)
if (connection == null) { if (connection == null) {
DeviceConnections.register( DeviceConnections.register(
msg.d, DeviceConnection( msg.d, DeviceConnection(
session = session, session = session,
lastSeen = System.currentTimeMillis() lastSeen = System.currentTimeMillis(),
request = request
) )
) )
} else { } else {
@@ -85,18 +101,18 @@ class ProtocolService(
} }
if (!isEnabled) { if (!isEnabled) {
println("device ${msg.d} locked, message ignored") log.info("device ${msg.d} locked, message ignored")
conn.commit() conn.commit()
return return
} }
val payload = json.decodeFromJsonElement(TelemetryDto.serializer(), msg.p) val payload = json.decodeFromJsonElement(TelemetryDto.serializer(), msg.p)
println("=== TELEMETRY ===") log.info("=== TELEMETRY ===")
println("device=${msg.d}") log.info("device=${msg.d}")
println("ts=${msg.ts}") log.info("ts=${msg.ts}")
println("metric=${payload.m}") log.info("metric=${payload.m}")
println("values=${payload.v}") log.info("values=${payload.v}")
// insert telemetry // insert telemetry
conn.prepareStatement( conn.prepareStatement(

View File

@@ -0,0 +1,136 @@
package org.pavloveugene.iot.backend.services
import MessageDto
import MessageType
import io.ktor.server.plugins.origin
import io.ktor.util.date.getTimeMillis
import io.ktor.websocket.send
import kotlinx.serialization.json.Json
import org.pavloveugene.iot.backend.db.Database
import org.pavloveugene.iot.backend.db.Database.queryOne
import org.pavloveugene.iot.backend.dto.FirmwareUpdateCommandDto
import org.slf4j.LoggerFactory
private val log = LoggerFactory.getLogger("OTA")
const val updateUri = "/firmware/download"
suspend fun sendUpdateCommand(
devId: Long
): Boolean {
var ret = true
log.info("Sending update command to device: $devId")
val ver = queryOne(
"""
select *
from firmware
where device_id = ?
order by version desc
limit 1
""".trimIndent(), listOf(devId)
)
if (ver == null) {
log.error("No firmware uploaded for device: $devId.")
ret = false
} else {
val firmware_id = ver["id"] as Int
val sha256 = ver["sha256"] as String
var update = queryOne(
"""
select fu.*
from firmware_updates fu
where fu.firmware_id = ?
""".trimIndent(),
listOf(firmware_id)
)
if (update == null) {
val update_id = Database.insertAndReturnId(
"""
insert into firmware_updates (firmware_id, status)
values (?, ?)
""".trimIndent(),
listOf(firmware_id, "pending")
)
update = queryOne(
"""
select fu.*
from firmware_updates fu
where fu.id = ?
""".trimIndent(),
listOf(update_id),
)
}
if (update == null) {
ret = false
log.error("No firmware sent to device: $devId. Something went wrong.")
} else {
val status = update["status"] as String
if (status == "pending") {
val deviceConnection = DeviceConnections.get(devId.toUInt())
if (deviceConnection == null) {
ret = false
log.warn("No firmware sent to device: $devId. Device not connected.")
} else {
val origin = deviceConnection.request.origin
val portPart =
if (origin.serverPort == 80 || origin.serverPort == 443) ""
else ":${origin.serverPort}"
val url = "${origin.scheme}://${origin.serverHost}$portPart$updateUri?id=${firmware_id}"
val updPayload = FirmwareUpdateCommandDto(
t = "fw",
u = url,
s = sha256,
)
deviceConnection.lastId++
val updateCommand = MessageDto(
v = 1,
id = deviceConnection.lastId,
d = devId.toUInt(),
t = MessageType.COMMAND,
ts = getTimeMillis(),
p = Json.encodeToJsonElement(FirmwareUpdateCommandDto.serializer(), updPayload),
)
val cmdJson: String = Json.encodeToJsonElement(
MessageDto.serializer(),
updateCommand
).toString();
try {
deviceConnection.session.send(cmdJson)
Database.execute(
"""
update firmware_updates
set status = 'sent'
where id = ?
""".trimIndent(), listOf(update["id"])
)
log.info("Firmware update sent for device: $devId. Command: $cmdJson")
} catch (e: Exception) {
ret = false
log.error("Error updating device $devId", e)
}
}
}
}
}
return ret
}

View File

@@ -4,6 +4,7 @@ cmake_minimum_required(VERSION 3.22)
file(READ "${CMAKE_SOURCE_DIR}/version.txt" PROJECT_VER) file(READ "${CMAKE_SOURCE_DIR}/version.txt" PROJECT_VER)
string(STRIP "${PROJECT_VER}" PROJECT_VER) string(STRIP "${PROJECT_VER}" PROJECT_VER)
add_compile_definitions(FW_VERSION=${PROJECT_VER})
include($ENV{IDF_PATH}/tools/cmake/project.cmake) include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(esp32) project(esp32)

View File

@@ -29,6 +29,7 @@ idf_component_register(
esp_driver_gpio esp_driver_gpio
esp_http_client esp_http_client
app_update app_update
mbedtls
) )
# добавляем кастомный Kconfig # добавляем кастомный Kconfig
set(COMPONENT_KCONFIG "Kconfig") set(COMPONENT_KCONFIG "Kconfig")

View File

@@ -22,9 +22,11 @@ static void heartbeat_task(void* arg)
time(&now); time(&now);
protocol_send_event_hb((uint32_t)now); protocol_send_event_hb((uint32_t)now);
vTaskDelay(pdMS_TO_TICKS(30000)); // 30 сек
} else
{
vTaskDelay(pdMS_TO_TICKS(2000));
} }
vTaskDelay(pdMS_TO_TICKS(10000)); // 10 сек
} }
} }

View File

@@ -8,57 +8,102 @@
#include "ws.h" #include "ws.h"
#include "heartbeat_task.h" #include "heartbeat_task.h"
#include "fw_command.h" #include "fw_command.h"
#include "psa/crypto.h"
static TaskHandle_t ota_task_handle = nullptr; static TaskHandle_t ota_task_handle = nullptr;
static char url_copy[256]; static char url_copy[256];
static char sha256[65]; static char sha256[65];
#define TAG "OTA"
void perform_ota(void* par) void perform_ota(void* par)
{ {
sampler_task_stop();
sender_task_stop();
heartbeat_task_stop();
ws_disconnect();
ESP_LOGI(TAG, "OTA task started");
esp_ota_handle_t ota_handle; esp_ota_handle_t ota_handle;
uint8_t buf[1024]; uint8_t buf[1024];
int read_bytes; int read_bytes;
const esp_partition_t* partition = nullptr; const esp_partition_t* partition = nullptr;
int content_length = 0;
psa_hash_operation_t op = PSA_HASH_OPERATION_INIT;
const char* url = static_cast<const char*>(par); const char* url = static_cast<const char*>(par);
ESP_LOGI(TAG, "URL: %s", url);
esp_http_client_config_t config = {}; esp_http_client_config_t config = {};
config.url = url; config.url = url;
config.timeout_ms = 5000; config.timeout_ms = 5000;
config.buffer_size = 1024; config.buffer_size = 1024;
esp_http_client_handle_t client = esp_http_client_init(&config); esp_http_client_handle_t client = esp_http_client_init(&config);
ESP_LOGI(TAG, "HTTP init complete");
if (esp_http_client_open(client, 0) != ESP_OK) if (esp_http_client_open(client, 0) != ESP_OK)
{ {
goto cleanup; goto cleanup;
} }
ESP_LOGI(TAG, "HTTP connection established");
content_length = esp_http_client_fetch_headers(client);
ESP_LOGI(TAG, "content_length=%d", content_length);
partition = esp_ota_get_next_update_partition(nullptr); partition = esp_ota_get_next_update_partition(nullptr);
ESP_LOGI(TAG, "OTA update partition selected %s (offset 0x%08x)", partition->label, partition->address);
if (esp_ota_begin(partition, OTA_SIZE_UNKNOWN, &ota_handle) != ESP_OK) if (esp_ota_begin(partition, OTA_SIZE_UNKNOWN, &ota_handle) != ESP_OK)
{ {
goto cleanup; goto cleanup;
} }
while ((read_bytes = esp_http_client_read(client, (char*)buf, sizeof(buf))) > 0) psa_crypto_init();
psa_hash_setup(&op, PSA_ALG_SHA_256);
ESP_LOGI(TAG, "OTA update started");
while (1)
{ {
if (esp_ota_write(ota_handle, buf, read_bytes) != ESP_OK) int data_read = esp_http_client_read(client, (char*)buf, sizeof(buf));
if (data_read < 0)
{ {
esp_ota_end(ota_handle); ESP_LOGE(TAG, "read error");
break;
}
else if (data_read == 0)
{
ESP_LOGI(TAG, "download finished");
break;
}
ESP_LOGI(TAG, "write chunk: %d bytes", data_read);
if (esp_ota_write(ota_handle, buf, data_read) != ESP_OK)
{
ESP_LOGE(TAG, "flash write failed");
goto cleanup; goto cleanup;
} }
if (psa_hash_update(&op, buf, data_read) != PSA_SUCCESS) {
ESP_LOGE(TAG, "hash update failed");
}
} }
ESP_LOGI(TAG, "finalizing...");
if (esp_ota_end(ota_handle) == ESP_OK) if (esp_ota_end(ota_handle) == ESP_OK)
{ {
uint8_t hash[32]; uint8_t hash[32];
size_t hash_len;
if (esp_partition_get_sha256(partition, hash) != ESP_OK) psa_hash_finish(&op, hash, sizeof(hash), &hash_len);
{
goto cleanup;
}
// переводим в hex // переводим в hex
char hash_str[65]; char hash_str[65];
@@ -68,16 +113,24 @@ void perform_ota(void* par)
} }
hash_str[64] = '\0'; hash_str[64] = '\0';
ESP_LOGI(TAG, "sha256 = %s. Comparing...", hash_str);
// сравнение // сравнение
if (strncmp(hash_str, sha256, 64) != 0) if (strncmp(hash_str, sha256, 64) != 0)
{ {
printf("SHA256 mismatch!\n"); ESP_LOGE(TAG, "SHA256 mismatch!");
goto cleanup; goto cleanup;
} }
ESP_LOGI(TAG, "set boot partition");
esp_ota_set_boot_partition(partition); esp_ota_set_boot_partition(partition);
ESP_LOGI(TAG, "restart pending");
esp_restart(); esp_restart();
} }
else
{
ESP_LOGE(TAG, "ota_end failed");
}
cleanup: cleanup:
esp_http_client_cleanup(client); esp_http_client_cleanup(client);
@@ -89,10 +142,6 @@ void ota_task_start(const fw_cmd_t* cmd)
{ {
strncpy(url_copy, cmd->url, sizeof(url_copy)); strncpy(url_copy, cmd->url, sizeof(url_copy));
strncpy(sha256, cmd->sha256, sizeof(sha256)); strncpy(sha256, cmd->sha256, sizeof(sha256));
sampler_task_stop();
sender_task_stop();
heartbeat_task_stop();
ws_disconnect();
if (ota_task_handle == nullptr) if (ota_task_handle == nullptr)
{ {
xTaskCreate( xTaskCreate(

View File

@@ -3,6 +3,8 @@
#include <stdio.h> #include <stdio.h>
#include <stdarg.h> #include <stdarg.h>
#include <inttypes.h> #include <inttypes.h>
#include "esp_wifi.h"
#include "esp_netif.h"
#define PROTOCOL_VERSION 1 #define PROTOCOL_VERSION 1
@@ -195,16 +197,45 @@ uint32_t protocol_next_id()
void protocol_send_event_hb(int64_t ts) void protocol_send_event_hb(int64_t ts)
{ {
// формируешь JSON строго по контракту char buf[256];
// пример: wifi_ap_record_t ap;
char buf[128]; esp_wifi_sta_get_ap_info(&ap);
int rssi = ap.rssi;
const char *ssid = (const char *)ap.ssid;
esp_netif_ip_info_t ip_info;
esp_netif_t *netif = esp_netif_get_handle_from_ifkey("WIFI_STA_DEF");
esp_netif_get_ip_info(netif, &ip_info);
uint32_t ip = ip_info.ip.addr;
size_t heap_free = heap_caps_get_free_size(MALLOC_CAP_DEFAULT);
size_t heap_largest = heap_caps_get_largest_free_block(MALLOC_CAP_DEFAULT);
uint32_t id = protocol_next_id(); uint32_t id = protocol_next_id();
snprintf(buf, sizeof(buf), snprintf(buf, sizeof(buf),
"{\"v\":1,\"id\":%" PRIu32 ",\"t\":\"e\",\"ts\":%" PRIu64 ",\"d\":%u,\"p\":{\"type\":\"hb\"}}", "{\"v\":1,\"id\":%" PRIu32 ",\"t\":\"e\",\"ts\":%" PRIu64 ",\"d\":%u,"
id, ts, CONFIG_DEVICE_ID); "\"p\":{"
"\"t\":\"hb\","
"\"v\":%u,"
"\"hp\":%u,"
"\"hl\":%u,"
"\"rs\":%d,"
"\"ip\":%" PRIu32 ","
"\"si\":\"%s\""
"}}",
id, ts, CONFIG_DEVICE_ID,
FW_VERSION,
heap_free,
heap_largest,
rssi,
ip,
ssid
);
ws_send(buf); ws_send(buf);
} }

View File

@@ -1 +1 @@
1 8