new sshpkt API: {get,put} wraps ssh_packet_{get,put}, adds {get_end,disconnect}

This commit is contained in:
Markus Friedl
2012-01-15 10:44:50 +01:00
parent 54c5205bc4
commit 9e254e24c5
2 changed files with 335 additions and 68 deletions

View File

@@ -609,15 +609,9 @@ ssh_packet_get_encryption_key(struct ssh *ssh, u_char *key)
void
ssh_packet_start(struct ssh *ssh, u_char type)
{
u_char buf[9];
int len, ret;
int ret;
DBG(debug("packet_start[%d]", type));
len = compat20 ? 6 : 9;
memset(buf, 0, len - 1);
buf[len - 1] = type;
sshbuf_reset(ssh->state->outgoing_packet);
if ((ret = sshbuf_put(ssh->state->outgoing_packet, buf, len)) != 0)
if ((ret = sshpkt_start(ssh, type)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
@@ -635,25 +629,37 @@ ssh_packet_put_char(struct ssh *ssh, int value)
void
ssh_packet_put_int(struct ssh *ssh, u_int value)
{
buffer_put_int(ssh->state->outgoing_packet, value);
int ret;
if ((ret = sshpkt_put_u32(ssh, value)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_put_int64(struct ssh *ssh, u_int64_t value)
{
buffer_put_int64(ssh->state->outgoing_packet, value);
int ret;
if ((ret = sshpkt_put_u64(ssh, value)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_put_string(struct ssh *ssh, const void *buf, u_int len)
{
buffer_put_string(ssh->state->outgoing_packet, buf, len);
int ret;
if ((ret = sshpkt_put_string(ssh, buf, len)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_put_cstring(struct ssh *ssh, const char *str)
{
buffer_put_cstring(ssh->state->outgoing_packet, str);
int ret;
if ((ret = sshpkt_put_cstring(ssh, str)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
@@ -661,26 +667,35 @@ ssh_packet_put_raw(struct ssh *ssh, const void *buf, u_int len)
{
int ret;
if ((ret = sshbuf_put(ssh->state->outgoing_packet, buf, len)) != 0)
if ((ret = sshpkt_put(ssh, buf, len)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_put_bignum(struct ssh *ssh, BIGNUM * value)
{
buffer_put_bignum(ssh->state->outgoing_packet, value);
int ret;
if ((ret = sshpkt_put_bignum1(ssh, value)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_put_bignum2(struct ssh *ssh, BIGNUM * value)
{
buffer_put_bignum2(ssh->state->outgoing_packet, value);
int ret;
if ((ret = sshpkt_put_bignum2(ssh, value)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_put_ecpoint(struct ssh *ssh, const EC_GROUP *curve, const EC_POINT *point)
{
buffer_put_ecpoint(ssh->state->outgoing_packet, curve, point);
int ret;
if ((ret = sshpkt_put_ec(ssh, point, curve)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
/*
@@ -688,7 +703,7 @@ ssh_packet_put_ecpoint(struct ssh *ssh, const EC_GROUP *curve, const EC_POINT *p
* encrypts the packet before sending.
*/
void
int
ssh_packet_send1(struct ssh *ssh)
{
struct session_state *state = ssh->state;
@@ -705,17 +720,17 @@ ssh_packet_send1(struct ssh *ssh)
sshbuf_reset(state->compression_buffer);
/* Skip padding. */
if ((ret = sshbuf_consume(state->outgoing_packet, 8)) != 0)
goto fail;
goto out;
/* padding */
if ((ret = sshbuf_put(state->compression_buffer,
"\0\0\0\0\0\0\0\0", 8)) != 0)
goto fail;
goto out;
buffer_compress(state->outgoing_packet,
state->compression_buffer);
sshbuf_reset(state->outgoing_packet);
if ((ret = sshbuf_putb(state->outgoing_packet,
state->compression_buffer)) != 0)
goto fail;
goto out;
}
/* Compute packet length without padding (add checksum, remove padding). */
len = sshbuf_len(state->outgoing_packet) + 4 - 8;
@@ -732,14 +747,14 @@ ssh_packet_send1(struct ssh *ssh)
}
}
if ((ret = sshbuf_consume(state->outgoing_packet, 8 - padding)) != 0)
goto fail;
goto out;
/* Add check bytes. */
checksum = ssh_crc32(sshbuf_ptr(state->outgoing_packet),
sshbuf_len(state->outgoing_packet));
put_u32(buf, checksum);
if ((ret == sshbuf_put(state->outgoing_packet, buf, 4)) != 0)
goto fail;
if ((ret = sshbuf_put(state->outgoing_packet, buf, 4)) != 0)
goto out;
#ifdef PACKET_DEBUG
fprintf(stderr, "packet_send plain: ");
@@ -749,10 +764,10 @@ ssh_packet_send1(struct ssh *ssh)
/* Append to output. */
put_u32(buf, len);
if ((ret = sshbuf_put(state->output, buf, 4)) != 0)
goto fail;
goto out;
if ((ret = sshbuf_reserve(state->output,
sshbuf_len(state->outgoing_packet), &cp)) != 0)
goto fail;
goto out;
cipher_crypt(&state->send_context, cp,
sshbuf_ptr(state->outgoing_packet),
sshbuf_len(state->outgoing_packet));
@@ -771,9 +786,9 @@ ssh_packet_send1(struct ssh *ssh)
* actually sent until packet_write_wait or packet_write_poll is
* called.
*/
return;
fail:
fatal("%s: %s", __func__, ssh_err(ret));
ret = 0;
out:
return ret;
}
void
@@ -889,7 +904,7 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh)
/*
* Finalize packet in SSH2 format (compress, mac, encrypt, enqueue)
*/
void
int
ssh_packet_send2_wrapped(struct ssh *ssh)
{
struct session_state *state = ssh->state;
@@ -922,7 +937,7 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
len = sshbuf_len(state->outgoing_packet);
/* skip header, compress only payload */
if ((ret = sshbuf_consume(state->outgoing_packet, 5)) != 0)
goto fail;
goto out;
sshbuf_reset(state->compression_buffer);
buffer_compress(state->outgoing_packet,
state->compression_buffer);
@@ -931,8 +946,8 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
"\0\0\0\0\0", 5)) != 0 ||
(ret = sshbuf_putb(state->outgoing_packet,
state->compression_buffer)) != 0)
goto fail;
DBG(debug("compression: raw %d compressed %d", len,
goto out;
DBG(debug("compression: raw %d compressed %zd", len,
sshbuf_len(state->outgoing_packet)));
}
@@ -958,7 +973,7 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
state->extra_pad = 0;
}
if ((ret = sshbuf_reserve(state->outgoing_packet, padlen, &cp)) != 0)
goto fail;
goto out;
if (enc && !state->send_context.plaintext) {
/* random padding */
for (i = 0; i < padlen; i++) {
@@ -988,14 +1003,14 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
/* encrypt packet and append to output buffer. */
if ((ret = sshbuf_reserve(state->output,
sshbuf_len(state->outgoing_packet), &cp)) != 0)
goto fail;
goto out;
cipher_crypt(&state->send_context, cp,
sshbuf_ptr(state->outgoing_packet),
sshbuf_len(state->outgoing_packet));
/* append unencrypted MAC */
if (mac && mac->enabled)
if ((ret = sshbuf_put(state->output, macbuf, mac->mac_len)) != 0)
goto fail;
goto out;
#ifdef PACKET_DEBUG
fprintf(stderr, "encrypted: ");
sshbuf_dump(state->output, stderr);
@@ -1014,17 +1029,18 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
ssh_set_newkeys(ssh, MODE_OUT);
else if (type == SSH2_MSG_USERAUTH_SUCCESS && state->server_side)
ssh_packet_enable_delayed_compress(ssh);
return;
fail:
fatal("%s: %s", __func__, ssh_err(ret));
ret = 0;
out:
return ret;
}
void
int
ssh_packet_send2(struct ssh *ssh)
{
struct session_state *state = ssh->state;
struct packet *p;
u_char type, *cp;
int ret;
cp = sshbuf_ptr(state->outgoing_packet);
type = cp[5];
@@ -1036,14 +1052,16 @@ ssh_packet_send2(struct ssh *ssh)
(type == SSH2_MSG_SERVICE_REQUEST) ||
(type == SSH2_MSG_SERVICE_ACCEPT)) {
debug("enqueue packet: %u", type);
p = xmalloc(sizeof(*p));
p = calloc(1, sizeof(*p));
if (p == NULL)
return SSH_ERR_ALLOC_FAIL;
p->type = type;
p->payload = state->outgoing_packet;
TAILQ_INSERT_TAIL(&state->outgoing, p, next);
state->outgoing_packet = sshbuf_new();
if (state->outgoing_packet == NULL)
fatal("%s: sshbuf_new failed", __func__);
return;
return SSH_ERR_ALLOC_FAIL;
return 0;
}
}
@@ -1051,7 +1069,8 @@ ssh_packet_send2(struct ssh *ssh)
if (type == SSH2_MSG_KEXINIT)
state->rekeying = 1;
ssh_packet_send2_wrapped(ssh);
if ((ret = ssh_packet_send2_wrapped(ssh)) != 0)
return ret;
/* after a NEWKEYS message we can send the complete queue */
if (type == SSH2_MSG_NEWKEYS) {
@@ -1063,18 +1082,20 @@ ssh_packet_send2(struct ssh *ssh)
state->outgoing_packet = p->payload;
TAILQ_REMOVE(&state->outgoing, p, next);
xfree(p);
ssh_packet_send2_wrapped(ssh);
if ((ret = ssh_packet_send2_wrapped(ssh)) != 0)
return ret;
}
}
return 0;
}
void
ssh_packet_send(struct ssh *ssh)
{
if (compat20)
ssh_packet_send2(ssh);
else
ssh_packet_send1(ssh);
int ret;
if ((ret = sshpkt_send(ssh)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
DBG(debug("packet_send done"));
}
@@ -1431,7 +1452,7 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
((ret = sshbuf_consume_end(state->incoming_packet, padlen)) != 0))
goto fail;
DBG(debug("input: len before de-compress %d",
DBG(debug("input: len before de-compress %zd",
sshbuf_len(state->incoming_packet)));
if (comp && comp->enabled) {
sshbuf_reset(state->compression_buffer);
@@ -1440,7 +1461,7 @@ ssh_packet_read_poll2(struct ssh *ssh, u_int32_t *seqnr_p)
sshbuf_reset(state->incoming_packet);
if ((ret = sshbuf_putb(state->incoming_packet, state->compression_buffer)) != 0)
goto fail;
DBG(debug("input: len after de-compress %d",
DBG(debug("input: len after de-compress %zd",
sshbuf_len(state->incoming_packet)));
}
/*
@@ -1568,10 +1589,12 @@ ssh_packet_process_incoming(struct ssh *ssh, const char *buf, u_int len)
u_int
ssh_packet_get_char(struct ssh *ssh)
{
char ch;
u_char ch;
int ret;
buffer_get(ssh->state->incoming_packet, &ch, 1);
return (u_char) ch;
if ((ret = sshpkt_get_u8(ssh, &ch)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
return ch;
}
/* Returns an integer from the packet data. */
@@ -1579,7 +1602,12 @@ ssh_packet_get_char(struct ssh *ssh)
u_int
ssh_packet_get_int(struct ssh *ssh)
{
return buffer_get_int(ssh->state->incoming_packet);
u_int val;
int ret;
if ((ret = sshpkt_get_u32(ssh, &val)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
return val;
}
/* Returns an 64 bit integer from the packet data. */
@@ -1587,7 +1615,12 @@ ssh_packet_get_int(struct ssh *ssh)
u_int64_t
ssh_packet_get_int64(struct ssh *ssh)
{
return buffer_get_int64(ssh->state->incoming_packet);
u_int64_t val;
int ret;
if ((ret = sshpkt_get_u64(ssh, &val)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
return val;
}
/*
@@ -1598,29 +1631,38 @@ ssh_packet_get_int64(struct ssh *ssh)
void
ssh_packet_get_bignum(struct ssh *ssh, BIGNUM * value)
{
buffer_get_bignum(ssh->state->incoming_packet, value);
int ret;
if ((ret = sshpkt_get_bignum1(ssh, value)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_get_bignum2(struct ssh *ssh, BIGNUM * value)
{
buffer_get_bignum2(ssh->state->incoming_packet, value);
int ret;
if ((ret = sshpkt_get_bignum2(ssh, value)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void
ssh_packet_get_ecpoint(struct ssh *ssh, const EC_GROUP *curve, EC_POINT *point)
{
buffer_get_ecpoint(ssh->state->incoming_packet, curve, point);
int ret;
if ((ret = sshpkt_get_ec(ssh, point, curve)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
}
void *
ssh_packet_get_raw(struct ssh *ssh, u_int *length_ptr)
{
u_int bytes = sshbuf_len(ssh->state->incoming_packet);
u_int bytes = sshbuf_len(ssh->state->incoming_packet);
if (length_ptr != NULL)
*length_ptr = bytes;
return sshbuf_ptr(ssh->state->incoming_packet);
if (length_ptr != NULL)
*length_ptr = bytes;
return sshbuf_ptr(ssh->state->incoming_packet);
}
int
@@ -1639,20 +1681,44 @@ ssh_packet_remaining(struct ssh *ssh)
void *
ssh_packet_get_string(struct ssh *ssh, u_int *length_ptr)
{
return buffer_get_string(ssh->state->incoming_packet, length_ptr);
int ret;
size_t len;
u_char *val;
if ((ret = sshpkt_get_string(ssh, &val, &len)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
if (length_ptr != NULL)
*length_ptr = (u_int)len;
return val;
}
const void *
ssh_packet_get_string_ptr(struct ssh *ssh, u_int *length_ptr)
{
return buffer_get_string_ptr(ssh->state->incoming_packet, length_ptr);
int ret;
size_t len;
const u_char *val;
if ((ret = sshpkt_get_string_direct(ssh, &val, &len)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
if (length_ptr != NULL)
*length_ptr = (u_int)len;
return val;
}
/* Ensures the returned string has no embedded \0 characters in it. */
char *
ssh_packet_get_cstring(struct ssh *ssh, u_int *length_ptr)
{
return buffer_get_cstring(ssh->state->incoming_packet, length_ptr);
int ret;
size_t len;
char *val;
if ((ret = sshpkt_get_cstring(ssh, &val, &len)) != 0)
fatal("%s: %s", __func__, ssh_err(ret));
if (length_ptr != NULL)
*length_ptr = (u_int)len;
return val;
}
/*
@@ -2087,3 +2153,178 @@ ssh_packet_set_postauth(struct ssh *ssh)
ssh_packet_init_compression(ssh);
}
}
/* NEW API */
/* put data to the incoming packet */
int
sshpkt_put(struct ssh *ssh, const void *v, size_t len)
{
return sshbuf_put(ssh->state->outgoing_packet, v, len);
}
int
sshpkt_put_u8(struct ssh *ssh, u_char val)
{
return sshbuf_put_u8(ssh->state->incoming_packet, val);
}
int
sshpkt_put_u32(struct ssh *ssh, u_int32_t val)
{
return sshbuf_put_u32(ssh->state->outgoing_packet, val);
}
int
sshpkt_put_u64(struct ssh *ssh, u_int64_t val)
{
return sshbuf_put_u64(ssh->state->outgoing_packet, val);
}
int
sshpkt_put_string(struct ssh *ssh, const void *v, size_t len)
{
return sshbuf_put_string(ssh->state->outgoing_packet, v, len);
}
int
sshpkt_put_cstring(struct ssh *ssh, const void *v)
{
return sshbuf_put_cstring(ssh->state->outgoing_packet, v);
}
int
sshpkt_put_ec(struct ssh *ssh, const EC_POINT *v, const EC_GROUP *g)
{
return sshbuf_put_ec(ssh->state->outgoing_packet, v, g);
}
int
sshpkt_put_bignum1(struct ssh *ssh, const BIGNUM *v)
{
return sshbuf_put_bignum1(ssh->state->outgoing_packet, v);
}
int
sshpkt_put_bignum2(struct ssh *ssh, const BIGNUM *v)
{
return sshbuf_put_bignum2(ssh->state->outgoing_packet, v);
}
/* fetch data from the incoming packet */
int
sshpkt_get_u8(struct ssh *ssh, u_char *valp)
{
return sshbuf_get_u8(ssh->state->incoming_packet, valp);
}
int
sshpkt_get_u32(struct ssh *ssh, u_int32_t *valp)
{
return sshbuf_get_u32(ssh->state->incoming_packet, valp);
}
int
sshpkt_get_u64(struct ssh *ssh, u_int64_t *valp)
{
return sshbuf_get_u64(ssh->state->incoming_packet, valp);
}
int
sshpkt_get_string(struct ssh *ssh, u_char **valp, size_t *lenp)
{
return sshbuf_get_string(ssh->state->incoming_packet, valp, lenp);
}
int
sshpkt_get_string_direct(struct ssh *ssh, const u_char **valp, size_t *lenp)
{
return sshbuf_get_string_direct(ssh->state->incoming_packet, valp, lenp);
}
int
sshpkt_get_cstring(struct ssh *ssh, char **valp, size_t *lenp)
{
return sshbuf_get_cstring(ssh->state->incoming_packet, valp, lenp);
}
int
sshpkt_get_ec(struct ssh *ssh, EC_POINT *v, const EC_GROUP *g)
{
return sshbuf_get_ec(ssh->state->incoming_packet, v, g);
}
int
sshpkt_get_bignum1(struct ssh *ssh, BIGNUM *v)
{
return sshbuf_get_bignum1(ssh->state->incoming_packet, v);
}
int
sshpkt_get_bignum2(struct ssh *ssh, BIGNUM *v)
{
return sshbuf_get_bignum2(ssh->state->incoming_packet, v);
}
int
sshpkt_get_end(struct ssh *ssh)
{
if (sshbuf_len(ssh->state->incoming_packet) > 0)
return SSH_ERR_UNEXPECTED_TRAILING_DATA;
return 0;
}
/* start a new packet */
int
sshpkt_start(struct ssh *ssh, u_char type)
{
u_char buf[9];
int len;
DBG(debug("packet_start[%d]", type));
len = compat20 ? 6 : 9;
memset(buf, 0, len - 1);
buf[len - 1] = type;
sshbuf_reset(ssh->state->outgoing_packet);
return sshbuf_put(ssh->state->outgoing_packet, buf, len);
}
/* send it */
int
sshpkt_send(struct ssh *ssh)
{
if (compat20)
return ssh_packet_send2(ssh);
else
return ssh_packet_send1(ssh);
}
int
sshpkt_disconnect(struct ssh *ssh, const char *fmt,...)
{
char buf[1024];
va_list args;
int r;
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
if (compat20) {
if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 ||
(r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 ||
(r = sshpkt_put_cstring(ssh, buf)) != 0 ||
(r = sshpkt_put_cstring(ssh, "")) != 0 ||
(r = sshpkt_send(ssh)) != 0)
return r;
} else {
if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 ||
(r = sshpkt_put_cstring(ssh, buf)) != 0 ||
(r = sshpkt_send(ssh)) != 0)
return r;
}
return 0;
}

View File

@@ -97,9 +97,9 @@ void ssh_packet_put_string(struct ssh *, const void *buf, u_int len);
void ssh_packet_put_cstring(struct ssh *, const char *str);
void ssh_packet_put_raw(struct ssh *, const void *buf, u_int len);
void ssh_packet_send(struct ssh *);
void ssh_packet_send1(struct ssh *);
void ssh_packet_send2_wrapped(struct ssh *);
void ssh_packet_send2(struct ssh *);
int ssh_packet_send1(struct ssh *);
int ssh_packet_send2_wrapped(struct ssh *);
int ssh_packet_send2(struct ssh *);
void ssh_packet_enable_delayed_compress(struct ssh *);
int ssh_packet_read(struct ssh *);
@@ -313,4 +313,30 @@ void packet_set_connection(int, int);
ssh_packet_set_postauth(active_state)
#endif
/* new API */
int sshpkt_start(struct ssh *ssh, u_char type);
int sshpkt_send(struct ssh *ssh);
int sshpkt_disconnect(struct ssh *, const char *fmt, ...) __attribute__((format(printf, 2, 3)));
int sshpkt_put(struct ssh *ssh, const void *v, size_t len);
int sshpkt_put_u8(struct ssh *ssh, u_char val);
int sshpkt_put_u32(struct ssh *ssh, u_int32_t val);
int sshpkt_put_u64(struct ssh *ssh, u_int64_t val);
int sshpkt_put_string(struct ssh *ssh, const void *v, size_t len);
int sshpkt_put_cstring(struct ssh *ssh, const void *v);
int sshpkt_put_ec(struct ssh *ssh, const EC_POINT *v, const EC_GROUP *g);
int sshpkt_put_bignum1(struct ssh *ssh, const BIGNUM *v);
int sshpkt_put_bignum2(struct ssh *ssh, const BIGNUM *v);
int sshpkt_get_u8(struct ssh *ssh, u_char *valp);
int sshpkt_get_u32(struct ssh *ssh, u_int32_t *valp);
int sshpkt_get_u64(struct ssh *ssh, u_int64_t *valp);
int sshpkt_get_string(struct ssh *ssh, u_char **valp, size_t *lenp);
int sshpkt_get_string_direct(struct ssh *ssh, const u_char **valp, size_t *lenp);
int sshpkt_get_cstring(struct ssh *ssh, char **valp, size_t *lenp);
int sshpkt_get_ec(struct ssh *ssh, EC_POINT *v, const EC_GROUP *g);
int sshpkt_get_bignum1(struct ssh *ssh, BIGNUM *v);
int sshpkt_get_bignum2(struct ssh *ssh, BIGNUM *v);
int sshpkt_get_end(struct ssh *ssh);
#endif /* PACKET_H */