#include "main.h"
#include "modbus_crc.h"
#include "speck.h"
#define BUS_BUF_SIZE 150
#define CDN_MAX_DAT     (BUS_BUF_SIZE - 5)

uint8_t bus_buf[BUS_BUF_SIZE];


static inline void bus_transmit(const uint8_t *buf, uint16_t len)
{
    for (uint16_t i = 0; i < len; i++)
    {
        while (USART_GetFlagStatus(USART1, USART_FLAG_BUSY))
        {
            __dekey();
        }
        USART_SendData(USART1, *(buf + i));
    }
    // 在单线串口中，发送完数据后，需要等待发送完成标志位清除
    while (USART_GetFlagStatus(USART1, USART_FLAG_BUSY))
    {
        __dekey();
    }
}

static void send_frame(uint8_t len)
{
    bus_buf[1] = bus_buf[0];        // swap cdbus src_mac and dst_mac
    bus_buf[0] = csa.mac;
    bus_buf[2] = len + 2;                // cdnet payload size -> cdbus payload size
    swap(bus_buf[3], bus_buf[4]);   // swap cdnet src_port and dst_port
    uint16_t cal_crc16 = crc16(bus_buf , bus_buf[2] + 3);
    put_unaligned16(cal_crc16, bus_buf + bus_buf[2] + 3);
    TTL_OR_485_TXMODE();
    bus_transmit(bus_buf, bus_buf[2] + 5);
    TTL_OR_485_RXMODE();
}

static void flash_operation(void)
{
    uint8_t *p_dat = bus_buf + 5;
    uint8_t p_len = bus_buf[2] - 2;
    bool reply = !(*p_dat & 0x80);
    *p_dat &= 0x7f;

    uint8_t buf0[128] = {0};
    uint8_t buf1[128] = {0};

#if ENABL_SPECK64

    static const uint8_t k96[12+1] = ENC_KEY; // random
    uint8_t k[12];
    for (int i = 0; i < 12; i++)
        k[i] = k96[i] ^ 0xcd;
    speck64_t speck_ctx;
    crypto_speck64_setkey(&speck_ctx, k, SPECK64_96_KEY_SIZE);
    if (*p_dat == FLASH_ERASE && p_len == 9) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint32_t len = get_unaligned32(p_dat + 5);
        uint8_t ret = flash_erase(addr, len);
        *p_dat = ret != FLASH_COMPLETE ? 1 : 0;
        if (reply)
            send_frame(1);
    }else if (*p_dat == FLASH_READ && p_len == 6) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint32_t *dst_addr = (uint32_t *) addr;
        uint8_t len = min(p_dat[5], CDN_MAX_DAT - 1);

        if ((addr % 8) == 0 && (len % 8) == 0 && len <= 128) {
            memcpy(buf0, dst_addr, len);
            for (int i = 0; i < len / 4; i++)
                *(uint32_t *)(buf0 + i * 4) ^= addr + i * 4;
            for (int i = 0; i < len / 8; i++)
                crypto_speck64_encrypt(&speck_ctx, buf1 + i * 8, buf0 + i * 8);
            memcpy(p_dat + 1, buf1, len);
            *p_dat = 0;
        } else {
            *p_dat = 2;
            len = 0;
        }
        if (reply)
            send_frame(len + 1);
    }else if (*p_dat == FLASH_WRITE && p_len > 8) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint8_t len = p_len - 5;
        if ((addr % 8) == 0 && (len % 8) == 0 && len <= 128) {
            memcpy(buf0, p_dat + 5, len);
            for (int i = 0; i < len / 8; i++)
                crypto_speck64_decrypt(&speck_ctx, buf1 + i * 8, buf0 + i * 8);
            for (int i = 0; i < len / 4; i++)
                *(uint32_t *)(buf1 + i * 4) ^= addr + i * 4;
            uint8_t ret = flash_write(addr, len, buf1);
            *p_dat = ret != FLASH_COMPLETE ? 1 : 0;
        } else {
            *p_dat = 0x2;
        }
        if (reply)
            send_frame(1);
    } else if (p_dat[0] == FLASH_CAL_CRC && p_len == 9) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint32_t len = get_unaligned32(p_dat + 5);
        if ((addr % 8) == 0 && (len % 8) == 0) {
            uint32_t pos = addr;
            uint16_t crc_val = 0xffff;
            while (true) {
                int sub_size = min(128, len - (pos - addr));
                if (!sub_size)
                    break;
                memcpy(buf0, (uint8_t *) pos, sub_size);
                for (int i = 0; i < sub_size / 4; i++)
                    *(uint32_t *)(buf0 + i * 4) ^= pos + i * 4;
                for (int i = 0; i < sub_size / 8; i++)
                    crypto_speck64_encrypt(&speck_ctx, buf1 + i * 8, buf0 + i * 8);
                crc_val = crc16_sub(buf1, sub_size, crc_val);
                pos += sub_size;
            }
            *p_dat = 0;
            put_unaligned16(crc_val, p_dat + 1);
        } else {
            *p_dat = 1;
        }
        if (reply)
            send_frame(*p_dat ? 1 : 3);
    }
#else


    if (*p_dat == FLASH_ERASE && p_len == 9) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint32_t len = get_unaligned32(p_dat + 5);
        uint8_t ret = flash_erase(addr, len);
        *p_dat = ret != FLASH_COMPLETE ? 1 : 0;
        if (reply)
            send_frame(1);

    } else if (*p_dat == FLASH_READ && p_len == 6) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint32_t *dst_addr = (uint32_t *) addr;
        uint8_t len = min(p_dat[5], CDN_MAX_DAT - 1);
        memcpy(p_dat + 1, dst_addr, len);
        *p_dat = 0;
        if (reply)
            send_frame(len + 1);

    } else if (*p_dat == FLASH_WRITE && p_len > 8) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint8_t len = p_len - 5;
        uint8_t ret = flash_write(addr, len, p_dat + 5);
        *p_dat = ret != FLASH_COMPLETE ? 1 : 0;
        if (reply)
            send_frame(1);
    }else if (*p_dat == FLASH_CAL_CRC && p_len == 9) {
        uint32_t addr = get_unaligned32(p_dat + 1);
        uint32_t len = get_unaligned32(p_dat + 5);
        uint32_t pos = addr;
        uint16_t crc_val = 0xffff;
        while (true) {
            int sub_size = min(128, len - (pos - addr));
            if (!sub_size)
                break;
            memcpy(buf0, (uint8_t *) pos, sub_size);
            for (int i = 0; i < sub_size / 4; i++)
                *(uint32_t *)(buf0 + i * 4) ^= pos + i * 4;
            crc_val = crc16_sub(buf0, sub_size, crc_val);
            pos += sub_size;
        }
        *p_dat = 0;
        put_unaligned16(crc_val, p_dat + 1);
        if (reply)
            send_frame(*p_dat ? 1 : 3);
    }
#endif
}
// M: Kpower wheel hub(bl), S: ------------------------, HW: PB194.Y2, SW:
static char dev_info[80] = {"M: Kpower wheel hub(bl), S: ------------------------, HW: PB225.Y01, SW:"};

static void get_uid(char *buf)
{
    const char tlb[] = "0123456789abcdef";
    for (int i = 0; i < 12; i++)
    {
        uint8_t val = *((char *)0x1FFFF7B0 + i);
        buf[i * 2 + 0] = tlb[val >> 4];
        buf[i * 2 + 1] = tlb[val & 0xf];
    }
    //buf[24] = '\0';
}

static void read_dev_info(void)
{
    uint8_t *p_dat = bus_buf + 5;
    uint8_t p_len = bus_buf[2] - 2;

    get_uid(dev_info + 28);
    strcpy(dev_info + 72, "2.0");
    
    if (p_len == 0) {
        memcpy(p_dat, dev_info, strlen(dev_info));
        send_frame(strlen(dev_info));
    }
}


static void read_write_para(void)
{
    uint32_t flags;
    uint8_t *p_dat = bus_buf + 5;
    uint8_t p_len = bus_buf[2] - 2;
    bool reply = !(*p_dat & 0x80);
    *p_dat &= 0x7f;

    if (*p_dat == READ_PARAM && p_len == 4) {
        uint16_t offset = get_unaligned16(p_dat + 1);
        uint8_t len = min(p_dat[3], CDN_MAX_DAT - 1);
        local_irq_save(flags);
        memcpy(p_dat + 1, ((void *) &csa) + offset, len);
        local_irq_restore(flags);
        *p_dat = 0;
        if (reply)
            send_frame(len + 1);
    } else if (p_dat[0] == WRITE_PARAM && p_len > 3) {
        uint16_t offset = get_unaligned16(p_dat + 1);
        uint8_t len = p_len - 3;
        uint8_t *src_addr = p_dat + 3;
        uint16_t start = clip(offset, 0, sizeof(csa_t));
        uint16_t end = clip(offset + len, 0, sizeof(csa_t));
        local_irq_save(flags);
        memcpy(((void *) &csa) + start, src_addr + (start - offset), end - start);
        local_irq_restore(flags);
        *p_dat = 0;
        if (reply)
            send_frame(1);
    }
}



#define CDUART_IDLE_TIME   5
static uint16_t            rx_byte_cnt = 0;
static bool                rx_drop = false;
static uint32_t            t_last = 0;
static bool                rx_pend = false;

static void cduart_rx_handle(const uint8_t dat)
{
    if (rx_byte_cnt != 0 && get_systick() - t_last > CDUART_IDLE_TIME)
    {
        printf("bus: timeout [%x %x %x] %x\n",
               bus_buf[0], bus_buf[1], bus_buf[2], rx_byte_cnt);
        rx_byte_cnt = 0;
        rx_drop = false;
    }
    t_last = get_systick();

    if (!rx_drop && rx_byte_cnt < bus_buf[2] + 5)
        bus_buf[rx_byte_cnt] = dat;
    rx_byte_cnt++;

    if (rx_byte_cnt == 3 &&
            (bus_buf[2] > BUS_BUF_SIZE - 5 ||
             (bus_buf[1] != 0xff && bus_buf[1] != csa.mac)))   // todo: use dev_mac backup
    {
        printf("bus: drop [%x %x %x]\n", bus_buf[0], bus_buf[1], bus_buf[2]);
        rx_drop = true;
    }

    if (rx_byte_cnt == bus_buf[2] + 5)
    {
        if (!rx_drop)
        {
            uint16_t rx_crc = crc16(bus_buf, rx_byte_cnt);
            if (rx_crc != 0)
                printf("bus: !crc [%x %x %x]\n", bus_buf[0], bus_buf[1], bus_buf[2]);
            else
                rx_pend = true;
        }
        rx_byte_cnt = 0;
        rx_drop = false;
    }
}


static inline void serial_server_allot(void)
{
    switch (bus_buf[4])
    {
        case READ_DEV_INFO: read_dev_info(); break;
        case READ_WRITE_PARA: read_write_para(); break;
        case FLASH_OPERATION: flash_operation(); break;
    }
}

void serial_task(void)
{
    if (USART_GetFlagStatus(USART1, USART_FLAG_RXFF)) // 超出4byte
        cduart_rx_handle(USART_ReceiveData(USART1));

    if (rx_pend)
    {
        rx_pend = false;
        serial_server_allot();
    }

}

void serial_init(void)
{
/*
PA0 | AF1 USART1_CTS | AF4 USART1_RX
PA1 | AF1 USART1_RTS | AF4 USART1_TX
PA2 | AF1 USART1_RX | AF3 USART1_TX
PA3 | AF1 USART1_TX | AF7 USART1_RX
PA4 | AF1 USART0_RTS | AF4 USART1_CK | AF7 USART1_TX
PA6 | AF6 USART0_CK
PA7 | AF4 USART0_RX
PA8 | USART0_CTS | AF4 USART0_TX
PA9 | AF2 USART0_TX | AF7 USART0_RX
PA10 | AF2 USART0_RX | AF5 USART0_TX
PA11 | AF4 USART0_TX | AR6 USART0_CTS
PA12 | AF4 USART0_RX | AF6 USART0_RTS
PA13 | AF4 USART1_RX 
PA14 | AF1 USART1TX | AF7 USART1_RX
PA15 | AF1 USART1_RX
PB0 | AF4 USART0_TX | AF6 USART0_RX
PB1 | AF4 USART1_RTS | AF5 USART0_RX | AF7 USART0_TX
PB2 | AF6 USART1_TX | AF7 USART0_RX
PB3 | AF7 USART1_TX
PB4 | AF3 USART0_RX
PB5 | AF4 USART0_TX
PB6 | AF0 USART0_TX | AF7 USART0_RX
PB7 | AF0 USART0_RX | AF4 USART0_TX
  
 */

    // io init PA0 PA1 PB1
    GPIO_InitTypeDef GPIO_InitStructure;
    
    memset(&GPIO_InitStructure, 0, sizeof(GPIO_InitTypeDef));
    GPIO_InitStructure.GPIO_Pin = GPIO_Pin_3;               //PB3 UART1_TX
    GPIO_InitStructure.GPIO_Speed = GPIO_Speed_Level_2;
    GPIO_InitStructure.GPIO_Mode = GPIO_Mode_AF;
    GPIO_InitStructure.GPIO_PuPd = GPIO_PuPd_NOPULL;
    GPIO_InitStructure.GPIO_OType = GPIO_OType_PP;
    GPIO_Init(GPIOB, &GPIO_InitStructure);
    GPIO_PinAFConfig(GPIOB, GPIO_PinSource3, GPIO_AF_7);  //TX

    GPIO_InitStructure.GPIO_Pin = GPIO_Pin_15;               //PA15 UART1_RX
    GPIO_Init(GPIOA, &GPIO_InitStructure);
    GPIO_PinAFConfig(GPIOA, GPIO_PinSource15, GPIO_AF_1);  //RX

    GPIO_InitStructure.GPIO_Pin = GPIO_Pin_8;               //PA8 485_SW
    GPIO_InitStructure.GPIO_Mode = GPIO_Mode_OUT;
    GPIO_Init(GPIOA, &GPIO_InitStructure);




    TTL_OR_485_RXMODE();

    // peripheral inti usart1 注意 串口初始化flash占用非常夸张， 在flash紧凑情况下 尽量使用寄存器操作
    // USART_InitTypeDef uart_init_struct;
    // uart_init_struct.USART_Mode = USART_Mode_Rx | USART_Mode_Tx;
    // uart_init_struct.USART_BaudRate = 1000000;
    // uart_init_struct.USART_WordLength = USART_WordLength_8b;
    // uart_init_struct.USART_Parity = USART_Parity_No;
    // uart_init_struct.USART_StopBits = USART_StopBits_1;
    // uart_init_struct.USART_HardwareFlowControl = USART_HardwareFlowControl_None;

    // USART_Init(USART1, &uart_init_struct);
    // USART_Cmd(USART1, ENABLE);


    // // 注意 请保持库函数操作和寄存器操作内容的一致性
    USART1->IBRD = 2;	    //38400 68 115200 22  1000000: 2
    USART1->FBRD = 40;       //38400 23 115200 50  1000000: 40
    USART1->LCR = USART_WordLength_8b | USART_StopBits_1 | USART_Parity_No;
    USART1->CR |= USART_CR_USARTEN_Msk;     //使能串口
}





