Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions py/visdom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,47 @@ def get_window_data(self, win=None, env=None):
create=False,
)

def set_tags(self, tags, env=None, append=False):
"""
This function sets tags for a specified environment.
If append is True, tags are added to the existing ones.
Otherwise, tags are replaced.
"""
if isinstance(tags, str):
tags = [tags]

return self._send(
msg={
"eid": env,
"tags": tags,
"append": append,
},
endpoint="tags",
create=False,
)

def get_tags(self, env=None):
"""
This function returns the tags for a specified environment.
"""
if env is None:
env = self.env

try:
url = "{0}:{1}{2}/tags".format(
self.server, self.port, self.base_url
)
# We use a custom GET or POST here. Our handler currently only has POST.
# But let's check if we should support GET for simpler retrieval.
# For now, let's use the POST endpoint with just the eid.
return self._send(
msg={"eid": env},
endpoint="tags",
create=False,
)
except Exception:
return []

def set_window_data(self, data, win=None, env=None):
"""
This function sets all the window data for a specified window in
Expand Down
34 changes: 32 additions & 2 deletions py/visdom/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
all of the required state about the currently running server.
"""

import json
import logging
import os
import platform
Expand Down Expand Up @@ -42,6 +43,7 @@
SaveHandler,
UpdateHandler,
UserSettingsHandler,
TagsHandler,
)
from visdom.server.defaults import (
DEFAULT_BASE_URL,
Expand Down Expand Up @@ -111,6 +113,7 @@ def __init__(
(r"%s/delete_env" % self.base_url, DeleteEnvHandler, {"app": self}),
(r"%s/env_state" % self.base_url, EnvStateHandler, {"app": self}),
(r"%s/fork_env" % self.base_url, ForkEnvHandler, {"app": self}),
(r"%s/tags" % self.base_url, TagsHandler, {"app": self}),
(r"%s/user/(.*)" % self.base_url, UserSettingsHandler, {"app": self}),
(r"%s(.*)" % self.base_url, IndexHandler, {"app": self}),
]
Expand Down Expand Up @@ -163,7 +166,13 @@ def load_state(self):
)
return {"main": {"jsons": {}, "reload": {}}}
ensure_dir_exists(env_path)
env_jsons = [i for i in os.listdir(env_path) if ".json" in i]
env_jsons = [
i
for i in os.listdir(env_path)
if i.endswith(".json") and i != "tags_index.json"
]
self.tags = self.load_tag_index()

for env_json in env_jsons:
eid = env_json.replace(".json", "")
env_path_file = os.path.join(env_path, env_json)
Expand All @@ -185,11 +194,32 @@ def load_state(self):
state[eid] = LazyEnvData(env_path_file)

if "main" not in state and "main.json" not in env_jsons:
state["main"] = {"jsons": {}, "reload": {}}
state["main"] = {"jsons": {}, "reload": {}, "tags": []}
serialize_env(state, ["main"], env_path=self.env_path)

return state

def load_tag_index(self):
index_path = os.path.join(self.env_path, "tags_index.json")
if os.path.exists(index_path):
try:
with open(index_path, "r") as f:
return json.load(f)
except Exception:
logging.warn(f"Failed to load tag index at {index_path}")
return {}

def save_tag_index(self):
index_path = os.path.join(self.env_path, "tags_index.json")
try:
from visdom.utils.server_utils import atomic_save

atomic_save(index_path, json.dumps(self.tags))
except Exception:
import traceback

logging.warn(f"Failed to save tag index at {index_path}: {traceback.format_exc()}")

def load_user_settings(self):
settings = {}

Expand Down
2 changes: 2 additions & 0 deletions py/visdom/server/handlers/socket_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from visdom.utils.server_utils import (
check_auth,
broadcast_envs,
sync_tags,
serialize_env,
send_to_sources,
broadcast,
Expand Down Expand Up @@ -277,6 +278,7 @@ def open(self):
)
self.broadcast_layouts([self])
broadcast_envs(self, [self])
sync_tags(self, [self])

def broadcast_layouts(self, target_subs=None):
if target_subs is None:
Expand Down
58 changes: 58 additions & 0 deletions py/visdom/server/handlers/web_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
compare_envs,
load_env,
broadcast,
broadcast_tags,
sync_tags,
update_window,
hash_password,
stringify,
Expand Down Expand Up @@ -677,6 +679,62 @@ def get(self, path):
self.write(self.user_settings["user_css"])


class TagsHandler(BaseHandler):
def initialize(self, app):
self.state = app.state
self.env_path = app.env_path
self.subs = app.subs
self.login_enabled = app.login_enabled
self.app = app

@check_auth
def post(self):
logging.info("TagsHandler.post called")
args = tornado.escape.json_decode(
tornado.escape.to_basestring(self.request.body)
)
eid = extract_eid(args)

if "tags" in args:
tags = args.get("tags", [])
append = args.get("append", False)

if eid not in self.state:
self.state[eid] = {"jsons": {}, "reload": {}, "tags": []}

if append:
current_tags = set(self.state[eid].get("tags", []))
current_tags.update(tags)
self.state[eid]["tags"] = list(current_tags)
else:
self.state[eid]["tags"] = list(set(tags))

# Update global index
self.app.tags[eid] = self.state[eid]["tags"]
self.app.save_tag_index()

# Broadcast update
broadcast_tags(self, eid, self.state[eid]["tags"])

# Async save env
serialize_env(self.state, [eid], env_path=self.env_path)

res = json.dumps(self.state[eid]["tags"])
logging.info(f"TagsHandler: (SET) returning {res}")
self.write(res)
else:
# This is a GET request (sent via POST for convenience in SDK)
if eid in self.state:
res = json.dumps(self.state[eid].get("tags", []))
elif eid in self.app.tags:
res = json.dumps(self.app.tags[eid])
else:
res = json.dumps([])

logging.info(f"TagsHandler: (GET) returning {res}")
self.write(res)


class ErrorHandler(BaseHandler):
def get(self, text):
error_text = text or "test error"
Expand Down
60 changes: 54 additions & 6 deletions py/visdom/utils/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import json
import logging
import os
import tempfile
import threading
import time
import tornado.escape
from collections import OrderedDict
from collections import OrderedDict, defaultdict

try:
# for after python 3.8
Expand Down Expand Up @@ -78,6 +80,29 @@ def hash_password(password):
# ------- File management helprs ----- #


WRITE_LOCKS = defaultdict(threading.Lock)


def atomic_save(path, data):
"""
Atomic write to a file using a temporary file and os.replace.
Ensures that the file is either fully written or not written at all.
"""
dir_path = os.path.dirname(path)
with tempfile.NamedTemporaryFile("w", dir=dir_path, delete=False) as tf:
temp_name = tf.name
tf.write(data)
tf.flush()
os.fsync(tf.fileno())

try:
os.replace(temp_name, path)
except Exception:
if os.path.exists(temp_name):
os.remove(temp_name)
raise


class LazyEnvData(Mapping):
def __init__(self, env_path_file):
self._env_path_file = env_path_file
Expand All @@ -96,7 +121,11 @@ def lazy_load_data(self):
self._env_path_file, repr(e)
)
)
self._raw_dict = {"jsons": env_data["jsons"], "reload": env_data["reload"]}
self._raw_dict = {
"jsons": env_data["jsons"],
"reload": env_data["reload"],
"tags": env_data.get("tags", []),
}

def __getitem__(self, key):
self.lazy_load_data()
Expand All @@ -119,12 +148,13 @@ def serialize_env(state, eids, env_path=DEFAULT_ENV_PATH):
env_ids = [i for i in eids if i in state]
if env_path is not None:
for env_id in env_ids:
env_path_file = os.path.join(env_path, "{0}.json".format(env_id))
with open(env_path_file, "w") as fn:
with WRITE_LOCKS[env_id]:
env_path_file = os.path.join(env_path, "{0}.json".format(env_id))
if isinstance(state[env_id], LazyEnvData):
fn.write(json.dumps(state[env_id]._raw_dict))
data = json.dumps(state[env_id]._raw_dict)
else:
fn.write(json.dumps(state[env_id]))
data = json.dumps(state[env_id])
atomic_save(env_path_file, data)
return env_ids


Expand Down Expand Up @@ -370,6 +400,24 @@ def broadcast_envs(handler, target_subs=None):
)


def broadcast_tags(handler, eid, tags, target_subs=None):
if target_subs is None:
target_subs = handler.subs.values()
for sub in target_subs:
sub.write_message(
json.dumps({"command": "tags_update", "data": {"eid": eid, "tags": tags}})
)


def sync_tags(handler, target_subs=None):
if target_subs is None:
target_subs = handler.subs.values()
# Use the app's tag index for fast synchronization
tags_map = handler.app.tags
for sub in target_subs:
sub.write_message(json.dumps({"command": "tags_sync", "data": tags_map}))


def send_to_sources(handler, msg):
target_sources = handler.sources.values()
for source in target_sources:
Expand Down