#!/usr/bin/python
# solver for https://sutom.nocle.fr/
#
# example:
# >>> import soluce
# >>> solver = soluce.SutomSolver('/usr/share/dict/french', 7)
# >>> solver.AddHint(['b', None, None, None, None, None, None], set(), set())
# >>> solver.PickCandidate()
# 'bâchage'
# >>> solver.AddHint(['b', None, None, None, None, None, None], set(), set('achage'))
# >>> solver.PickCandidate()
# 'bâfrons'
# >>> solver.AddHint(['b', None, None, 'r', None, None, None], set('on'), set('âfs'))
# >>> solver.PickCandidate()
# 'boiront'
# >>> solver.AddHint(['b', 'o', None, 'r', None, None, None], set('in'), set('t'))
# >>> solver.PickCandidate()
# 'bourrin'

from typing import Text, List, Optional, Set
import re

DICT_PATH = '/usr/share/dict/french'

# TODO: figure out what to do with non-ASCII
#       either normalise all to ASCII
#       or exclude words that have non-ASCII characters
class SutomSolver(object):
  def __init__(self, dictionary_path: Text, word_length: int):
    assert word_length > 0
    self.length = word_length
    self.fixed = [None] * self.length
    self.floating = set()
    self.excluded = set()
    self.dict = dictionary_path

  def AddHint(self,
              fixed_chars: List[Optional[Text]],
              floating_chars: Set[Text],
              excluded_chars: Set[Text]):
    assert len(fixed_chars) == self.length
    assert floating_chars.isdisjoint(excluded_chars)
    for i in range(len(fixed_chars)):
      if not self.fixed[i]:
        self.fixed[i] = fixed_chars[i]
      else:
        assert self.fixed[i] == fixed_chars[i]
    self.excluded.update(excluded_chars)
    self.floating.update(floating_chars)
    self.floating.difference_update(self.excluded)

  @staticmethod
  def BuildFixedRegex(fixed_characters):
    regex = ''
    for c in fixed_characters:
      if not c:
        regex += '.'
      else:
        regex += c
    return regex

  def SeenChars(self):
    seen = self.excluded.copy()
    seen.update(self.floating)
    seen.update(self.fixed)
    seen.discard(None)
    return seen

  def GetWordScore(self, word):
    new_chars = set(word).difference(self.SeenChars())
    # TODO: score by char frequency among remaining candidates
    return len(new_chars)

  def GetCandidates(self):
    fixed_re = self.BuildFixedRegex(self.fixed)
    candidates = []  # type: Tuple[Text, int]
    with open(self.dict) as f:
      for candidate in f:
        candidate = candidate.rstrip()
        if not re.fullmatch(fixed_re, candidate):
          continue
        candidate_letters = set(candidate)
        if not self.floating.issubset(candidate_letters):
          continue
        if not self.excluded.isdisjoint(candidate_letters.difference(self.fixed)):
          continue
        candidates.append((candidate, self.GetWordScore(candidate)))
    candidates.sort(key=lambda x: x[1], reverse=True)
    return candidates

  def PickCandidate(self):
    return self.GetCandidates()[0][0]
