Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/python/handler/py_CT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ py::UniqueObj CTModelObject::getCorrelations(PyObject* topicId) const
py::UniqueObj CTModelObject::getPriorCov() const
{
auto* inst = getInst<tomoto::ICTModel>();
auto cov = inst->getPriorCov();
if (cov.empty()) return py::buildPyValue(nullptr);
float* ptr;
auto ret = py::newEmptyArray(ptr, inst->getK(), inst->getK());
auto cov = inst->getPriorCov();
memcpy(ptr, cov.data(), sizeof(float) * inst->getK() * inst->getK());
return ret;
}
Expand Down
40 changes: 40 additions & 0 deletions test/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
curpath = os.path.dirname(os.path.realpath(__file__))
print(curpath)

def test_concat():
tokenizer = tp.utils.SimpleTokenizer()
corpus = tp.utils.Corpus(tokenizer=tokenizer)
corpus.process([
'a b 0 d e',
'a b 1 d e',
'a b d e 2',
'a b 3 d e',
'a b 4 d e',
])
cands = corpus.extract_ngrams(min_cf=5, min_df=5, normalized=True, min_score=0.5)
print(cands)
corpus.concat_ngrams(cands)
for doc in corpus:
print(doc.words, doc.span)

model_cases = [
(tp.LDAModel, curpath + '/sample.txt', 0, None, {'k':40}, None),
(tp.LLDAModel, curpath + '/sample_with_md.txt', 1, lambda x:x, {'k':5}, None),
Expand Down Expand Up @@ -65,6 +81,23 @@
(tp.PTModel, curpath + '/sample.txt', 0, None, {'k':10, 'p':100}, [tp.ParallelScheme.PARTITION]),
]

def properties(cls, inputFile, mdFields, f, kargs, ps):
print('Test properties')
tw = 0
print('Initialize model %s with TW=%s ...' % (str(cls), ['one', 'idf', 'pmi'][tw]))
mdl = cls(tw=tw, min_df=2, rm_top=2, **kargs)
all_attributes = [attr for attr in dir(mdl) if not attr.startswith('_')]
ignore_properties = {'CTModel.alpha', 'DTModel.eta'}
for attr in all_attributes:
if '{}.{}'.format(cls.__name__, attr) in ignore_properties:
print('Skipping property {}.{}'.format(cls.__name__, attr))
continue
try:
print(attr, getattr(mdl, attr), sep=': ')
except Exception as e:
print('Error accessing attribute {}: {}'.format(attr, e))
raise

def null_doc(cls, inputFile, mdFields, f, kargs, ps):
tw = 0
print('Initialize model %s with TW=%s ...' % (str(cls), ['one', 'idf', 'pmi'][tw]))
Expand Down Expand Up @@ -648,6 +681,13 @@ def test_purge_dead_topics():
mdl.train(100)
print('Iteration: {}\tLog-likelihood: {}\tNum. of topics: {}\tNum. of tables: {}'.format(i, mdl.ll_per_word, mdl.live_k, mdl.num_tables))

for model_case in model_cases:
pss = model_case[5]
if not pss: pss = [tp.ParallelScheme.DEFAULT]
for ps in pss[:1]:
for func in [properties]:
locals()['test_{}_{}_{}'.format(model_case[0].__name__, func.__name__, ps.name)] = (lambda f, mc, ps: lambda: f(*(mc + (ps,))))(func, model_case[:-1], ps)

for model_case in model_cases:
pss = model_case[5]
if not pss: pss = [tp.ParallelScheme.COPY_MERGE, tp.ParallelScheme.PARTITION]
Expand Down
11 changes: 7 additions & 4 deletions tomotopy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ def set_word_prior(self, word, prior) -> None:

@classmethod
def _summary_extract_param_desc(cls:type):
doc_string = cls.__doc__ or cls.__init__.__doc__
doc_string = cls.__init__.__doc__
if not doc_string: return {}
ps = doc_string.split('\nParameters\n')[1].split('\n')
ps = doc_string.split('Parameters\n')[1].split('\n')
param_name = re.compile(r'^([a-zA-Z0-9_]+)\s*:\s*')
directive = re.compile(r'^\s*\.\.')
descriptive = re.compile(r'\s+([^\s].*)')
Expand Down Expand Up @@ -503,7 +503,10 @@ def _summary_training_info(self, file):
print('| Log-likelihood per word: {:.5f}'.format(self.ll_per_word), file=file)

def _summary_initial_params_info(self, file):
param_desc = self._summary_extract_param_desc()
try:
param_desc = self._summary_extract_param_desc()
except:
param_desc = {}
if hasattr(self, 'init_params'):
for k, v in self.init_params.items():
if type(v) is float: fmt = ':.5'
Expand Down Expand Up @@ -1390,7 +1393,7 @@ def get_topic_word_dist(self, topic_id, normalize=True) -> List[float]:
@property
def k_g(self) -> int:
'''the hyperparameter k_g (read-only)'''
return self._k_g
return self._k

@property
def k_l(self) -> int:
Expand Down
31 changes: 30 additions & 1 deletion tomotopy/viewer/template.html
Original file line number Diff line number Diff line change
Expand Up @@ -444,17 +444,46 @@ <h3>{{get_topic_label(topic, prefix="Topic ", id_suffix=True)}}: <small>{{", ".j
var category_labels = {{categorical_metadata}};
</script>
<script>
const backgroundColors = [
'rgba(54, 162, 235, 0.5)',
'rgba(255, 99, 132, 0.5)',
'rgba(255, 159, 64, 0.5)',
'rgba(255, 205, 86, 0.5)',
'rgba(75, 192, 192, 0.5)',
'rgba(153, 102, 255, 0.5)',
'rgba(201, 203, 207, 0.5)',
'rgba(255, 0, 0, 0.5)',
'rgba(0, 255, 0, 0.5)',
'rgba(0, 0, 255, 0.5)',
];
const borderColors = [
'rgb(54, 162, 235)',
'rgb(255, 99, 132)',
'rgb(255, 159, 64)',
'rgb(255, 205, 86)',
'rgb(75, 192, 192)',
'rgb(153, 102, 255)',
'rgb(201, 203, 207)',
'rgb(255, 0, 0)',
'rgb(0, 255, 0)',
'rgb(0, 0, 255)',
];

var charts = [];
for (var i in tdf_data) {
var data = tdf_data[i];
const ctx = document.getElementById('chart-' + i);

var datasets = [];
for (var c in category_labels) {
const reorder = [1, 2, 3, 4, 5, 0, 6, 7, 8, 9];
for (var c in reorder) {
c = reorder[c];
datasets.push({
label: category_labels[c],
cubicInterpolationMode: 'monotone',
data: data[c],
backgroundColor: backgroundColors[datasets.length % backgroundColors.length],
borderColor: borderColors[datasets.length % borderColors.length],
borderWidth: 2,
pointStyle: false,
segment: {
Expand Down
46 changes: 38 additions & 8 deletions tomotopy/viewer/viewer_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,12 @@ def __init__(self, model, max_cache_size=5) -> None:
self._cached = {}
self._cached_keys = []

def _sort_and_filter(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:tuple, filter_metadata:str):
def _sort_and_filter(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:tuple, filter_metadata:str, filter_numeric_metadata:list):
results = []
for i, doc in enumerate(self.model.docs):
if filter_keyword and not all(kw in doc.raw.lower() for kw in filter_keyword): continue
if filter_metadata is not None and doc.metadata != filter_metadata: continue
if filter_numeric_metadata and not all(s <= v <= e for (s, e), v in zip(filter_numeric_metadata, doc.numeric_metadata)): continue
dist = doc.get_topic_dist()
if dist[filter_target] < filter_value: continue
if sort_key >= 0:
Expand All @@ -208,25 +209,25 @@ def _sort_and_filter(self, sort_key:int, filter_target:int, filter_value:float,
else:
return [i for _, i in sorted(results, reverse=True)]

def _get_cached_filter_result(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:str, filter_metadata:str):
def _get_cached_filter_result(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:str, filter_metadata:str, filter_numeric_metadata:list):
filter_keyword = tuple(filter_keyword.lower().split())
if sort_key < 0 and filter_value <= 0 and not filter_keyword and filter_metadata is None:
if sort_key < 0 and filter_value <= 0 and not filter_keyword and filter_metadata is None and not filter_numeric_metadata:
# return None for no filtering nor sorting
return None
key = (sort_key, filter_target, filter_value, filter_keyword, filter_metadata)
key = (sort_key, filter_target, filter_value, filter_keyword, filter_metadata, tuple(filter_numeric_metadata))
if key in self._cached:
return self._cached[key]
else:
result = self._sort_and_filter(sort_key, filter_target, filter_value, filter_keyword, filter_metadata)
result = self._sort_and_filter(sort_key, filter_target, filter_value, filter_keyword, filter_metadata, filter_numeric_metadata)
if len(self._cached_keys) >= self.max_cache_size:
del self._cached[self._cached_keys.pop(0)]
self._cached[key] = result
self._cached_keys.append(key)
return result

def get(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:str, filter_metadata:str, index:slice):
def get(self, sort_key:int, filter_target:int, filter_value:float, filter_keyword:str, filter_metadata:str, filter_numeric_metadata:list, index:slice):
# return (doc_indices, total_docs_filtered)
result = self._get_cached_filter_result(sort_key, filter_target, filter_value, filter_keyword, filter_metadata)
result = self._get_cached_filter_result(sort_key, filter_target, filter_value, filter_keyword, filter_metadata, filter_numeric_metadata)
if result is None:
return list(range(index.start, min(index.stop, len(self.model.docs)))), len(self.model.docs)
else:
Expand Down Expand Up @@ -431,6 +432,7 @@ def get_document(self):
filter_value = float(self.arguments.get('v', '0'))
filter_keyword = self.arguments.get('sq', '')
filter_metadata = int(self.arguments.get('m', '-1'))
filter_numeric_metadata = self.arguments.get('x', '')
page = int(self.arguments.get('p', '0'))

if not self.available.get('metadata') or filter_metadata < 0:
Expand All @@ -439,12 +441,19 @@ def get_document(self):
else:
md = self.model.metadata_dict[filter_metadata]

if not self.available.get('metadata') or not filter_numeric_metadata:
filter_numeric_metadata = []
else:
filter_numeric_metadata = list(map(float, filter_numeric_metadata.split(',')))
filter_numeric_metadata = list(zip(filter_numeric_metadata[::2], filter_numeric_metadata[1::2]))

doc_indices, filtered_docs = self.server.filter.get(
sort_key,
filter_target,
filter_value / 100,
filter_keyword,
md,
filter_numeric_metadata,
slice(page * self.num_docs_per_page, (page + 1) * self.num_docs_per_page)
)
total_pages = (filtered_docs + self.num_docs_per_page - 1) // self.num_docs_per_page
Expand All @@ -459,7 +468,7 @@ def get_document(self):
self.render(action='document',
page=page,
total_pages=total_pages,
filtered_docs=filtered_docs if filter_value > 0 or filter_keyword or filter_metadata >= 0 else None,
filtered_docs=filtered_docs if filter_value > 0 or filter_keyword or filter_metadata >= 0 or filter_numeric_metadata else None,
total_docs=total_docs,
documents=documents,
sort_key=sort_key,
Expand Down Expand Up @@ -800,6 +809,27 @@ def cache_tdf_map_img(self, topic_id, x, y, w, h, r, contour_interval, smooth):
contour_map[:-int(h * 0.32)] = 0
contour_map[:-int(h * 0.16), is_sub_grid] = 0
contour_map = contour_map.clip(0, 1)
else:
grid_map = np.zeros_like(contour_map)
for i in range(5):
t = (h * i) // 4
bold = i % 2 == 0
if t < h:
grid_map[t] = 0.8 if bold else 0.5
if t > 0:
grid_map[t - 1] = 0.8 if bold else 0.5
if t + 1 < h and bold:
grid_map[t + 1] = 0.8 if bold else 0.5
for i in range(24):
t = (w * i) // 23
bold = i % 5 == 4
if t < w:
grid_map[:, t] = 0.8 if bold else 0.5
if t > 0:
grid_map[:, t - 1] = 0.8 if bold else 0.5
if t + 1 < h and bold:
grid_map[:, t + 1] = 0.8 if bold else 0.5
colorized = colorized * (1 - grid_map[..., None]) + 0.85 * grid_map[..., None]
colorized *= 1 - contour_map[..., None]
img = Image.fromarray((colorized * 255).astype(np.uint8), 'RGB')
img_buf = io.BytesIO()
Expand Down
Loading