#include #include "lwip/altcp.h" #include "lwip/debug.h" #include "mbedtls/base64.h" #include "mbedtls/sha1.h" #include "websocket.h" static const char WS_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static const char WS_RESPONSE[] = "HTTP/1.1 101 Switching Protocols\r\n" \ "Upgrade: websocket\r\n" \ "Connection: Upgrade\r\n" \ "Sec-WebSocket-Accept: "; static tWSHandler ws_receive_cb = NULL; static tWSOpenHandler ws_open_cb = NULL; static struct ws_state * ws_connections; static uint8_t ws_num_conns = 0; // allocate memory for ws_state instance static struct ws_state * ws_state_alloc(void) { struct ws_state *ret = WS_ALLOC_WS_STATE(); if ( ret != NULL) { ws_state_init(ret); if (ws_connections == NULL) { ws_connections = ret; } else { struct ws_state *last; for (last=ws_connections; last->next != NULL; last=last->next); LWIP_ASSERT("last != NULL", last != NULL); last->next = ret; } } return ret; } // initiate ws_state instance static void ws_state_init(struct ws_state *wss) { memset(wss, 0, sizeof(struct ws_state)); wss->active = false; } // free memory from ws_state instance static void ws_state_free(struct ws_state *wss) { if (wss != NULL) { if (ws_connections == wss) { ws_connections = wss->next; } else { struct ws_state * last; for (last = ws_connections; last->next != NULL; last = last->next) { if (last->next = wss) { last->next = wss->next; break; } } } mem_free(wss); } } // initiate websocket server on specified pcb static void ws_server_init_pcb( struct altcp_pcb *pcb, uint16_t port) { err_t err; if (pcb) { altcp_setprio(pcb, TCP_PRIO_MIN); err = altcp_bind(pcb, IP_ANY_TYPE, port); LWIP_UNUSED_ARG(err); LWIP_ASSERT("ws_server_init: tcp_bind failed", err == ERR_OK); pcb = altcp_listen(pcb); LWIP_ASSERT("ws_server_init: tcp_listen failed", pcb != NULL); altcp_accept(pcb, ws_accept); } } // initiate a websocket server void ws_server_init(void) { struct altcp_pcb *pcb = altcp_tcp_new_ip_type(IPADDR_TYPE_ANY); LWIP_ASSERT("ws_server_init: tcp_new failed", pcb != NULL); ws_server_init_pcb(pcb, WS_PORT); } // set ws_receive_handler void ws_set_receive_handler( tWSHandler ws_handler) { ws_receive_cb = ws_handler; } // set ws_open_handler void ws_set_open_handler( tWSOpenHandler ws_handler) { ws_open_cb = ws_handler; } // callback for accepted websocket connection static err_t ws_accept(void *arg, struct altcp_pcb *pcb, err_t err) { struct ws_state *wss; LWIP_UNUSED_ARG(err); LWIP_UNUSED_ARG(arg); LWIP_DEBUGF(WS_DEBUG, ("ws_accept %p / %p\n", (void *)pcb, arg)); if ((err != ERR_OK) || (pcb == NULL)) { return ERR_VAL; } // create new ws_state object wss = ws_state_alloc(); if (wss == NULL) { LWIP_DEBUGF(WS_DEBUG, ("ws_accept: Out of memory, RST\n")); return ERR_MEM; } wss->pcb = pcb; // make ws_state object the argument of callbacks altcp_arg(pcb, wss); // register callbacks for tcp events altcp_recv(pcb, ws_recv); altcp_sent(pcb, ws_sent); altcp_poll(pcb, ws_poll, WS_POLL_INTERVAL); altcp_err(pcb, ws_err); return ERR_OK; } // call when data is received static err_t ws_recv(void *arg, struct altcp_pcb *pcb, struct pbuf *p, err_t err) { struct ws_state *wss = (struct ws_state *) arg; if ((err != ERR_OK) || (p == NULL) || (wss == NULL)) { // error or closed by client if (p != NULL) { // inform TCP that we have taken the data altcp_recved(pcb, p->tot_len); pbuf_free(p); } if (wss == NULL) { // should not occur LWIP_DEBUGF(WS_DEBUG, ("Error, ws_recv: wss is NULL, close\n")); } ws_close_conn(pcb, wss); return ERR_OK; } if (wss->active) { // process websocket message err = ws_read(pcb, wss, p); } else { // init websocket connection LWIP_DEBUGF(WS_DEBUG, ("ws_recv: websocket inactive, checking for handshake\n")); err = ws_handshake(pcb, wss, p); } // inform TCP that we have taken the data. altcp_recved(pcb, p->tot_len); pbuf_free(p); if (err == ERR_CLSD) { ws_close_conn(pcb, wss); } return ERR_OK; } // called when data has been sent over the websocket static err_t ws_sent(void *arg, struct altcp_pcb *pcb, uint16_t len) { struct ws_state *wss = (struct ws_state *)arg; LWIP_DEBUGF(WS_DEBUG | LWIP_DBG_TRACE, ("ws_sent %p\n", (void*) pcb)); LWIP_UNUSED_ARG(len); if (wss == NULL) { return ERR_OK; } wss->retries = 0; return ERR_OK; } // called when there is a websocket error static void ws_err (void *arg, err_t err) { struct ws_state *wss = (struct ws_state *) arg; LWIP_UNUSED_ARG(err); LWIP_DEBUGF(WS_DEBUG, ("ws_err: %s", lwip_strerr(err))); if (wss != NULL) { ws_state_free(wss); } } // initiate close of connection static err_t ws_close_conn(struct altcp_pcb *pcb, struct ws_state *wss) { return ws_close_or_abort_conn(pcb, wss, 0); } // call when closing connection or connection was aborted static err_t ws_close_or_abort_conn(struct altcp_pcb *pcb, struct ws_state *wss, uint8_t abort_conn) { err_t err; LWIP_DEBUGF(WS_DEBUG, ("Closing connection %p\n", (void *)pcb)); // clear callbacks altcp_arg(pcb, NULL); altcp_recv(pcb, NULL); altcp_sent(pcb, NULL); altcp_poll(pcb, NULL, 0); altcp_err(pcb, NULL); // remove and free memory from ws_state object if (wss != NULL) { ws_state_free(wss); } if (abort_conn) { altcp_abort(pcb); return ERR_OK; } err = altcp_close(pcb); if (err != ERR_OK) { LWIP_DEBUGF(WS_DEBUG, ("Error %d closing %p\n", err, (void *)pcb)); // error closing, try again later in poll altcp_poll(pcb, ws_poll, WS_POLL_INTERVAL); } return err; } // callback for polling process static err_t ws_poll(void *arg, struct altcp_pcb *pcb) { struct ws_state *wss = (struct ws_state *) arg; if (wss == NULL) { err_t closed; LWIP_DEBUGF(WS_DEBUG, ("ws_poll: arg is NULL, close\n")); closed = ws_close_conn(pcb, NULL); LWIP_UNUSED_ARG(closed); if (closed == ERR_MEM) { altcp_abort(pcb); return ERR_ABRT; } return ERR_OK; } else { wss->retries++; if (wss->retries == WS_MAX_RETRIES) { LWIP_DEBUGF(WS_DEBUG, ("ws_poll: too may retries, close\n")); ws_close_conn(pcb, wss); return ERR_OK; } } return ERR_OK; } // check for and complete handshake with client static err_t ws_handshake(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p){ uint8_t *data = (uint8_t *) p->payload; uint16_t len = p->len; // check if client is initiating a websocket connecttion if (strstr(data, "Upgrade: websocket")) { LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: received websocket upgrade request\n")); // search for websocket security key char *key_start = strstr(data, "Sec-WebSocket-Key: "); if (key_start) { key_start += 19; const char *key_end = strstr(key_start, "\r\n"); if (key_end) { char key[64]; uint16_t key_len = key_end-key_start; if ( (key_len>0) && (key_len + sizeof(WS_GUID) < sizeof(key)) ) { // create response key by concatenating with websocket GUID, // taking SHA1 hash, then encoding in base 64 strncpy(key, key_start, key_len); strlcpy(&key[key_len], WS_GUID, sizeof(key)-key_len); key_len += sizeof(WS_GUID)-1; unsigned char key_sha1[20]; unsigned char key_base64[29]; size_t encoded_len; mbedtls_sha1( (unsigned char *) key, key_len, key_sha1); mbedtls_base64_encode( key_base64, 29, &encoded_len, key_sha1, 20); // create response packet with encoded response key unsigned char response[sizeof(WS_RESPONSE) + sizeof(key_base64)+3]; //strncpy(response, WS_RESPONSE, sizeof(WS_RESPONSE)); //strlcpy(&response[sizeof(WS_RESPONSE)-1], key_base64, strlen(key_base64)); size_t count = sprintf(response, "%s%s\r\n\r\n", WS_RESPONSE, key_base64); // send completed data packet LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: sending response\n")); if(altcp_write(pcb, response, strlen(response), TCP_WRITE_FLAG_COPY) == ERR_OK) { wss->active = true; } if (ws_open_cb != NULL) { ws_open_cb(wss); } return ERR_OK; } } LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: key overflow\n")); return ERR_MEM; } else { LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: key not received\n")); return ERR_ARG; } } LWIP_DEBUGF(WS_DEBUG, ("ws_handshake: not a websocket request\n")); return ERR_ARG; } // handle reading of websocket data and pass to ws_receive_cb static err_t ws_read(struct altcp_pcb *pcb, struct ws_state *wss, struct pbuf *p) { uint8_t *data = (uint8_t *) p->payload; uint16_t len = p->len; if (data != NULL && len > 1) { // successful read, reset timeout wss->retries = 0; uint8_t mode = data[0] & 0x0F; uint16_t msg_len = data[1] & 0x7F; switch (mode) { case 0x01: //text LWIP_DEBUGF(WS_DEBUG, ("ws_read: received text data\n")); case 0x02: //binary LWIP_DEBUGF(WS_DEBUG, ("ws_read: decoding data\n")); if (len >= 6 && ws_receive_cb != NULL) { uint8_t *mask = &data[2]; uint8_t *msg = &data[6]; for (int i=0; ipcb, buf, len+2, TCP_WRITE_FLAG_COPY); if (err == ERR_OK) { altcp_output(wss->pcb); } return err; } void ws_send_all(uint8_t *data, uint16_t len) { // send message to all connections if (ws_connections != NULL) { struct ws_state *wss; err_t err; for (wss=ws_connections; wss != NULL; wss=wss->next) { err = ws_send(wss, data, len); if (err != ERR_OK ) { LWIP_DEBUGF(WS_DEBUG, ("ws_send_all: error sending to %p\n", wss)); } } } }