replace Buffer with allocated 'struct sshbuf *' in session_state

This commit is contained in:
Markus Friedl
2012-01-13 14:22:07 +01:00
parent 7861f112ba
commit 9662f3f978

View File

@@ -93,7 +93,7 @@ struct packet_state {
struct packet {
TAILQ_ENTRY(packet) next;
u_char type;
Buffer payload;
struct sshbuf *payload;
};
struct session_state {
@@ -116,20 +116,20 @@ struct session_state {
CipherContext send_context;
/* Buffer for raw input data from the socket. */
Buffer input;
struct sshbuf *input;
/* Buffer for raw output data going to the socket. */
Buffer output;
struct sshbuf *output;
/* Buffer for the partial outgoing packet being constructed. */
Buffer outgoing_packet;
struct sshbuf *outgoing_packet;
/* Buffer for the incoming packet currently being processed. */
Buffer incoming_packet;
struct sshbuf *incoming_packet;
/* Scratch buffer for packet compression/decompression. */
Buffer compression_buffer;
int compression_buffer_ready;
struct sshbuf *compression_buffer;
int compression_buffer_ready; /** XXX */
/*
* Flag indicating whether packet compression/decompression is
@@ -193,15 +193,48 @@ struct session_state {
struct ssh *
ssh_alloc_session_state(void)
{
struct ssh *ssh = xcalloc(1, sizeof(*ssh));
struct session_state *state = xcalloc(1, sizeof(*state));
struct ssh *ssh;
struct session_state *state;
if ((ssh = calloc(1, sizeof(*ssh))) == NULL ||
(ssh->state = state = calloc(1, sizeof(*state))) == NULL) {
if (ssh)
free(ssh);
return NULL;
}
state->connection_in = -1;
state->connection_out = -1;
state->max_packet_size = 32768;
state->packet_timeout_ms = -1;
ssh->state = state;
if (!state->initialized) {
if ((state->input = sshbuf_new()) == NULL ||
(state->output = sshbuf_new()) == NULL ||
(state->outgoing_packet = sshbuf_new()) == NULL ||
(state->incoming_packet = sshbuf_new()) == NULL)
goto out;
TAILQ_INIT(&state->outgoing);
TAILQ_INIT(&ssh->private_keys);
TAILQ_INIT(&ssh->public_keys);
state->p_send.packets = state->p_read.packets = 0;
state->initialized = 1;
}
return ssh;
out:
if (state->input)
sshbuf_free(state->input);
if (state->output)
sshbuf_free(state->output);
if (state->incoming_packet)
sshbuf_free(state->incoming_packet);
if (state->outgoing_packet)
sshbuf_free(state->outgoing_packet);
state->input = NULL;
state->output = NULL;
state->incoming_packet = NULL;
state->outgoing_packet = NULL;
free(ssh);
free(state);
return NULL;
}
/*
@@ -218,6 +251,8 @@ ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out)
fatal("packet_set_connection: cannot load cipher 'none'");
if (ssh == NULL)
ssh = ssh_alloc_session_state();
if (ssh == NULL)
fatal("%s: cound not allocate state", __func__);
state = ssh->state;
state->connection_in = fd_in;
state->connection_out = fd_out;
@@ -226,17 +261,6 @@ ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out)
cipher_init(&state->receive_context, none, (const u_char *)"",
0, NULL, 0, CIPHER_DECRYPT);
state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL;
if (!state->initialized) {
state->initialized = 1;
buffer_init(&state->input);
buffer_init(&state->output);
buffer_init(&state->outgoing_packet);
buffer_init(&state->incoming_packet);
TAILQ_INIT(&state->outgoing);
TAILQ_INIT(&ssh->private_keys);
TAILQ_INIT(&ssh->public_keys);
state->p_send.packets = state->p_read.packets = 0;
}
return ssh;
}
@@ -265,13 +289,13 @@ ssh_packet_stop_discard(struct ssh *ssh)
char buf[1024];
memset(buf, 'a', sizeof(buf));
while (buffer_len(&state->incoming_packet) <
while (buffer_len(state->incoming_packet) <
PACKET_MAX_SIZE)
buffer_append(&state->incoming_packet, buf,
buffer_append(state->incoming_packet, buf,
sizeof(buf));
(void) mac_compute(state->packet_discard_mac,
state->p_read.seqnr,
buffer_ptr(&state->incoming_packet),
buffer_ptr(state->incoming_packet),
PACKET_MAX_SIZE);
}
logit("Finished discarding for %.200s", get_remote_ipaddr());
@@ -288,10 +312,10 @@ ssh_packet_start_discard(struct ssh *ssh, Enc *enc, Mac *mac,
ssh_packet_disconnect(ssh, "Packet corrupt");
if (packet_length != PACKET_MAX_SIZE && mac && mac->enabled)
state->packet_discard_mac = mac;
if (buffer_len(&state->input) >= discard)
if (buffer_len(state->input) >= discard)
ssh_packet_stop_discard(ssh);
state->packet_discard = discard -
buffer_len(&state->input);
buffer_len(state->input);
}
/* Returns 1 if remote host is connected via socket, 0 if not. */
@@ -491,12 +515,12 @@ ssh_packet_close(struct ssh *ssh)
close(state->connection_in);
close(state->connection_out);
}
buffer_free(&state->input);
buffer_free(&state->output);
buffer_free(&state->outgoing_packet);
buffer_free(&state->incoming_packet);
buffer_free(state->input);
buffer_free(state->output);
buffer_free(state->outgoing_packet);
buffer_free(state->incoming_packet);
if (state->compression_buffer_ready) {
buffer_free(&state->compression_buffer);
buffer_free(state->compression_buffer);
buffer_compress_uninit();
}
cipher_cleanup(&state->send_context);
@@ -527,10 +551,10 @@ ssh_packet_get_protocol_flags(struct ssh *ssh)
void
ssh_packet_init_compression(struct ssh *ssh)
{
if (ssh->state->compression_buffer_ready == 1)
if (ssh->state->compression_buffer ||
((ssh->state->compression_buffer = sshbuf_new()) != NULL))
return;
ssh->state->compression_buffer_ready = 1;
buffer_init(&ssh->state->compression_buffer);
return; /* XX */
}
void
@@ -590,8 +614,8 @@ ssh_packet_start(struct ssh *ssh, u_char type)
len = compat20 ? 6 : 9;
memset(buf, 0, len - 1);
buf[len - 1] = type;
buffer_clear(&ssh->state->outgoing_packet);
buffer_append(&ssh->state->outgoing_packet, buf, len);
buffer_clear(ssh->state->outgoing_packet);
buffer_append(ssh->state->outgoing_packet, buf, len);
}
/* Append payload. */
@@ -600,55 +624,55 @@ ssh_packet_put_char(struct ssh *ssh, int value)
{
char ch = value;
buffer_append(&ssh->state->outgoing_packet, &ch, 1);
buffer_append(ssh->state->outgoing_packet, &ch, 1);
}
void
ssh_packet_put_int(struct ssh *ssh, u_int value)
{
buffer_put_int(&ssh->state->outgoing_packet, value);
buffer_put_int(ssh->state->outgoing_packet, value);
}
void
ssh_packet_put_int64(struct ssh *ssh, u_int64_t value)
{
buffer_put_int64(&ssh->state->outgoing_packet, value);
buffer_put_int64(ssh->state->outgoing_packet, value);
}
void
ssh_packet_put_string(struct ssh *ssh, const void *buf, u_int len)
{
buffer_put_string(&ssh->state->outgoing_packet, buf, len);
buffer_put_string(ssh->state->outgoing_packet, buf, len);
}
void
ssh_packet_put_cstring(struct ssh *ssh, const char *str)
{
buffer_put_cstring(&ssh->state->outgoing_packet, str);
buffer_put_cstring(ssh->state->outgoing_packet, str);
}
void
ssh_packet_put_raw(struct ssh *ssh, const void *buf, u_int len)
{
buffer_append(&ssh->state->outgoing_packet, buf, len);
buffer_append(ssh->state->outgoing_packet, buf, len);
}
void
ssh_packet_put_bignum(struct ssh *ssh, BIGNUM * value)
{
buffer_put_bignum(&ssh->state->outgoing_packet, value);
buffer_put_bignum(ssh->state->outgoing_packet, value);
}
void
ssh_packet_put_bignum2(struct ssh *ssh, BIGNUM * value)
{
buffer_put_bignum2(&ssh->state->outgoing_packet, value);
buffer_put_bignum2(ssh->state->outgoing_packet, value);
}
void
ssh_packet_put_ecpoint(struct ssh *ssh, const EC_GROUP *curve, const EC_POINT *point)
{
buffer_put_ecpoint(&ssh->state->outgoing_packet, curve, point);
buffer_put_ecpoint(ssh->state->outgoing_packet, curve, point);
}
/*
@@ -670,26 +694,26 @@ ssh_packet_send1(struct ssh *ssh)
* packet.
*/
if (state->packet_compression) {
buffer_clear(&state->compression_buffer);
buffer_clear(state->compression_buffer);
/* Skip padding. */
buffer_consume(&state->outgoing_packet, 8);
buffer_consume(state->outgoing_packet, 8);
/* padding */
buffer_append(&state->compression_buffer,
buffer_append(state->compression_buffer,
"\0\0\0\0\0\0\0\0", 8);
buffer_compress(&state->outgoing_packet,
&state->compression_buffer);
buffer_clear(&state->outgoing_packet);
buffer_append(&state->outgoing_packet,
buffer_ptr(&state->compression_buffer),
buffer_len(&state->compression_buffer));
buffer_compress(state->outgoing_packet,
state->compression_buffer);
buffer_clear(state->outgoing_packet);
buffer_append(state->outgoing_packet,
buffer_ptr(state->compression_buffer),
buffer_len(state->compression_buffer));
}
/* Compute packet length without padding (add checksum, remove padding). */
len = buffer_len(&state->outgoing_packet) + 4 - 8;
len = buffer_len(state->outgoing_packet) + 4 - 8;
/* Insert padding. Initialized to zero in packet_start1() */
padding = 8 - len % 8;
if (!state->send_context.plaintext) {
cp = buffer_ptr(&state->outgoing_packet);
cp = buffer_ptr(state->outgoing_packet);
for (i = 0; i < padding; i++) {
if (i % 4 == 0)
rnd = arc4random();
@@ -697,36 +721,36 @@ ssh_packet_send1(struct ssh *ssh)
rnd >>= 8;
}
}
buffer_consume(&state->outgoing_packet, 8 - padding);
buffer_consume(state->outgoing_packet, 8 - padding);
/* Add check bytes. */
checksum = ssh_crc32(buffer_ptr(&state->outgoing_packet),
buffer_len(&state->outgoing_packet));
checksum = ssh_crc32(buffer_ptr(state->outgoing_packet),
buffer_len(state->outgoing_packet));
put_u32(buf, checksum);
buffer_append(&state->outgoing_packet, buf, 4);
buffer_append(state->outgoing_packet, buf, 4);
#ifdef PACKET_DEBUG
fprintf(stderr, "packet_send plain: ");
buffer_dump(&state->outgoing_packet);
buffer_dump(state->outgoing_packet);
#endif
/* Append to output. */
put_u32(buf, len);
buffer_append(&state->output, buf, 4);
cp = buffer_append_space(&state->output,
buffer_len(&state->outgoing_packet));
buffer_append(state->output, buf, 4);
cp = buffer_append_space(state->output,
buffer_len(state->outgoing_packet));
cipher_crypt(&state->send_context, cp,
buffer_ptr(&state->outgoing_packet),
buffer_len(&state->outgoing_packet));
buffer_ptr(state->outgoing_packet),
buffer_len(state->outgoing_packet));
#ifdef PACKET_DEBUG
fprintf(stderr, "encrypted: ");
buffer_dump(&state->output);
buffer_dump(state->output);
#endif
state->p_send.packets++;
state->p_send.bytes += len +
buffer_len(&state->outgoing_packet);
buffer_clear(&state->outgoing_packet);
buffer_len(state->outgoing_packet);
buffer_clear(state->outgoing_packet);
/*
* Note that the packet is now only buffered in output. It won't be
@@ -869,32 +893,32 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
}
block_size = enc ? enc->block_size : 8;
cp = buffer_ptr(&state->outgoing_packet);
cp = buffer_ptr(state->outgoing_packet);
type = cp[5];
#ifdef PACKET_DEBUG
fprintf(stderr, "plain: ");
buffer_dump(&state->outgoing_packet);
buffer_dump(state->outgoing_packet);
#endif
if (comp && comp->enabled) {
len = buffer_len(&state->outgoing_packet);
len = buffer_len(state->outgoing_packet);
/* skip header, compress only payload */
buffer_consume(&state->outgoing_packet, 5);
buffer_clear(&state->compression_buffer);
buffer_compress(&state->outgoing_packet,
&state->compression_buffer);
buffer_clear(&state->outgoing_packet);
buffer_append(&state->outgoing_packet, "\0\0\0\0\0", 5);
buffer_append(&state->outgoing_packet,
buffer_ptr(&state->compression_buffer),
buffer_len(&state->compression_buffer));
buffer_consume(state->outgoing_packet, 5);
buffer_clear(state->compression_buffer);
buffer_compress(state->outgoing_packet,
state->compression_buffer);
buffer_clear(state->outgoing_packet);
buffer_append(state->outgoing_packet, "\0\0\0\0\0", 5);
buffer_append(state->outgoing_packet,
buffer_ptr(state->compression_buffer),
buffer_len(state->compression_buffer));
DBG(debug("compression: raw %d compressed %d", len,
buffer_len(&state->outgoing_packet)));
buffer_len(state->outgoing_packet)));
}
/* sizeof (packet_len + pad_len + payload) */
len = buffer_len(&state->outgoing_packet);
len = buffer_len(state->outgoing_packet);
/*
* calc size of padding, alloc space, get random data,
@@ -914,7 +938,7 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
padlen += pad;
state->extra_pad = 0;
}
cp = buffer_append_space(&state->outgoing_packet, padlen);
cp = buffer_append_space(state->outgoing_packet, padlen);
if (enc && !state->send_context.plaintext) {
/* random padding */
for (i = 0; i < padlen; i++) {
@@ -928,8 +952,8 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
memset(cp, 0, padlen);
}
/* packet_length includes payload, padding and padding length field */
packet_length = buffer_len(&state->outgoing_packet) - 4;
cp = buffer_ptr(&state->outgoing_packet);
packet_length = buffer_len(state->outgoing_packet) - 4;
cp = buffer_ptr(state->outgoing_packet);
put_u32(cp, packet_length);
cp[4] = padlen;
DBG(debug("send: len %d (includes padlen %d)", packet_length+4, padlen));
@@ -937,22 +961,22 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
/* compute MAC over seqnr and packet(length fields, payload, padding) */
if (mac && mac->enabled) {
macbuf = mac_compute(mac, state->p_send.seqnr,
buffer_ptr(&state->outgoing_packet),
buffer_len(&state->outgoing_packet));
buffer_ptr(state->outgoing_packet),
buffer_len(state->outgoing_packet));
DBG(debug("done calc MAC out #%d", state->p_send.seqnr));
}
/* encrypt packet and append to output buffer. */
cp = buffer_append_space(&state->output,
buffer_len(&state->outgoing_packet));
cp = buffer_append_space(state->output,
buffer_len(state->outgoing_packet));
cipher_crypt(&state->send_context, cp,
buffer_ptr(&state->outgoing_packet),
buffer_len(&state->outgoing_packet));
buffer_ptr(state->outgoing_packet),
buffer_len(state->outgoing_packet));
/* append unencrypted MAC */
if (mac && mac->enabled)
buffer_append(&state->output, macbuf, mac->mac_len);
buffer_append(state->output, macbuf, mac->mac_len);
#ifdef PACKET_DEBUG
fprintf(stderr, "encrypted: ");
buffer_dump(&state->output);
buffer_dump(state->output);
#endif
/* increment sequence number for outgoing packets */
if (++state->p_send.seqnr == 0)
@@ -962,7 +986,7 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
fatal("XXX too many packets with same key");
state->p_send.blocks += (packet_length + 4) / block_size;
state->p_send.bytes += packet_length + 4;
buffer_clear(&state->outgoing_packet);
buffer_clear(state->outgoing_packet);
if (type == SSH2_MSG_NEWKEYS)
ssh_set_newkeys(ssh, MODE_OUT);
@@ -977,7 +1001,7 @@ ssh_packet_send2(struct ssh *ssh)
struct packet *p;
u_char type, *cp;
cp = buffer_ptr(&state->outgoing_packet);
cp = buffer_ptr(state->outgoing_packet);
type = cp[5];
/* during rekeying we can only send key exchange messages */
@@ -989,10 +1013,11 @@ ssh_packet_send2(struct ssh *ssh)
debug("enqueue packet: %u", type);
p = xmalloc(sizeof(*p));
p->type = type;
memcpy(&p->payload, &state->outgoing_packet,
sizeof(Buffer));
buffer_init(&state->outgoing_packet);
p->payload = state->outgoing_packet;
TAILQ_INSERT_TAIL(&state->outgoing, p, next);
state->outgoing_packet = sshbuf_new();
if (state->outgoing_packet == NULL)
fatal("%s: sshbuf_new failed", __func__);
return;
}
}
@@ -1009,9 +1034,8 @@ ssh_packet_send2(struct ssh *ssh)
while ((p = TAILQ_FIRST(&state->outgoing))) {
type = p->type;
debug("dequeue packet: %u", type);
buffer_free(&state->outgoing_packet);
memcpy(&state->outgoing_packet, &p->payload,
sizeof(Buffer));
sshbuf_free(state->outgoing_packet);
state->outgoing_packet = p->payload;
TAILQ_REMOVE(&state->outgoing, p, next);
xfree(p);
ssh_packet_send2_wrapped(ssh);
@@ -1162,10 +1186,10 @@ ssh_packet_read_poll1(struct ssh *ssh)
u_int checksum, stored_checksum;
/* Check if input size is less than minimum packet size. */
if (buffer_len(&state->input) < 4 + 8)
if (buffer_len(state->input) < 4 + 8)
return SSH_MSG_NONE;
/* Get length of incoming packet. */
cp = buffer_ptr(&state->input);
cp = buffer_ptr(state->input);
len = get_u32(cp);
if (len < 1 + 2 + 2 || len > 256 * 1024)
ssh_packet_disconnect(ssh, "Bad packet length %u.",
@@ -1173,13 +1197,13 @@ ssh_packet_read_poll1(struct ssh *ssh)
padded_len = (len + 8) & ~7;
/* Check if the packet has been entirely received. */
if (buffer_len(&state->input) < 4 + padded_len)
if (buffer_len(state->input) < 4 + padded_len)
return SSH_MSG_NONE;
/* The entire packet is in buffer. */
/* Consume packet length. */
buffer_consume(&state->input, 4);
buffer_consume(state->input, 4);
/*
* Cryptographic attack detector for ssh
@@ -1187,7 +1211,7 @@ ssh_packet_read_poll1(struct ssh *ssh)
* Ariel Futoransky(futo@core-sdi.com)
*/
if (!state->receive_context.plaintext) {
switch (detect_attack(buffer_ptr(&state->input),
switch (detect_attack(buffer_ptr(state->input),
padded_len)) {
case DEATTACK_DETECTED:
ssh_packet_disconnect(ssh,
@@ -1200,50 +1224,50 @@ ssh_packet_read_poll1(struct ssh *ssh)
}
/* Decrypt data to incoming_packet. */
buffer_clear(&state->incoming_packet);
cp = buffer_append_space(&state->incoming_packet, padded_len);
buffer_clear(state->incoming_packet);
cp = buffer_append_space(state->incoming_packet, padded_len);
cipher_crypt(&state->receive_context, cp,
buffer_ptr(&state->input), padded_len);
buffer_ptr(state->input), padded_len);
buffer_consume(&state->input, padded_len);
buffer_consume(state->input, padded_len);
#ifdef PACKET_DEBUG
fprintf(stderr, "read_poll plain: ");
buffer_dump(&state->incoming_packet);
buffer_dump(state->incoming_packet);
#endif
/* Compute packet checksum. */
checksum = ssh_crc32(buffer_ptr(&state->incoming_packet),
buffer_len(&state->incoming_packet) - 4);
checksum = ssh_crc32(buffer_ptr(state->incoming_packet),
buffer_len(state->incoming_packet) - 4);
/* Skip padding. */
buffer_consume(&state->incoming_packet, 8 - len % 8);
buffer_consume(state->incoming_packet, 8 - len % 8);
/* Test check bytes. */
if (len != buffer_len(&state->incoming_packet))
if (len != buffer_len(state->incoming_packet))
ssh_packet_disconnect(ssh,
"packet_read_poll1: len %d != buffer_len %d.",
len, buffer_len(&state->incoming_packet));
len, buffer_len(state->incoming_packet));
cp = (u_char *)buffer_ptr(&state->incoming_packet) + len - 4;
cp = (u_char *)buffer_ptr(state->incoming_packet) + len - 4;
stored_checksum = get_u32(cp);
if (checksum != stored_checksum)
ssh_packet_disconnect(ssh,
"Corrupted check bytes on input.");
buffer_consume_end(&state->incoming_packet, 4);
buffer_consume_end(state->incoming_packet, 4);
if (state->packet_compression) {
buffer_clear(&state->compression_buffer);
buffer_uncompress(&state->incoming_packet,
&state->compression_buffer);
buffer_clear(&state->incoming_packet);
buffer_append(&state->incoming_packet,
buffer_ptr(&state->compression_buffer),
buffer_len(&state->compression_buffer));
buffer_clear(state->compression_buffer);
buffer_uncompress(state->incoming_packet,
state->compression_buffer);
buffer_clear(state->incoming_packet);
buffer_append(state->incoming_packet,
buffer_ptr(state->compression_buffer),
buffer_len(state->compression_buffer));
}
state->p_read.packets++;
state->p_read.bytes += padded_len + 4;
type = buffer_get_char(&state->incoming_packet);
type = buffer_get_char(state->incoming_packet);
if (type < SSH_MSG_MIN || type > SSH_MSG_MAX)
ssh_packet_disconnect(ssh,
"Invalid ssh1 packet type: %d", type);
@@ -1277,19 +1301,19 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
* check if input size is less than the cipher block size,
* decrypt first block and extract length of incoming packet
*/
if (buffer_len(&state->input) < block_size)
if (buffer_len(state->input) < block_size)
return SSH_MSG_NONE;
buffer_clear(&state->incoming_packet);
cp = buffer_append_space(&state->incoming_packet,
buffer_clear(state->incoming_packet);
cp = buffer_append_space(state->incoming_packet,
block_size);
cipher_crypt(&state->receive_context, cp,
buffer_ptr(&state->input), block_size);
cp = buffer_ptr(&state->incoming_packet);
buffer_ptr(state->input), block_size);
cp = buffer_ptr(state->incoming_packet);
state->packlen = get_u32(cp);
if (state->packlen < 1 + 4 ||
state->packlen > PACKET_MAX_SIZE) {
#ifdef PACKET_DEBUG
buffer_dump(&state->incoming_packet);
buffer_dump(state->incoming_packet);
#endif
logit("Bad packet length %u.", state->packlen);
ssh_packet_start_discard(ssh, enc, mac,
@@ -1297,7 +1321,7 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
return SSH_MSG_NONE;
}
DBG(debug("input: packet len %u", state->packlen+4));
buffer_consume(&state->input, block_size);
buffer_consume(state->input, block_size);
}
/* we have a partial packet of block_size bytes */
need = 4 + state->packlen - block_size;
@@ -1314,25 +1338,25 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
* check if the entire packet has been received and
* decrypt into incoming_packet
*/
if (buffer_len(&state->input) < need + maclen)
if (buffer_len(state->input) < need + maclen)
return SSH_MSG_NONE;
#ifdef PACKET_DEBUG
fprintf(stderr, "read_poll enc/full: ");
buffer_dump(&state->input);
buffer_dump(state->input);
#endif
cp = buffer_append_space(&state->incoming_packet, need);
cp = buffer_append_space(state->incoming_packet, need);
cipher_crypt(&state->receive_context, cp,
buffer_ptr(&state->input), need);
buffer_consume(&state->input, need);
buffer_ptr(state->input), need);
buffer_consume(state->input, need);
/*
* compute MAC over seqnr and packet,
* increment sequence number for incoming packet
*/
if (mac && mac->enabled) {
macbuf = mac_compute(mac, state->p_read.seqnr,
buffer_ptr(&state->incoming_packet),
buffer_len(&state->incoming_packet));
if (timingsafe_bcmp(macbuf, buffer_ptr(&state->input),
buffer_ptr(state->incoming_packet),
buffer_len(state->incoming_packet));
if (timingsafe_bcmp(macbuf, buffer_ptr(state->input),
mac->mac_len) != 0) {
logit("Corrupted MAC on input.");
if (need > PACKET_MAX_SIZE)
@@ -1343,7 +1367,7 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
}
DBG(debug("MAC #%d ok", state->p_read.seqnr));
buffer_consume(&state->input, mac->mac_len);
buffer_consume(state->input, mac->mac_len);
}
/* XXX now it's safe to use fatal/packet_disconnect */
if (seqnr_p != NULL)
@@ -1357,7 +1381,7 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
state->p_read.bytes += state->packlen + 4;
/* get padlen */
cp = buffer_ptr(&state->incoming_packet);
cp = buffer_ptr(state->incoming_packet);
padlen = cp[4];
DBG(debug("input: padlen %d", padlen));
if (padlen < 4)
@@ -1365,27 +1389,27 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
"Corrupted padlen %d on input.", padlen);
/* skip packet size + padlen, discard padding */
buffer_consume(&state->incoming_packet, 4 + 1);
buffer_consume_end(&state->incoming_packet, padlen);
buffer_consume(state->incoming_packet, 4 + 1);
buffer_consume_end(state->incoming_packet, padlen);
DBG(debug("input: len before de-compress %d",
buffer_len(&state->incoming_packet)));
buffer_len(state->incoming_packet)));
if (comp && comp->enabled) {
buffer_clear(&state->compression_buffer);
buffer_uncompress(&state->incoming_packet,
&state->compression_buffer);
buffer_clear(&state->incoming_packet);
buffer_append(&state->incoming_packet,
buffer_ptr(&state->compression_buffer),
buffer_len(&state->compression_buffer));
buffer_clear(state->compression_buffer);
buffer_uncompress(state->incoming_packet,
state->compression_buffer);
buffer_clear(state->incoming_packet);
buffer_append(state->incoming_packet,
buffer_ptr(state->compression_buffer),
buffer_len(state->compression_buffer));
DBG(debug("input: len after de-compress %d",
buffer_len(&state->incoming_packet)));
buffer_len(state->incoming_packet)));
}
/*
* get packet type, implies consume.
* return length of payload (without type field)
*/
type = buffer_get_char(&state->incoming_packet);
type = buffer_get_char(state->incoming_packet);
if (type < SSH2_MSG_MIN || type >= SSH2_MSG_LOCAL_MIN)
ssh_packet_disconnect(ssh,
"Invalid ssh2 packet type: %d", type);
@@ -1396,7 +1420,7 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
ssh_packet_enable_delayed_compress(ssh);
#ifdef PACKET_DEBUG
fprintf(stderr, "read/plain[%d]:\r\n", type);
buffer_dump(&state->incoming_packet);
buffer_dump(state->incoming_packet);
#endif
/* reset for next packet */
state->packlen = 0;
@@ -1494,7 +1518,7 @@ ssh_packet_process_incoming(struct ssh *ssh, const char *buf, u_int len)
state->packet_discard -= len;
return;
}
buffer_append(&ssh->state->input, buf, len);
buffer_append(ssh->state->input, buf, len);
}
/* Returns a character from the packet. */
@@ -1504,7 +1528,7 @@ ssh_packet_get_char(struct ssh *ssh)
{
char ch;
buffer_get(&ssh->state->incoming_packet, &ch, 1);
buffer_get(ssh->state->incoming_packet, &ch, 1);
return (u_char) ch;
}
@@ -1513,7 +1537,7 @@ ssh_packet_get_char(struct ssh *ssh)
u_int
ssh_packet_get_int(struct ssh *ssh)
{
return buffer_get_int(&ssh->state->incoming_packet);
return buffer_get_int(ssh->state->incoming_packet);
}
/* Returns an 64 bit integer from the packet data. */
@@ -1521,7 +1545,7 @@ ssh_packet_get_int(struct ssh *ssh)
u_int64_t
ssh_packet_get_int64(struct ssh *ssh)
{
return buffer_get_int64(&ssh->state->incoming_packet);
return buffer_get_int64(ssh->state->incoming_packet);
}
/*
@@ -1532,35 +1556,35 @@ ssh_packet_get_int64(struct ssh *ssh)
void
ssh_packet_get_bignum(struct ssh *ssh, BIGNUM * value)
{
buffer_get_bignum(&ssh->state->incoming_packet, value);
buffer_get_bignum(ssh->state->incoming_packet, value);
}
void
ssh_packet_get_bignum2(struct ssh *ssh, BIGNUM * value)
{
buffer_get_bignum2(&ssh->state->incoming_packet, value);
buffer_get_bignum2(ssh->state->incoming_packet, value);
}
void
ssh_packet_get_ecpoint(struct ssh *ssh, const EC_GROUP *curve, EC_POINT *point)
{
buffer_get_ecpoint(&ssh->state->incoming_packet, curve, point);
buffer_get_ecpoint(ssh->state->incoming_packet, curve, point);
}
void *
ssh_packet_get_raw(struct ssh *ssh, u_int *length_ptr)
{
u_int bytes = buffer_len(&ssh->state->incoming_packet);
u_int bytes = buffer_len(ssh->state->incoming_packet);
if (length_ptr != NULL)
*length_ptr = bytes;
return buffer_ptr(&ssh->state->incoming_packet);
return buffer_ptr(ssh->state->incoming_packet);
}
int
ssh_packet_remaining(struct ssh *ssh)
{
return buffer_len(&ssh->state->incoming_packet);
return buffer_len(ssh->state->incoming_packet);
}
/*
@@ -1573,20 +1597,20 @@ ssh_packet_remaining(struct ssh *ssh)
void *
ssh_packet_get_string(struct ssh *ssh, u_int *length_ptr)
{
return buffer_get_string(&ssh->state->incoming_packet, length_ptr);
return buffer_get_string(ssh->state->incoming_packet, length_ptr);
}
const void *
ssh_packet_get_string_ptr(struct ssh *ssh, u_int *length_ptr)
{
return buffer_get_string_ptr(&ssh->state->incoming_packet, length_ptr);
return buffer_get_string_ptr(ssh->state->incoming_packet, length_ptr);
}
/* Ensures the returned string has no embedded \0 characters in it. */
char *
ssh_packet_get_cstring(struct ssh *ssh, u_int *length_ptr)
{
return buffer_get_cstring(&ssh->state->incoming_packet, length_ptr);
return buffer_get_cstring(ssh->state->incoming_packet, length_ptr);
}
/*
@@ -1680,13 +1704,13 @@ void
ssh_packet_write_poll(struct ssh *ssh)
{
struct session_state *state = ssh->state;
int len = buffer_len(&state->output);
int len = buffer_len(state->output);
int cont;
if (len > 0) {
cont = 0;
len = roaming_write(state->connection_out,
buffer_ptr(&state->output), len, &cont);
buffer_ptr(state->output), len, &cont);
if (len == -1) {
if (errno == EINTR || errno == EAGAIN)
return;
@@ -1694,7 +1718,7 @@ ssh_packet_write_poll(struct ssh *ssh)
}
if (len == 0 && !cont)
fatal("Write connection closed");
buffer_consume(&state->output, len);
buffer_consume(state->output, len);
}
}
@@ -1756,7 +1780,7 @@ ssh_packet_write_wait(struct ssh *ssh)
int
ssh_packet_have_data_to_write(struct ssh *ssh)
{
return buffer_len(&ssh->state->output) != 0;
return buffer_len(ssh->state->output) != 0;
}
/* Returns true if there is not too much data to write to the connection. */
@@ -1765,9 +1789,9 @@ int
ssh_packet_not_very_much_data_to_write(struct ssh *ssh)
{
if (ssh->state->interactive_mode)
return buffer_len(&ssh->state->output) < 16384;
return buffer_len(ssh->state->output) < 16384;
else
return buffer_len(&ssh->state->output) < 128 * 1024;
return buffer_len(ssh->state->output) < 128 * 1024;
}
void
@@ -1934,13 +1958,13 @@ ssh_packet_set_authenticated(struct ssh *ssh)
void *
ssh_packet_get_input(struct ssh *ssh)
{
return (void *)&ssh->state->input;
return (void *)ssh->state->input;
}
void *
ssh_packet_get_output(struct ssh *ssh)
{
return (void *)&ssh->state->output;
return (void *)ssh->state->output;
}
void *
@@ -1991,11 +2015,11 @@ ssh_packet_restore_state(struct ssh *ssh,
backup_state->state->connection_in = -1;
ssh->state->connection_out = backup_state->state->connection_out;
backup_state->state->connection_out = -1;
len = buffer_len(&backup_state->state->input);
len = buffer_len(backup_state->state->input);
if (len > 0) {
buf = buffer_ptr(&backup_state->state->input);
buffer_append(&ssh->state->input, buf, len);
buffer_clear(&backup_state->state->input);
buf = buffer_ptr(backup_state->state->input);
buffer_append(ssh->state->input, buf, len);
buffer_clear(backup_state->state->input);
add_recv_bytes(len);
}
}