diff --git a/cipher-ctr-mt.c b/cipher-ctr-mt.c index fdc9b2f..300cd90 100644 --- a/cipher-ctr-mt.c +++ b/cipher-ctr-mt.c @@ -127,7 +127,7 @@ struct kq { u_char keys[KQLEN][AES_BLOCK_SIZE]; u_char ctr[AES_BLOCK_SIZE]; u_char pad0[CACHELINE_LEN]; - volatile int qstate; + int qstate; pthread_mutex_t lock; pthread_cond_t cond; u_char pad1[CACHELINE_LEN]; @@ -141,6 +141,11 @@ struct ssh_aes_ctr_ctx STATS_STRUCT(stats); u_char aes_counter[AES_BLOCK_SIZE]; pthread_t tid[CIPHER_THREADS]; + pthread_rwlock_t tid_lock; +#ifdef __APPLE__ + pthread_rwlock_t stop_lock; + int exit_flag; +#endif /* __APPLE__ */ int state; int qidx; int ridx; @@ -187,6 +192,57 @@ thread_loop_cleanup(void *x) pthread_mutex_unlock((pthread_mutex_t *)x); } +#ifdef __APPLE__ +/* Check if we should exit, we are doing both cancel and exit condition + * since on OSX threads seem to occasionally fail to notice when they have + * been cancelled. We want to have a backup to make sure that we won't hang + * when the main process join()-s the cancelled thread. + */ +static void +thread_loop_check_exit(struct ssh_aes_ctr_ctx *c) +{ + int exit_flag; + + pthread_rwlock_rdlock(&c->stop_lock); + exit_flag = c->exit_flag; + pthread_rwlock_unlock(&c->stop_lock); + + if (exit_flag) + pthread_exit(NULL); +} +#else +# define thread_loop_check_exit(s) +#endif /* __APPLE__ */ + +/* + * Helper function to terminate the helper threads + */ +static void +stop_and_join_pregen_threads(struct ssh_aes_ctr_ctx *c) +{ + int i; + +#ifdef __APPLE__ + /* notify threads that they should exit */ + pthread_rwlock_wrlock(&c->stop_lock); + c->exit_flag = TRUE; + pthread_rwlock_unlock(&c->stop_lock); +#endif /* __APPLE__ */ + + /* Cancel pregen threads */ + for (i = 0; i < CIPHER_THREADS; i++) { + pthread_cancel(c->tid[i]); + } + for (i = 0; i < NUMKQ; i++) { + pthread_mutex_lock(&c->q[i].lock); + pthread_cond_broadcast(&c->q[i].cond); + pthread_mutex_unlock(&c->q[i].lock); + } + for (i = 0; i < CIPHER_THREADS; i++) { + pthread_join(c->tid[i], NULL); + } +} + /* * The life of a pregen thread: * Find empty keystream queues and fill them using their counter. @@ -201,6 +257,7 @@ thread_loop(void *x) struct kq *q; int i; int qidx; + pthread_t first_tid; /* Threads stats on cancellation */ STATS_INIT(stats); @@ -211,11 +268,15 @@ thread_loop(void *x) /* Thread local copy of AES key */ memcpy(&key, &c->aes_ctx, sizeof(key)); + pthread_rwlock_rdlock(&c->tid_lock); + first_tid = c->tid[0]; + pthread_rwlock_unlock(&c->tid_lock); + /* * Handle the special case of startup, one thread must fill * the first KQ then mark it as draining. Lock held throughout. */ - if (pthread_equal(pthread_self(), c->tid[0])) { + if (pthread_equal(pthread_self(), first_tid)) { q = &c->q[0]; pthread_mutex_lock(&q->lock); if (q->qstate == KQINIT) { @@ -245,12 +306,16 @@ thread_loop(void *x) /* Check if I was cancelled, also checked in cond_wait */ pthread_testcancel(); + /* Check if we should exit as well */ + thread_loop_check_exit(c); + /* Lock queue and block if its draining */ q = &c->q[qidx]; pthread_mutex_lock(&q->lock); pthread_cleanup_push(thread_loop_cleanup, &q->lock); while (q->qstate == KQDRAINING || q->qstate == KQINIT) { STATS_WAIT(stats); + thread_loop_check_exit(c); pthread_cond_wait(&q->cond, &q->lock); } pthread_cleanup_pop(0); @@ -268,6 +333,7 @@ thread_loop(void *x) * can see that it's being filled. */ q->qstate = KQFILLING; + pthread_cond_broadcast(&q->cond); pthread_mutex_unlock(&q->lock); for (i = 0; i < KQLEN; i++) { AES_encrypt(q->ctr, q->keys[i], &key); @@ -279,7 +345,7 @@ thread_loop(void *x) ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE); q->qstate = KQFULL; STATS_FILL(stats); - pthread_cond_signal(&q->cond); + pthread_cond_broadcast(&q->cond); pthread_mutex_unlock(&q->lock); } @@ -371,6 +437,7 @@ ssh_aes_ctr(EVP_CIPHER_CTX *ctx, u_char *dest, const u_char *src, pthread_cond_wait(&q->cond, &q->lock); } q->qstate = KQDRAINING; + pthread_cond_broadcast(&q->cond); pthread_mutex_unlock(&q->lock); /* Mark consumed queue empty and signal producers */ @@ -397,6 +464,11 @@ ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv, if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL) { c = xmalloc(sizeof(*c)); + pthread_rwlock_init(&c->tid_lock, NULL); +#ifdef __APPLE__ + pthread_rwlock_init(&c->stop_lock, NULL); + c->exit_flag = FALSE; +#endif /* __APPLE__ */ c->state = HAVE_NONE; for (i = 0; i < NUMKQ; i++) { @@ -409,11 +481,14 @@ ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv, } if (c->state == (HAVE_KEY | HAVE_IV)) { - /* Cancel pregen threads */ - for (i = 0; i < CIPHER_THREADS; i++) - pthread_cancel(c->tid[i]); - for (i = 0; i < CIPHER_THREADS; i++) - pthread_join(c->tid[i], NULL); + /* tell the pregen threads to exit */ + stop_and_join_pregen_threads(c); + +#ifdef __APPLE__ + /* reset the exit flag */ + c->exit_flag = FALSE; +#endif /* __APPLE__ */ + /* Start over getting key & iv */ c->state = HAVE_NONE; } @@ -444,10 +519,12 @@ ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv, /* Start threads */ for (i = 0; i < CIPHER_THREADS; i++) { debug("spawned a thread"); + pthread_rwlock_wrlock(&c->tid_lock); pthread_create(&c->tid[i], NULL, thread_loop, c); + pthread_rwlock_unlock(&c->tid_lock); } pthread_mutex_lock(&c->q[0].lock); - while (c->q[0].qstate != KQDRAINING) + while (c->q[0].qstate == KQINIT) pthread_cond_wait(&c->q[0].cond, &c->q[0].lock); pthread_mutex_unlock(&c->q[0].lock); } @@ -461,15 +538,10 @@ void ssh_aes_ctr_thread_destroy(EVP_CIPHER_CTX *ctx) { struct ssh_aes_ctr_ctx *c; - int i; + c = EVP_CIPHER_CTX_get_app_data(ctx); - /* destroy threads */ - for (i = 0; i < CIPHER_THREADS; i++) { - pthread_cancel(c->tid[i]); - } - for (i = 0; i < CIPHER_THREADS; i++) { - pthread_join(c->tid[i], NULL); - } + + stop_and_join_pregen_threads(c); } void @@ -481,7 +553,9 @@ ssh_aes_ctr_thread_reconstruction(EVP_CIPHER_CTX *ctx) /* reconstruct threads */ for (i = 0; i < CIPHER_THREADS; i++) { debug("spawned a thread"); + pthread_rwlock_wrlock(&c->tid_lock); pthread_create(&c->tid[i], NULL, thread_loop, c); + pthread_rwlock_unlock(&c->tid_lock); } } @@ -489,18 +563,13 @@ static int ssh_aes_ctr_cleanup(EVP_CIPHER_CTX *ctx) { struct ssh_aes_ctr_ctx *c; - int i; if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) != NULL) { #ifdef CIPHER_THREAD_STATS debug("main thread: %u drains, %u waits", c->stats.drains, c->stats.waits); #endif - /* Cancel pregen threads */ - for (i = 0; i < CIPHER_THREADS; i++) - pthread_cancel(c->tid[i]); - for (i = 0; i < CIPHER_THREADS; i++) - pthread_join(c->tid[i], NULL); + stop_and_join_pregen_threads(c); memset(c, 0, sizeof(*c)); free(c);