375 lines
9.5 KiB
C
375 lines
9.5 KiB
C
#include <string.h>
|
|
|
|
#include "lwip/altcp.h"
|
|
#include "lwip/debug.h"
|
|
#include "mbedtls/base64.h"
|
|
#include "mbedtls/sha1.h"
|
|
|
|
#include "websocket.h"
|
|
|
|
static const char WS_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
static const char WS_RESPONSE[] = "HTTP/1.1 101 Switching Protocols\r\n" \
|
|
"Upgrade: websocket\r\n" \
|
|
"Connection: Upgrade\r\n" \
|
|
"Sec-WebSocket-Accept: ";
|
|
|
|
static tWSHandler ws_receive_cb = NULL;
|
|
static tWSOpenHandler ws_open_cb = NULL;
|
|
|
|
static struct ws_state * ws_connections;
|
|
static uint8_t ws_num_conns = 0;
|
|
|
|
// allocate memory for ws_state instance
|
|
static struct ws_state * ws_state_alloc(void) {
|
|
struct ws_state *ret = WS_ALLOC_WS_STATE();
|
|
|
|
if ( ret != NULL) {
|
|
ws_state_init(ret);
|
|
if (ws_connections == NULL) {
|
|
ws_connections = ret;
|
|
} else {
|
|
struct ws_state *last;
|
|
for (last=ws_connections; last->next != NULL; last=last->next);
|
|
LWIP_ASSERT("last != NULL", last != NULL);
|
|
last->next = ret;
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
// initiate ws_state instance
|
|
static void ws_state_init(struct ws_state *wss) {
|
|
memset(wss, 0, sizeof(struct ws_state));
|
|
wss->active = false;
|
|
}
|
|
|
|
// free memory from ws_state instance
|
|
static void ws_state_free(struct ws_state *wss) {
|
|
if (wss != NULL) {
|
|
if (ws_connections == wss) {
|
|
ws_connections = wss->next;
|
|
} else {
|
|
struct ws_state * last;
|
|
for (last = ws_connections; last->next != NULL; last = last->next) {
|
|
if (last->next = wss) {
|
|
last->next = wss->next;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
mem_free(wss);
|
|
}
|
|
}
|
|
|
|
// initiate websocket server on specified pcb
|
|
static void ws_server_init_pcb( struct altcp_pcb *pcb, uint16_t port) {
|
|
err_t err;
|
|
|
|
if (pcb) {
|
|
altcp_setprio(pcb, TCP_PRIO_MIN);
|
|
err = altcp_bind(pcb, IP_ANY_TYPE, port);
|
|
LWIP_UNUSED_ARG(err);
|
|
LWIP_ASSERT("ws_server_init: tcp_bind failed", err == ERR_OK);
|
|
pcb = altcp_listen(pcb);
|
|
LWIP_ASSERT("ws_server_init: tcp_listen failed", pcb != NULL);
|
|
altcp_accept(pcb, ws_accept);
|
|
}
|
|
}
|
|
|
|
// initiate a websocket server
|
|
void ws_server_init(void) {
|
|
struct altcp_pcb *pcb = altcp_tcp_new_ip_type(IPADDR_TYPE_ANY);
|
|
LWIP_ASSERT("ws_server_init: tcp_new failed", pcb != NULL);
|
|
ws_server_init_pcb(pcb, WS_PORT);
|
|
}
|
|
|
|
// set ws_receive_handler
|
|
void ws_set_receive_handler( tWSHandler ws_handler)
|
|
{
|
|
ws_receive_cb = ws_handler;
|
|
}
|
|
|
|
// set ws_open_handler
|
|
void ws_set_open_handler( tWSOpenHandler ws_handler)
|
|
{
|
|
ws_open_cb = ws_handler;
|
|
}
|
|
|
|
// callback for accepted websocket connection
|
|
static err_t ws_accept(void *arg, struct altcp_pcb *pcb, err_t err) {
|
|
struct ws_state *wss;
|
|
LWIP_UNUSED_ARG(err);
|
|
LWIP_UNUSED_ARG(arg);
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_accept %p / %p\n", (void *)pcb, arg));
|
|
|
|
if ((err != ERR_OK) || (pcb == NULL)) {
|
|
return ERR_VAL;
|
|
}
|
|
|
|
// create new ws_state object
|
|
wss = ws_state_alloc();
|
|
if (wss == NULL) {
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_accept: Out of memory, RST\n"));
|
|
return ERR_MEM;
|
|
}
|
|
wss->pcb = pcb;
|
|
|
|
// make ws_state object the argument of callbacks
|
|
altcp_arg(pcb, wss);
|
|
|
|
// register callbacks for tcp events
|
|
altcp_recv(pcb, ws_recv);
|
|
altcp_poll(pcb, ws_poll, WS_POLL_INTERVAL);
|
|
altcp_err(pcb, ws_err);
|
|
|
|
return ERR_OK;
|
|
}
|
|
|
|
// call when data is received
|
|
static err_t ws_recv(void *arg, struct altcp_pcb *pcb, struct pbuf *p, err_t err) {
|
|
struct ws_state *wss = (struct ws_state *) arg;
|
|
|
|
if ((err != ERR_OK) || (p == NULL) || (wss == NULL)) {
|
|
// error or closed by client
|
|
if (p != NULL) {
|
|
// inform TCP that we have taken the data
|
|
altcp_recved(pcb, p->tot_len);
|
|
pbuf_free(p);
|
|
}
|
|
if (wss == NULL) {
|
|
// should not occur
|
|
LWIP_DEBUGF(WS_DEBUG, ("Error, ws_recv: wss is NULL, close\n"));
|
|
}
|
|
ws_close_conn(pcb, wss);
|
|
return ERR_OK;
|
|
}
|
|
|
|
if (wss->active) {
|
|
// process websocket message
|
|
err = ws_read(pcb, wss, p);
|
|
} else {
|
|
// init websocket connection
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_recv: websocket inactive, checking for handshake\n"));
|
|
|
|
err = ws_handshake(pcb, wss, p);
|
|
}
|
|
|
|
// inform TCP that we have taken the data.
|
|
altcp_recved(pcb, p->tot_len);
|
|
pbuf_free(p);
|
|
|
|
if (err == ERR_CLSD) {
|
|
ws_close_conn(pcb, wss);
|
|
}
|
|
|
|
return ERR_OK;
|
|
}
|
|
|
|
static void ws_err (void *arg, err_t err) {
|
|
struct ws_state *wss = (struct ws_state *) arg;
|
|
LWIP_UNUSED_ARG(err);
|
|
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_err: %s", lwip_strerr(err)));
|
|
|
|
if (wss != NULL) {
|
|
ws_state_free(wss);
|
|
}
|
|
}
|
|
|
|
// initiate close of connection
|
|
static err_t ws_close_conn(struct altcp_pcb *pcb, struct ws_state *wss) {
|
|
return ws_close_or_abort_conn(pcb, wss, 0);
|
|
}
|
|
|
|
// call when closing connection or connection was aborted
|
|
static err_t ws_close_or_abort_conn(struct altcp_pcb *pcb, struct ws_state *wss,
|
|
uint8_t abort_conn) {
|
|
|
|
err_t err;
|
|
LWIP_DEBUGF(WS_DEBUG, ("Closing connection %p\n", (void *)pcb));
|
|
|
|
// clear callbacks
|
|
altcp_arg(pcb, NULL);
|
|
altcp_recv(pcb, NULL);
|
|
altcp_poll(pcb, NULL, 0);
|
|
altcp_err(pcb, NULL);
|
|
|
|
// remove and free memory from ws_state object
|
|
if (wss != NULL) {
|
|
ws_state_free(wss);
|
|
}
|
|
|
|
if (abort_conn) {
|
|
altcp_abort(pcb);
|
|
return ERR_OK;
|
|
}
|
|
err = altcp_close(pcb);
|
|
if (err != ERR_OK) {
|
|
LWIP_DEBUGF(WS_DEBUG, ("Error %d closing %p\n", err, (void *)pcb));
|
|
// error closing, try again later in poll
|
|
altcp_poll(pcb, ws_poll, WS_POLL_INTERVAL);
|
|
}
|
|
return err;
|
|
}
|
|
|
|
// callback for polling process
|
|
static err_t ws_poll(void *arg, struct altcp_pcb *pcb) {
|
|
struct ws_state *wss = (struct ws_state *) arg;
|
|
if (wss == NULL) {
|
|
err_t closed;
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_poll: arg is NULL, close\n"));
|
|
closed = ws_close_conn(pcb, NULL);
|
|
LWIP_UNUSED_ARG(closed);
|
|
if (closed == ERR_MEM) {
|
|
altcp_abort(pcb);
|
|
return ERR_ABRT;
|
|
}
|
|
return ERR_OK;
|
|
} else {
|
|
wss->retries++;
|
|
if (wss->retries == WS_MAX_RETRIES) {
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_poll: too may retries, close\n"));
|
|
ws_close_conn(pcb, wss);
|
|
return ERR_OK;
|
|
}
|
|
}
|
|
|
|
return ERR_OK;
|
|
}
|
|
|
|
// check for and complete handshake with client
|
|
static err_t ws_handshake(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p){
|
|
uint8_t *data = (uint8_t *) p->payload;
|
|
uint16_t len = p->len;
|
|
|
|
// check if client is initiating a websocket connecttion
|
|
if (strstr(data, "Upgrade: websocket")) {
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: received websocket upgrade request\n"));
|
|
|
|
// search for websocket security key
|
|
char *key_start = strstr(data, "Sec-WebSocket-Key: ");
|
|
|
|
if (key_start) {
|
|
key_start += 19;
|
|
const char *key_end = strstr(key_start, "\r\n");
|
|
if (key_end) {
|
|
char key[64];
|
|
uint16_t key_len = key_end-key_start;
|
|
if ( (key_len>0) && (key_len + sizeof(WS_GUID) < sizeof(key)) ) {
|
|
// create response key by concatenating with websocket GUID,
|
|
// taking SHA1 hash, then encoding in base 64
|
|
strncpy(key, key_start, key_len);
|
|
strlcpy(&key[key_len], WS_GUID, sizeof(key)-key_len);
|
|
|
|
key_len += sizeof(WS_GUID)-1;
|
|
unsigned char key_sha1[20];
|
|
unsigned char key_base64[29];
|
|
size_t encoded_len;
|
|
mbedtls_sha1( (unsigned char *) key, key_len, key_sha1);
|
|
mbedtls_base64_encode( key_base64, 29, &encoded_len, key_sha1, 20);
|
|
|
|
// create response packet with encoded response key
|
|
unsigned char response[sizeof(WS_RESPONSE) + sizeof(key_base64)+3];
|
|
//strncpy(response, WS_RESPONSE, sizeof(WS_RESPONSE));
|
|
//strlcpy(&response[sizeof(WS_RESPONSE)-1], key_base64, strlen(key_base64));
|
|
size_t count = sprintf(response, "%s%s\r\n\r\n", WS_RESPONSE, key_base64);
|
|
|
|
// send completed data packet
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: sending response\n"));
|
|
if(altcp_write(pcb, response, strlen(response), TCP_WRITE_FLAG_COPY) == ERR_OK) {
|
|
wss->active = true;
|
|
}
|
|
|
|
if (ws_open_cb != NULL) {
|
|
ws_open_cb(wss);
|
|
}
|
|
|
|
return ERR_OK;
|
|
}
|
|
}
|
|
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: key overflow\n"));
|
|
return ERR_MEM;
|
|
} else {
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: key not received\n"));
|
|
return ERR_ARG;
|
|
}
|
|
}
|
|
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: not a websocket request\n"));
|
|
return ERR_ARG;
|
|
}
|
|
|
|
// handle reading of websocket data and pass to ws_receive_cb
|
|
static err_t ws_read(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p) {
|
|
uint8_t *data = (uint8_t *) p->payload;
|
|
uint16_t len = p->len;
|
|
|
|
if (data != NULL && len > 1) {
|
|
// successful read, reset timeout
|
|
wss->retries = 0;
|
|
|
|
uint8_t mode = data[0] & 0x0F;
|
|
uint16_t msg_len = data[1] & 0x7F;
|
|
switch (mode) {
|
|
case 0x01: //text
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received text data\n"));
|
|
case 0x02: //binary
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_read: decoding data\n"));
|
|
if (len >= 6 && ws_receive_cb != NULL) {
|
|
uint8_t *mask = &data[2];
|
|
uint8_t *msg = &data[6];
|
|
|
|
for (int i=0; i<msg_len; i++) {
|
|
msg[i] ^= mask[i % 4];
|
|
}
|
|
msg[msg_len]=0;
|
|
|
|
ws_receive_cb(msg, msg_len);
|
|
}
|
|
break;
|
|
case 0x08: //close
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_read: close request"));
|
|
return ERR_CLSD;
|
|
default:
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_read: invalid data mode %02X\n", mode));
|
|
return ERR_ARG;
|
|
}
|
|
|
|
return ERR_OK;
|
|
}
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_read: received empty payload\n"));
|
|
return ERR_VAL;
|
|
}
|
|
|
|
static err_t ws_send(struct ws_state *wss, uint8_t *data, uint16_t len) {
|
|
uint8_t buf[128];
|
|
buf[0] = 0x81;
|
|
buf[1] = len & 0x7F;
|
|
memcpy(&buf[2], data, len);
|
|
|
|
err_t err;
|
|
err = altcp_write(wss->pcb, buf, len+2, TCP_WRITE_FLAG_COPY);
|
|
if (err == ERR_OK) {
|
|
altcp_output(wss->pcb);
|
|
wss->retries = 0;
|
|
}
|
|
|
|
return err;
|
|
}
|
|
|
|
void ws_send_all(uint8_t *data, uint16_t len) {
|
|
// send message to all connections
|
|
if (ws_connections != NULL) {
|
|
struct ws_state *wss;
|
|
err_t err;
|
|
for (wss=ws_connections; wss != NULL; wss=wss->next) {
|
|
err = ws_send(wss, data, len);
|
|
if (err != ERR_OK ) {
|
|
LWIP_DEBUGF(WS_DEBUG, ("ws_send_all: error sending to %p\n", wss));
|
|
}
|
|
}
|
|
}
|
|
}
|