diff --git a/ssh/ssh-proxy.c b/ssh/ssh-proxy.c index cbefaf5..c4916e9 100644 --- a/ssh/ssh-proxy.c +++ b/ssh/ssh-proxy.c @@ -41,10 +41,12 @@ struct side { struct event input, output; struct ssh *ssh; }; +#define SESSION_CONNECTED 0x01 +#define SESSION_NEEDS_FLUSH 0x02 struct session { struct side client, server; - int connected; TAILQ_ENTRY(session) next; + int flags; }; Forward fwd; @@ -57,6 +59,7 @@ int do_connect(const char *, int); int do_listen(const char *, int); void session_close(struct session *); int ssh_packet_fwd(struct side *, struct side *); +int ssh_prepare_output(struct side *); void usage(void); uid_t original_real_uid; /* XXX */ @@ -165,7 +168,7 @@ session_close(struct session *s) close(s->client.fd); if (s->server.fd != -1) close(s->server.fd); - if (s->connected == 1) { + if (s->flags & SESSION_CONNECTED) { event_del(&s->client.input); event_del(&s->client.output); event_del(&s->server.input); @@ -262,7 +265,7 @@ connect_cb(int fd, short type, void *arg) event_set(&s->server.output, s->server.fd, EV_WRITE, output_cb, s); event_add(&s->server.input, NULL); event_add(&s->client.input, NULL); - s->connected = 1; + s->flags = SESSION_CONNECTED; TAILQ_INSERT_TAIL(&sessions, s, next); return; fail: @@ -272,6 +275,20 @@ connect_cb(int fd, short type, void *arg) return; } +/* schedule output event and return 1 if there is any output pending */ +int +ssh_prepare_output(struct side *side) +{ + u_int len; + + ssh_output_ptr(side->ssh, &len); + if (len) { + debug3("output %d for %d", len, side->fd); + event_add(&side->output, NULL); + } + return len > 0; +} + int ssh_packet_fwd(struct side *from, struct side *to) { @@ -286,7 +303,7 @@ ssh_packet_fwd(struct side *from, struct side *to) return ret; if (!type) { debug3("no packet on %d", from->fd); - break; + return 0; } data = ssh_packet_payload(from->ssh, &len); debug("ssh_packet_fwd %d->%d type %d len %d", @@ -306,17 +323,6 @@ ssh_packet_fwd(struct side *from, struct side *to) if ((ret = ssh_packet_put(to->ssh, type, data, len)) != 0) return ret; } - ssh_output_ptr(from->ssh, &len); - if (len) { - debug3("output %d for %d", len, from->fd); - event_add(&from->output, NULL); - } - ssh_output_ptr(to->ssh, &len); - if (len) { - debug3("output %d for %d", len, to->fd); - event_add(&to->output, NULL); - } - return 0; } void @@ -326,8 +332,8 @@ input_cb(int fd, short type, void *arg) struct session *s = arg; struct side *r, *w; ssize_t len; + int pending, r1, r2; const char *tag; - int ret; if (fd == s->client.fd) { tag = "client"; @@ -351,10 +357,15 @@ input_cb(int fd, short type, void *arg) event_add(&r->input, NULL); ssh_input_append(r->ssh, buf, len); } - if ((ret = ssh_packet_fwd(r, w)) != 0 || - (ret = ssh_packet_fwd(w, r)) != 0) { - error("ssh_packet_fwd: %s", ssh_err(ret)); - session_close(s); + r1 = ssh_packet_fwd(r, w); + r2 = ssh_packet_fwd(w, r); + pending = ssh_prepare_output(r) + ssh_prepare_output(w); + if (r1 || r1) { + error("ssh_packet_fwd: %s/%s", ssh_err(r1), ssh_err(r2)); + if (pending) + s->flags |= SESSION_NEEDS_FLUSH; + else + session_close(s); } } @@ -363,9 +374,9 @@ output_cb(int fd, short type, void *arg) { struct session *s = arg; struct side *r, *w; - ssize_t len; + ssize_t len, olen; + int pending; const char *tag; - ssize_t olen; char *obuf; if (fd == s->client.fd) { @@ -396,8 +407,16 @@ output_cb(int fd, short type, void *arg) ssh_output_consume(w->ssh, len); } } - ssh_packet_fwd(r, w); - ssh_packet_fwd(w, r); + if (!(s->flags & SESSION_NEEDS_FLUSH)) { + ssh_packet_fwd(r, w); + ssh_packet_fwd(w, r); + } + pending = ssh_prepare_output(r) + ssh_prepare_output(w); + if ((s->flags & SESSION_NEEDS_FLUSH) && !pending) { + debug("delayed close %p", s); + s->flags &= ~SESSION_NEEDS_FLUSH; + session_close(s); + } } void