/*	$NetBSD: lr0.c,v 1.14 2024/09/14 21:29:02 christos Exp $	*/

/* Id: lr0.c,v 1.21 2021/05/20 23:57:23 tom Exp  */

#include "defs.h"

#include <sys/cdefs.h>
__RCSID("$NetBSD: lr0.c,v 1.14 2024/09/14 21:29:02 christos Exp $");

static core *new_state(int symbol);
static Value_t get_state(int symbol);
static void allocate_itemsets(void);
static void allocate_storage(void);
static void append_states(void);
static void free_storage(void);
static void generate_states(void);
static void initialize_states(void);
static void new_itemsets(void);
static void save_reductions(void);
static void save_shifts(void);
static void set_derives(void);
static void set_nullable(void);

Value_t nstates;
core *first_state;
shifts *first_shift;
reductions *first_reduction;

static core **state_set;
static core *this_state;
static core *last_state;
static shifts *last_shift;
static reductions *last_reduction;

static int nshifts;
static Value_t *shift_symbol;

static Value_t *rules;

static Value_t *redset;
static Value_t *shiftset;

static Value_t **kernel_base;
static Value_t **kernel_end;
static Value_t *kernel_items;

static void
allocate_itemsets(void)
{
    Value_t *itemp;
    Value_t *item_end;
    int i;
    int count;
    int max;
    Value_t *symbol_count;

    count = 0;
    symbol_count = NEW2(nsyms, Value_t);

    item_end = ritem + nitems;
    for (itemp = ritem; itemp < item_end; itemp++)
    {
	int symbol = *itemp;

	if (symbol >= 0)
	{
	    count++;
	    symbol_count[symbol]++;
	}
    }

    kernel_base = NEW2(nsyms, Value_t *);
    kernel_items = NEW2(count, Value_t);

    count = 0;
    max = 0;
    for (i = 0; i < nsyms; i++)
    {
	kernel_base[i] = kernel_items + count;
	count += symbol_count[i];
	if (max < symbol_count[i])
	    max = symbol_count[i];
    }

    shift_symbol = symbol_count;
    kernel_end = NEW2(nsyms, Value_t *);
}

static void
allocate_storage(void)
{
    allocate_itemsets();
    shiftset = NEW2(nsyms, Value_t);
    redset = NEW2(nrules + 1, Value_t);
    state_set = NEW2(nitems, core *);
}

static void
append_states(void)
{
    int i;
    Value_t symbol;

#ifdef	TRACE
    fprintf(stderr, "Entering append_states()\n");
#endif
    for (i = 1; i < nshifts; i++)
    {
	int j = i;

	symbol = shift_symbol[i];
	while (j > 0 && shift_symbol[j - 1] > symbol)
	{
	    shift_symbol[j] = shift_symbol[j - 1];
	    j--;
	}
	shift_symbol[j] = symbol;
    }

    for (i = 0; i < nshifts; i++)
    {
	symbol = shift_symbol[i];
	shiftset[i] = get_state(symbol);
    }
}

static void
free_storage(void)
{
    FREE(shift_symbol);
    FREE(redset);
    FREE(shiftset);
    FREE(kernel_base);
    FREE(kernel_end);
    FREE(kernel_items);
    FREE(state_set);
}

static void
generate_states(void)
{
    allocate_storage();
    itemset = NEW2(nitems, Value_t);
    ruleset = NEW2(WORDSIZE(nrules), unsigned);
    set_first_derives();
    initialize_states();

    while (this_state)
    {
	closure(this_state->items, this_state->nitems);
	save_reductions();
	new_itemsets();
	append_states();

	if (nshifts > 0)
	    save_shifts();

	this_state = this_state->next;
    }

    free_storage();
}

static Value_t
get_state(int symbol)
{
    int key;
    Value_t *isp1;
    Value_t *iend;
    core *sp;
    int n;

#ifdef	TRACE
    fprintf(stderr, "Entering get_state(%d)\n", symbol);
#endif

    isp1 = kernel_base[symbol];
    iend = kernel_end[symbol];
    n = (int)(iend - isp1);

    key = *isp1;
    assert(0 <= key && key < nitems);
    sp = state_set[key];
    if (sp)
    {
	int found = 0;

	while (!found)
	{
	    if (sp->nitems == n)
	    {
		Value_t *isp2;

		found = 1;
		isp1 = kernel_base[symbol];
		isp2 = sp->items;

		while (found && isp1 < iend)
		{
		    if (*isp1++ != *isp2++)
			found = 0;
		}
	    }

	    if (!found)
	    {
		if (sp->link)
		{
		    sp = sp->link;
		}
		else
		{
		    sp = sp->link = new_state(symbol);
		    found = 1;
		}
	    }
	}
    }
    else
    {
	state_set[key] = sp = new_state(symbol);
    }

    return (sp->number);
}

static void
initialize_states(void)
{
    unsigned i;
    Value_t *start_derives;
    core *p;

    start_derives = derives[start_symbol];
    for (i = 0; start_derives[i] >= 0; ++i)
	continue;

    p = (core *)MALLOC(sizeof(core) + i * sizeof(Value_t));
    NO_SPACE(p);

    p->next = 0;
    p->link = 0;
    p->number = 0;
    p->accessing_symbol = 0;
    p->nitems = (Value_t)i;

    for (i = 0; start_derives[i] >= 0; ++i)
	p->items[i] = rrhs[start_derives[i]];

    first_state = last_state = this_state = p;
    nstates = 1;
}

static void
new_itemsets(void)
{
    Value_t i;
    int shiftcount;
    Value_t *isp;
    Value_t *ksp;

    for (i = 0; i < nsyms; i++)
	kernel_end[i] = 0;

    shiftcount = 0;
    isp = itemset;
    while (isp < itemsetend)
    {
	int j = *isp++;
	Value_t symbol = ritem[j];

	if (symbol > 0)
	{
	    ksp = kernel_end[symbol];
	    if (!ksp)
	    {
		shift_symbol[shiftcount++] = symbol;
		ksp = kernel_base[symbol];
	    }

	    *ksp++ = (Value_t)(j + 1);
	    kernel_end[symbol] = ksp;
	}
    }

    nshifts = shiftcount;
}

static core *
new_state(int symbol)
{
    unsigned n;
    core *p;
    Value_t *isp1;
    Value_t *isp2;
    Value_t *iend;

#ifdef	TRACE
    fprintf(stderr, "Entering new_state(%d)\n", symbol);
#endif

    if (nstates >= MAXYYINT)
	fatal("too many states");

    isp1 = kernel_base[symbol];
    iend = kernel_end[symbol];
    n = (unsigned)(iend - isp1);

    p = (core *)allocate((sizeof(core) + (n - 1) * sizeof(Value_t)));
    p->accessing_symbol = (Value_t)symbol;
    p->number = (Value_t)nstates;
    p->nitems = (Value_t)n;

    isp2 = p->items;
    while (isp1 < iend)
	*isp2++ = *isp1++;

    last_state->next = p;
    last_state = p;

    nstates++;

    return (p);
}

/* show_cores is used for debugging */
#ifdef DEBUG
void
show_cores(void)
{
    core *p;
    int i, j, k, n;
    int itemno;

    k = 0;
    for (p = first_state; p; ++k, p = p->next)
    {
	if (k)
	    printf("\n");
	printf("state %d, number = %d, accessing symbol = %s\n",
	       k, p->number, symbol_name[p->accessing_symbol]);
	n = p->nitems;
	for (i = 0; i < n; ++i)
	{
	    itemno = p->items[i];
	    printf("%4d  ", itemno);
	    j = itemno;
	    while (ritem[j] >= 0)
		++j;
	    printf("%s :", symbol_name[rlhs[-ritem[j]]]);
	    j = rrhs[-ritem[j]];
	    while (j < itemno)
		printf(" %s", symbol_name[ritem[j++]]);
	    printf(" .");
	    while (ritem[j] >= 0)
		printf(" %s", symbol_name[ritem[j++]]);
	    printf("\n");
	    fflush(stdout);
	}
    }
}

/* show_ritems is used for debugging */

void
show_ritems(void)
{
    int i;

    for (i = 0; i < nitems; ++i)
	printf("ritem[%d] = %d\n", i, ritem[i]);
}

/* show_rrhs is used for debugging */
void
show_rrhs(void)
{
    int i;

    for (i = 0; i < nrules; ++i)
	printf("rrhs[%d] = %d\n", i, rrhs[i]);
}

/* show_shifts is used for debugging */

void
show_shifts(void)
{
    shifts *p;
    int i, j, k;

    k = 0;
    for (p = first_shift; p; ++k, p = p->next)
    {
	if (k)
	    printf("\n");
	printf("shift %d, number = %d, nshifts = %d\n", k, p->number,
	       p->nshifts);
	j = p->nshifts;
	for (i = 0; i < j; ++i)
	    printf("\t%d\n", p->shift[i]);
    }
}
#endif

static void
save_shifts(void)
{
    shifts *p;
    Value_t *sp1;
    Value_t *sp2;
    Value_t *send;

    p = (shifts *)allocate((sizeof(shifts) +
			      (unsigned)(nshifts - 1) * sizeof(Value_t)));

    p->number = this_state->number;
    p->nshifts = (Value_t)nshifts;

    sp1 = shiftset;
    sp2 = p->shift;
    send = shiftset + nshifts;

    while (sp1 < send)
	*sp2++ = *sp1++;

    if (last_shift)
    {
	last_shift->next = p;
	last_shift = p;
    }
    else
    {
	first_shift = p;
	last_shift = p;
    }
}

static void
save_reductions(void)
{
    Value_t *isp;
    Value_t *rp1;
    Value_t count;
    reductions *p;

    count = 0;
    for (isp = itemset; isp < itemsetend; isp++)
    {
	int item = ritem[*isp];

	if (item < 0)
	{
	    redset[count++] = (Value_t)-item;
	}
    }

    if (count)
    {
	Value_t *rp2;
	Value_t *rend;

	p = (reductions *)allocate((sizeof(reductions) +
				      (unsigned)(count - 1) *
				    sizeof(Value_t)));

	p->number = this_state->number;
	p->nreds = count;

	rp1 = redset;
	rp2 = p->rules;
	rend = rp1 + count;

	while (rp1 < rend)
	    *rp2++ = *rp1++;

	if (last_reduction)
	{
	    last_reduction->next = p;
	    last_reduction = p;
	}
	else
	{
	    first_reduction = p;
	    last_reduction = p;
	}
    }
}

static void
set_derives(void)
{
    Value_t i, k;
    int lhs;

    derives = NEW2(nsyms, Value_t *);
    rules = NEW2(nvars + nrules, Value_t);

    k = 0;
    for (lhs = start_symbol; lhs < nsyms; lhs++)
    {
	derives[lhs] = rules + k;
	for (i = 0; i < nrules; i++)
	{
	    if (rlhs[i] == lhs)
	    {
		rules[k] = i;
		k++;
	    }
	}
	rules[k] = -1;
	k++;
    }

#ifdef	DEBUG
    print_derives();
#endif
}

#ifdef	DEBUG
void
print_derives(void)
{
    int i;
    Value_t *sp;

    printf("\nDERIVES\n\n");

    for (i = start_symbol; i < nsyms; i++)
    {
	printf("%s derives ", symbol_name[i]);
	for (sp = derives[i]; *sp >= 0; sp++)
	{
	    printf("  %d", *sp);
	}
	putchar('\n');
    }

    putchar('\n');
}
#endif

static void
set_nullable(void)
{
    int i, j;
    int empty;
    int done_flag;

    nullable = TMALLOC(char, nsyms);
    NO_SPACE(nullable);

    for (i = 0; i < nsyms; ++i)
	nullable[i] = 0;

    done_flag = 0;
    while (!done_flag)
    {
	done_flag = 1;
	for (i = 1; i < nitems; i++)
	{
	    empty = 1;
	    while ((j = ritem[i]) >= 0)
	    {
		if (!nullable[j])
		    empty = 0;
		++i;
	    }
	    if (empty)
	    {
		j = rlhs[-j];
		if (!nullable[j])
		{
		    nullable[j] = 1;
		    done_flag = 0;
		}
	    }
	}
    }

#ifdef DEBUG
    for (i = 0; i < nsyms; i++)
    {
	if (nullable[i])
	    printf("%s is nullable\n", symbol_name[i]);
	else
	    printf("%s is not nullable\n", symbol_name[i]);
    }
#endif
}

void
lr0(void)
{
    set_derives();
    set_nullable();
    generate_states();
}

#ifdef NO_LEAKS
void
lr0_leaks(void)
{
    if (derives)
    {
	if (derives[start_symbol] != rules)
	{
	    DO_FREE(derives[start_symbol]);
	}
	DO_FREE(derives);
	DO_FREE(rules);
    }
    DO_FREE(nullable);
}
#endif