Files
alloy_elite2_rgb/websocket.c
T

471 lines
12 KiB
C

#include <string.h>
#include "lwip/altcp.h"
#include "lwip/debug.h"
#include "mbedtls/base64.h"
#include "mbedtls/sha1.h"
#include "websocket.h"
static const char WS_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
static const char WS_RESPONSE[] = "HTTP/1.1 101 Switching Protocols\r\n" \
"Upgrade: websocket\r\n" \
"Connection: Upgrade\r\n" \
"Sec-WebSocket-Accept: ";
static uint8_t buf[WS_BUFFER_SIZE];
static uint16_t buf_len=0;
static tWSHandler ws_receive_cb = NULL;
static tWSOpenHandler ws_open_cb = NULL;
static struct ws_state * ws_connections;
static uint8_t ws_num_conns = 0;
static struct ws_state* ws_state_alloc(void);
static void ws_state_init(struct ws_state *wss);
static void ws_state_free(struct ws_state *wss);
static void ws_server_init_pcb( struct altcp_pcb *pcb, uint16_t port);
static err_t ws_accept(void *arg, struct altcp_pcb *pcb, err_t err);
static err_t ws_recv(void *arg, struct altcp_pcb *pcb, struct pbuf *p, err_t err);
static err_t ws_sent(void *arg, struct altcp_pcb *pcb, uint16_t len);
static void ws_err (void *arg, err_t err);
static err_t ws_close_conn(struct altcp_pcb *pcb, struct ws_state *wss);
static err_t ws_close_or_abort_conn(struct altcp_pcb *pcb, struct ws_state *wss, uint8_t abort_conn);
static err_t ws_poll(void *arg, struct altcp_pcb *pcb);
static err_t ws_handshake(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p);
static err_t ws_read(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p);
static err_t ws_send(struct ws_state *wss, uint8_t *data, uint16_t len);
// allocate memory for ws_state instance
static struct ws_state * ws_state_alloc(void) {
struct ws_state *ret = WS_ALLOC_WS_STATE();
if ( ret != NULL) {
ws_state_init(ret);
if (ws_connections == NULL) {
ws_connections = ret;
} else {
struct ws_state *last;
for (last=ws_connections; last->next != NULL; last=last->next);
LWIP_ASSERT("last != NULL", last != NULL);
last->next = ret;
}
}
return ret;
}
// initiate ws_state instance
static void ws_state_init(struct ws_state *wss) {
memset(wss, 0, sizeof(struct ws_state));
wss->active = false;
}
// free memory from ws_state instance
static void ws_state_free(struct ws_state *wss) {
if (wss != NULL) {
if (ws_connections == wss) {
ws_connections = wss->next;
} else {
struct ws_state * last;
for (last = ws_connections; last->next != NULL; last = last->next) {
if (last->next == wss) {
last->next = wss->next;
break;
}
}
}
mem_free(wss);
}
}
// initiate websocket server on specified pcb
static void ws_server_init_pcb( struct altcp_pcb *pcb, uint16_t port) {
err_t err;
if (pcb) {
altcp_setprio(pcb, TCP_PRIO_MIN);
err = altcp_bind(pcb, IP_ANY_TYPE, port);
LWIP_UNUSED_ARG(err);
LWIP_ASSERT("ws_server_init: tcp_bind failed", err == ERR_OK);
pcb = altcp_listen(pcb);
LWIP_ASSERT("ws_server_init: tcp_listen failed", pcb != NULL);
altcp_accept(pcb, ws_accept);
}
}
// initiate a websocket server
void ws_server_init(void) {
struct altcp_pcb *pcb = altcp_tcp_new_ip_type(IPADDR_TYPE_ANY);
LWIP_ASSERT("ws_server_init: tcp_new failed", pcb != NULL);
ws_server_init_pcb(pcb, WS_PORT);
}
// set ws_receive_handler
void ws_set_receive_handler( tWSHandler ws_handler)
{
ws_receive_cb = ws_handler;
}
// set ws_open_handler
void ws_set_open_handler( tWSOpenHandler ws_handler)
{
ws_open_cb = ws_handler;
}
// callback for accepted websocket connection
static err_t ws_accept(void *arg, struct altcp_pcb *pcb, err_t err) {
struct ws_state *wss;
LWIP_UNUSED_ARG(err);
LWIP_UNUSED_ARG(arg);
LWIP_DEBUGF(WS_DEBUG, ("ws_accept %p / %p\n", (void *)pcb, arg));
if ((err != ERR_OK) || (pcb == NULL)) {
return ERR_VAL;
}
// create new ws_state object
wss = ws_state_alloc();
if (wss == NULL) {
LWIP_DEBUGF(WS_DEBUG, ("ws_accept: Out of memory, RST\n"));
return ERR_MEM;
}
wss->pcb = pcb;
// make ws_state object the argument of callbacks
altcp_arg(pcb, wss);
// register callbacks for tcp events
altcp_recv(pcb, ws_recv);
altcp_sent(pcb, ws_sent);
altcp_poll(pcb, ws_poll, WS_POLL_INTERVAL);
altcp_err(pcb, ws_err);
return ERR_OK;
}
// call when data is received
static err_t ws_recv(void *arg, struct altcp_pcb *pcb, struct pbuf *p, err_t err) {
struct ws_state *wss = (struct ws_state *) arg;
if ((err != ERR_OK) || (p == NULL) || (wss == NULL)) {
// error or closed by client
if (p != NULL) {
// inform TCP that we have taken the data
altcp_recved(pcb, p->tot_len);
pbuf_free(p);
}
if (wss == NULL) {
// should not occur
LWIP_DEBUGF(WS_DEBUG, ("Error, ws_recv: wss is NULL, close\n"));
}
ws_close_conn(pcb, wss);
return ERR_OK;
}
if (wss->active) {
// process websocket message
err = ws_read(pcb, wss, p);
} else {
// init websocket connection
LWIP_DEBUGF(WS_DEBUG, ("ws_recv: websocket inactive, checking for handshake\n"));
err = ws_handshake(pcb, wss, p);
}
// inform TCP that we have taken the data.
altcp_recved(pcb, p->tot_len);
pbuf_free(p);
if (err == ERR_CLSD) {
ws_close_conn(pcb, wss);
}
return ERR_OK;
}
// called when data has been sent over the websocket
static err_t ws_sent(void *arg, struct altcp_pcb *pcb, uint16_t len) {
(void) pcb;
struct ws_state *wss = (struct ws_state *)arg;
LWIP_DEBUGF(WS_DEBUG | LWIP_DBG_TRACE, ("ws_sent %p\n", (void*) pcb));
LWIP_UNUSED_ARG(len);
if (wss == NULL) {
return ERR_OK;
}
wss->retries = 0;
return ERR_OK;
}
// called when there is a websocket error
static void ws_err (void *arg, err_t err) {
struct ws_state *wss = (struct ws_state *) arg;
LWIP_UNUSED_ARG(err);
LWIP_DEBUGF(WS_DEBUG, ("ws_err: %s", lwip_strerr(err)));
if (wss != NULL) {
ws_state_free(wss);
}
}
// initiate close of connection
static err_t ws_close_conn(struct altcp_pcb *pcb, struct ws_state *wss) {
return ws_close_or_abort_conn(pcb, wss, 0);
}
// call when closing connection or connection was aborted
static err_t ws_close_or_abort_conn(struct altcp_pcb *pcb, struct ws_state *wss,
uint8_t abort_conn) {
err_t err;
LWIP_DEBUGF(WS_DEBUG, ("Closing connection %p\n", (void *)pcb));
// clear callbacks
altcp_arg(pcb, NULL);
altcp_recv(pcb, NULL);
altcp_sent(pcb, NULL);
altcp_poll(pcb, NULL, 0);
altcp_err(pcb, NULL);
// remove and free memory from ws_state object
if (wss != NULL) {
ws_state_free(wss);
}
if (abort_conn) {
altcp_abort(pcb);
return ERR_OK;
}
err = altcp_close(pcb);
if (err != ERR_OK) {
LWIP_DEBUGF(WS_DEBUG, ("Error %d closing %p\n", err, (void *)pcb));
// error closing, try again later in poll
altcp_poll(pcb, ws_poll, WS_POLL_INTERVAL);
}
return err;
}
// callback for polling process
static err_t ws_poll(void *arg, struct altcp_pcb *pcb) {
struct ws_state *wss = (struct ws_state *) arg;
if (wss == NULL) {
err_t closed;
LWIP_DEBUGF(WS_DEBUG, ("ws_poll: arg is NULL, close\n"));
closed = ws_close_conn(pcb, NULL);
LWIP_UNUSED_ARG(closed);
if (closed == ERR_MEM) {
altcp_abort(pcb);
return ERR_ABRT;
}
return ERR_OK;
} else {
wss->retries++;
if (wss->retries == WS_MAX_RETRIES) {
LWIP_DEBUGF(WS_DEBUG, ("ws_poll: too may retries, close\n"));
ws_close_conn(pcb, wss);
return ERR_OK;
}
}
return ERR_OK;
}
// check for and complete handshake with client
static err_t ws_handshake(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p){
uint8_t *data = (uint8_t *) p->payload;
uint16_t len = p->len;
// check if client is initiating a websocket connecttion
if (strstr(data, "Upgrade: websocket")) {
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: received websocket upgrade request\n"));
// search for websocket security key
char *key_start = strstr(data, "Sec-WebSocket-Key: ");
if (key_start) {
key_start += 19;
const char *key_end = strstr(key_start, "\r\n");
if (key_end) {
char key[64];
uint16_t key_len = key_end-key_start;
if ( (key_len>0) && (key_len + sizeof(WS_GUID) < sizeof(key)) ) {
// create response key by concatenating with websocket GUID,
// taking SHA1 hash, then encoding in base 64
strncpy(key, key_start, key_len);
strlcpy(&key[key_len], WS_GUID, sizeof(key)-key_len);
key_len += sizeof(WS_GUID)-1;
unsigned char key_sha1[20];
unsigned char key_base64[29];
size_t encoded_len;
mbedtls_sha1( (unsigned char *) key, key_len, key_sha1);
mbedtls_base64_encode( key_base64, 29, &encoded_len, key_sha1, 20);
// create response packet with encoded response key
unsigned char response[sizeof(WS_RESPONSE) + sizeof(key_base64)+3];
size_t count = sprintf(response, "%s%s\r\n\r\n", WS_RESPONSE, key_base64);
// send completed data packet
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: sending response\n"));
if(altcp_write(pcb, response, count, TCP_WRITE_FLAG_COPY) == ERR_OK) {
wss->active = true;
}
if (ws_open_cb != NULL) {
ws_open_cb(wss);
}
return ERR_OK;
}
}
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: key overflow\n"));
return ERR_MEM;
} else {
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: key not received\n"));
return ERR_ARG;
}
}
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: not a websocket request\n"));
return ERR_ARG;
}
// handle reading of websocket data and pass to ws_receive_cb
static err_t ws_read(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p) {
(void) pcb;
uint8_t *data = (uint8_t *) p->payload;
uint16_t len = p->len;
if (data != NULL && len > 1) {
// successful read, reset timeout
wss->retries = 0;
uint8_t fin = data[0] & 0x80;
uint8_t opcode = data[0] & 0x0F;
uint8_t masked = data[1] & 0x80;
uint16_t msg_len = data[1] & 0x7F;
uint8_t *msg;
switch (msg_len) {
case 126: // next two bytes are length
memcpy(&msg_len, &data[2], 2);
if (len >= 8) {
msg = &data[8];
}
break;
case 127: // next four bytes are length
// lwIP's pbuf only handles 16-bit lengths, so error
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received 64-bit length %u\n", msg_len));
return ERR_MEM;
default:
if (len >= 6) {
msg = &data[6];
}
break;
}
switch (opcode) {
case OP_CONT:
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received continuation frame\n"));
case OP_TEXT:
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received text data\n"));
case OP_BINARY:
LWIP_DEBUGF(WS_DEBUG, ("ws_read: decoding data, len=%u\n", msg_len));
if (msg && ws_receive_cb != NULL) {
// unmask the data if mask bit is received
if (masked) {
uint8_t *mask = &data[2];
for (int i=0; i<msg_len; i++) {
msg[i] ^= mask[i % 4];
}
} else {
// messages from client must be masked - disconnect
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received unmasked message"));
return ERR_CLSD;
}
msg[msg_len]=0;
if (opcode != OP_CONT) { // not a continuation frame, reset buffer
buf_len=0;
memset(buf, 0x00, sizeof(buf));
}
if (buf_len + msg_len > WS_BUFFER_SIZE) {
LWIP_DEBUGF(WS_DEBUG, ("ws_read: message exceeds buffer size %u+%u\n", buf_len, msg_len));
return ERR_MEM;
}
memcpy(&buf[buf_len], msg, msg_len);
buf_len += msg_len;
if (fin) { // last packet in message, process completed message
ws_receive_cb(buf, buf_len);
}
}
break;
case OP_CLOSE:
LWIP_DEBUGF(WS_DEBUG, ("ws_read: close request"));
return ERR_CLSD;
case OP_PING:
// control frames cannot exceed 125 bytes in length
if (msg && msg_len <= 125) {
// send back a pong
uint8_t pong[2+msg_len];
pong[0]=0x8A;
pong[1]=msg_len;
memcpy(&pong[2], msg, msg_len);
return ws_send(wss, pong, msg_len+2);
}
return ERR_ARG;
case OP_PONG: // no response required for pong
return ERR_OK;
default:
LWIP_DEBUGF(WS_DEBUG, ("ws_read: invalid opcode %02X\n", opcode));
return ERR_ARG;
}
return ERR_OK;
}
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received empty payload\n"));
return ERR_VAL;
}
static err_t ws_send(struct ws_state *wss, uint8_t *data, uint16_t len) {
uint8_t buf[128];
buf[0] = 0x81;
buf[1] = len & 0x7F;
memcpy(&buf[2], data, len);
err_t err;
err = altcp_write(wss->pcb, buf, len+2, TCP_WRITE_FLAG_COPY);
if (err == ERR_OK) {
altcp_output(wss->pcb);
}
return err;
}
void ws_send_all(uint8_t *data, uint16_t len) {
// send message to all connections
if (ws_connections != NULL) {
struct ws_state *wss;
err_t err;
for (wss=ws_connections; wss != NULL; wss=wss->next) {
err = ws_send(wss, data, len);
if (err != ERR_OK ) {
LWIP_DEBUGF(WS_DEBUG, ("ws_send_all: error sending to %p\n", wss));
}
}
}
}