CARVIEW |
Select Language
HTTP/2 200
date: Sun, 12 Oct 2025 00:50:50 GMT
content-type: text/html; charset=utf-8
vary: X-PJAX, X-PJAX-Container, Turbo-Visit, Turbo-Frame, X-Requested-With,Accept-Encoding, Accept, X-Requested-With
etag: W/"24577ea9fb43f16304940319a949bd78"
cache-control: max-age=0, private, must-revalidate
strict-transport-security: max-age=31536000; includeSubdomains; preload
x-frame-options: deny
x-content-type-options: nosniff
x-xss-protection: 0
referrer-policy: origin-when-cross-origin, strict-origin-when-cross-origin
content-security-policy: default-src 'none'; base-uri 'self'; child-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/; connect-src 'self' uploads.github.com www.githubstatus.com collector.github.com raw.githubusercontent.com api.github.com github-cloud.s3.amazonaws.com github-production-repository-file-5c1aeb.s3.amazonaws.com github-production-upload-manifest-file-7fdce7.s3.amazonaws.com github-production-user-asset-6210df.s3.amazonaws.com *.rel.tunnels.api.visualstudio.com wss://*.rel.tunnels.api.visualstudio.com github.githubassets.com objects-origin.githubusercontent.com copilot-proxy.githubusercontent.com proxy.individual.githubcopilot.com proxy.business.githubcopilot.com proxy.enterprise.githubcopilot.com *.actions.githubusercontent.com wss://*.actions.githubusercontent.com productionresultssa0.blob.core.windows.net/ productionresultssa1.blob.core.windows.net/ productionresultssa2.blob.core.windows.net/ productionresultssa3.blob.core.windows.net/ productionresultssa4.blob.core.windows.net/ productionresultssa5.blob.core.windows.net/ productionresultssa6.blob.core.windows.net/ productionresultssa7.blob.core.windows.net/ productionresultssa8.blob.core.windows.net/ productionresultssa9.blob.core.windows.net/ productionresultssa10.blob.core.windows.net/ productionresultssa11.blob.core.windows.net/ productionresultssa12.blob.core.windows.net/ productionresultssa13.blob.core.windows.net/ productionresultssa14.blob.core.windows.net/ productionresultssa15.blob.core.windows.net/ productionresultssa16.blob.core.windows.net/ productionresultssa17.blob.core.windows.net/ productionresultssa18.blob.core.windows.net/ productionresultssa19.blob.core.windows.net/ github-production-repository-image-32fea6.s3.amazonaws.com github-production-release-asset-2e65be.s3.amazonaws.com insights.github.com wss://alive.github.com wss://alive-staging.github.com api.githubcopilot.com api.individual.githubcopilot.com api.business.githubcopilot.com api.enterprise.githubcopilot.com; font-src github.githubassets.com; form-action 'self' github.com gist.github.com copilot-workspace.githubnext.com objects-origin.githubusercontent.com; frame-ancestors 'none'; frame-src viewscreen.githubusercontent.com notebooks.githubusercontent.com; img-src 'self' data: blob: github.githubassets.com media.githubusercontent.com camo.githubusercontent.com identicons.github.com avatars.githubusercontent.com private-avatars.githubusercontent.com github-cloud.s3.amazonaws.com objects.githubusercontent.com release-assets.githubusercontent.com secured-user-images.githubusercontent.com/ user-images.githubusercontent.com/ private-user-images.githubusercontent.com opengraph.githubassets.com marketplace-screenshots.githubusercontent.com/ copilotprodattachments.blob.core.windows.net/github-production-copilot-attachments/ github-production-user-asset-6210df.s3.amazonaws.com customer-stories-feed.github.com spotlights-feed.github.com objects-origin.githubusercontent.com *.githubusercontent.com; manifest-src 'self'; media-src github.com user-images.githubusercontent.com/ secured-user-images.githubusercontent.com/ private-user-images.githubusercontent.com github-production-user-asset-6210df.s3.amazonaws.com gist.github.com; script-src github.githubassets.com; style-src 'unsafe-inline' github.githubassets.com; upgrade-insecure-requests; worker-src github.githubassets.com github.com/assets-cdn/worker/ github.com/assets/ gist.github.com/assets-cdn/worker/
server: github.com
content-encoding: gzip
accept-ranges: bytes
set-cookie: _gh_sess=MQQwVW5W9hTEtv34N4BWTYIFuHSVyMHNvhkpybVQQFSUx2oFLIjMd%2Fq39srLytiDYDJpjP8jmN4f01%2F%2F3NCPk%2FHfIGiwcHCW5m9rDGnGp3ei71RwQk3kv1wSwBDGtRYckPuMwzOCuj1mFKFv%2BdLTI5d1kJAq0fDfNm1HUOxuNu9Kmsn47Q4F7c7n%2FRQmZmkLVyC80jnRNGrz9bBZjJRGERb0Y5xH7IVFym9oWQQ%2FKSb9EuOh01gGJ8p98fKXKw6nAqCZjXv3z1pUIQZvXnNqrA%3D%3D--r8og2bXMaI7joNjK--cw3U2L9ZtiT28C9Vo0oThw%3D%3D; Path=/; HttpOnly; Secure; SameSite=Lax
set-cookie: _octo=GH1.1.123658331.1760230249; Path=/; Domain=github.com; Expires=Mon, 12 Oct 2026 00:50:49 GMT; Secure; SameSite=Lax
set-cookie: logged_in=no; Path=/; Domain=github.com; Expires=Mon, 12 Oct 2026 00:50:49 GMT; HttpOnly; Secure; SameSite=Lax
x-github-request-id: E57E:123F70:AEC201:E97904:68EAFB69
files-to-prompt datasette -e py - · GitHub
Show Gist options
Save simonw/1922544763b08c76f0b904e2ece364ea to your computer and use it in GitHub Desktop.
{{ message }}
Instantly share code, notes, and snippets.
Created
February 5, 2025 05:50
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save simonw/1922544763b08c76f0b904e2ece364ea to your computer and use it in GitHub Desktop.
files-to-prompt datasette -e py -
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<documents> | |
<document index="1"> | |
<source>datasette/__init__.py</source> | |
<document_content> | |
from datasette.permissions import Permission # noqa | |
from datasette.version import __version_info__, __version__ # noqa | |
from datasette.events import Event # noqa | |
from datasette.utils.asgi import Forbidden, NotFound, Request, Response # noqa | |
from datasette.utils import actor_matches_allow # noqa | |
from datasette.views import Context # noqa | |
from .hookspecs import hookimpl # noqa | |
from .hookspecs import hookspec # noqa | |
</document_content> | |
</document> | |
<document index="2"> | |
<source>datasette/__main__.py</source> | |
<document_content> | |
from datasette.cli import cli | |
if __name__ == "__main__": | |
cli() | |
</document_content> | |
</document> | |
<document index="3"> | |
<source>datasette/actor_auth_cookie.py</source> | |
<document_content> | |
from datasette import hookimpl | |
from itsdangerous import BadSignature | |
from datasette.utils import baseconv | |
import time | |
@hookimpl | |
def actor_from_request(datasette, request): | |
if "ds_actor" not in request.cookies: | |
return None | |
try: | |
decoded = datasette.unsign(request.cookies["ds_actor"], "actor") | |
# If it has "e" and "a" keys process the "e" expiry | |
if not isinstance(decoded, dict) or "a" not in decoded: | |
return None | |
expires_at = decoded.get("e") | |
if expires_at: | |
timestamp = int(baseconv.base62.decode(expires_at)) | |
if time.time() > timestamp: | |
return None | |
return decoded["a"] | |
except BadSignature: | |
return None | |
</document_content> | |
</document> | |
<document index="4"> | |
<source>datasette/app.py</source> | |
<document_content> | |
from asgi_csrf import Errors | |
import asyncio | |
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union | |
import asgi_csrf | |
import collections | |
import contextvars | |
import dataclasses | |
import datetime | |
import functools | |
import glob | |
import hashlib | |
import httpx | |
import importlib.metadata | |
import inspect | |
from itsdangerous import BadSignature | |
import json | |
import os | |
import re | |
import secrets | |
import sys | |
import threading | |
import time | |
import types | |
import urllib.parse | |
import uuid | |
from concurrent import futures | |
from pathlib import Path | |
from markupsafe import Markup, escape | |
from itsdangerous import URLSafeSerializer | |
from jinja2 import ( | |
ChoiceLoader, | |
Environment, | |
FileSystemLoader, | |
PrefixLoader, | |
) | |
from jinja2.environment import Template | |
from jinja2.exceptions import TemplateNotFound | |
from .events import Event | |
from .views import Context | |
from .views.database import database_download, DatabaseView, TableCreateView, QueryView | |
from .views.index import IndexView | |
from .views.special import ( | |
JsonDataView, | |
PatternPortfolioView, | |
AuthTokenView, | |
ApiExplorerView, | |
CreateTokenView, | |
LogoutView, | |
AllowDebugView, | |
PermissionsDebugView, | |
MessagesDebugView, | |
) | |
from .views.table import ( | |
TableInsertView, | |
TableUpsertView, | |
TableDropView, | |
table_view, | |
) | |
from .views.row import RowView, RowDeleteView, RowUpdateView | |
from .renderer import json_renderer | |
from .url_builder import Urls | |
from .database import Database, QueryInterrupted | |
from .utils import ( | |
PrefixedUrlString, | |
SPATIALITE_FUNCTIONS, | |
StartupError, | |
async_call_with_supported_arguments, | |
await_me_maybe, | |
baseconv, | |
call_with_supported_arguments, | |
detect_json1, | |
display_actor, | |
escape_css_string, | |
escape_sqlite, | |
find_spatialite, | |
format_bytes, | |
module_from_path, | |
move_plugins_and_allow, | |
move_table_config, | |
parse_metadata, | |
resolve_env_secrets, | |
resolve_routes, | |
tilde_decode, | |
to_css_class, | |
urlsafe_components, | |
redact_keys, | |
row_sql_params_pks, | |
) | |
from .utils.asgi import ( | |
AsgiLifespan, | |
Forbidden, | |
NotFound, | |
DatabaseNotFound, | |
TableNotFound, | |
RowNotFound, | |
Request, | |
Response, | |
AsgiRunOnFirstRequest, | |
asgi_static, | |
asgi_send, | |
asgi_send_file, | |
asgi_send_redirect, | |
) | |
from .utils.internal_db import init_internal_db, populate_schema_tables | |
from .utils.sqlite import ( | |
sqlite3, | |
using_pysqlite3, | |
) | |
from .tracer import AsgiTracer | |
from .plugins import pm, DEFAULT_PLUGINS, get_plugins | |
from .version import __version__ | |
app_root = Path(__file__).parent.parent | |
# https://github.com/simonw/datasette/issues/283#issuecomment-781591015 | |
SQLITE_LIMIT_ATTACHED = 10 | |
Setting = collections.namedtuple("Setting", ("name", "default", "help")) | |
SETTINGS = ( | |
Setting("default_page_size", 100, "Default page size for the table view"), | |
Setting( | |
"max_returned_rows", | |
1000, | |
"Maximum rows that can be returned from a table or custom query", | |
), | |
Setting( | |
"max_insert_rows", | |
100, | |
"Maximum rows that can be inserted at a time using the bulk insert API", | |
), | |
Setting( | |
"num_sql_threads", | |
3, | |
"Number of threads in the thread pool for executing SQLite queries", | |
), | |
Setting("sql_time_limit_ms", 1000, "Time limit for a SQL query in milliseconds"), | |
Setting( | |
"default_facet_size", 30, "Number of values to return for requested facets" | |
), | |
Setting("facet_time_limit_ms", 200, "Time limit for calculating a requested facet"), | |
Setting( | |
"facet_suggest_time_limit_ms", | |
50, | |
"Time limit for calculating a suggested facet", | |
), | |
Setting( | |
"allow_facet", | |
True, | |
"Allow users to specify columns to facet using ?_facet= parameter", | |
), | |
Setting( | |
"allow_download", | |
True, | |
"Allow users to download the original SQLite database files", | |
), | |
Setting( | |
"allow_signed_tokens", | |
True, | |
"Allow users to create and use signed API tokens", | |
), | |
Setting( | |
"default_allow_sql", | |
True, | |
"Allow anyone to run arbitrary SQL queries", | |
), | |
Setting( | |
"max_signed_tokens_ttl", | |
0, | |
"Maximum allowed expiry time for signed API tokens", | |
), | |
Setting("suggest_facets", True, "Calculate and display suggested facets"), | |
Setting( | |
"default_cache_ttl", | |
5, | |
"Default HTTP cache TTL (used in Cache-Control: max-age= header)", | |
), | |
Setting("cache_size_kb", 0, "SQLite cache size in KB (0 == use SQLite default)"), | |
Setting( | |
"allow_csv_stream", | |
True, | |
"Allow .csv?_stream=1 to download all rows (ignoring max_returned_rows)", | |
), | |
Setting( | |
"max_csv_mb", | |
100, | |
"Maximum size allowed for CSV export in MB - set 0 to disable this limit", | |
), | |
Setting( | |
"truncate_cells_html", | |
2048, | |
"Truncate cells longer than this in HTML table view - set 0 to disable", | |
), | |
Setting( | |
"force_https_urls", | |
False, | |
"Force URLs in API output to always use https:// protocol", | |
), | |
Setting( | |
"template_debug", | |
False, | |
"Allow display of template debug information with ?_context=1", | |
), | |
Setting( | |
"trace_debug", | |
False, | |
"Allow display of SQL trace debug information with ?_trace=1", | |
), | |
Setting("base_url", "/", "Datasette URLs should use this base path"), | |
) | |
_HASH_URLS_REMOVED = "The hash_urls setting has been removed, try the datasette-hashed-urls plugin instead" | |
OBSOLETE_SETTINGS = { | |
"hash_urls": _HASH_URLS_REMOVED, | |
"default_cache_ttl_hashed": _HASH_URLS_REMOVED, | |
} | |
DEFAULT_SETTINGS = {option.name: option.default for option in SETTINGS} | |
FAVICON_PATH = app_root / "datasette" / "static" / "favicon.png" | |
DEFAULT_NOT_SET = object() | |
async def favicon(request, send): | |
await asgi_send_file( | |
send, | |
str(FAVICON_PATH), | |
content_type="image/png", | |
headers={"Cache-Control": "max-age=3600, immutable, public"}, | |
) | |
ResolvedTable = collections.namedtuple("ResolvedTable", ("db", "table", "is_view")) | |
ResolvedRow = collections.namedtuple( | |
"ResolvedRow", ("db", "table", "sql", "params", "pks", "pk_values", "row") | |
) | |
def _to_string(value): | |
if isinstance(value, str): | |
return value | |
else: | |
return json.dumps(value, default=str) | |
class Datasette: | |
# Message constants: | |
INFO = 1 | |
WARNING = 2 | |
ERROR = 3 | |
def __init__( | |
self, | |
files=None, | |
immutables=None, | |
cache_headers=True, | |
cors=False, | |
inspect_data=None, | |
config=None, | |
metadata=None, | |
sqlite_extensions=None, | |
template_dir=None, | |
plugins_dir=None, | |
static_mounts=None, | |
memory=False, | |
settings=None, | |
secret=None, | |
version_note=None, | |
config_dir=None, | |
pdb=False, | |
crossdb=False, | |
nolock=False, | |
internal=None, | |
): | |
self._startup_invoked = False | |
self._request_id = contextvars.ContextVar("request_id", default=None) | |
assert config_dir is None or isinstance( | |
config_dir, Path | |
), "config_dir= should be a pathlib.Path" | |
self.config_dir = config_dir | |
self.pdb = pdb | |
self._secret = secret or secrets.token_hex(32) | |
if files is not None and isinstance(files, str): | |
raise ValueError("files= must be a list of paths, not a string") | |
self.files = tuple(files or []) + tuple(immutables or []) | |
if config_dir: | |
db_files = [] | |
for ext in ("db", "sqlite", "sqlite3"): | |
db_files.extend(config_dir.glob("*.{}".format(ext))) | |
self.files += tuple(str(f) for f in db_files) | |
if ( | |
config_dir | |
and (config_dir / "inspect-data.json").exists() | |
and not inspect_data | |
): | |
inspect_data = json.loads((config_dir / "inspect-data.json").read_text()) | |
if not immutables: | |
immutable_filenames = [i["file"] for i in inspect_data.values()] | |
immutables = [ | |
f for f in self.files if Path(f).name in immutable_filenames | |
] | |
self.inspect_data = inspect_data | |
self.immutables = set(immutables or []) | |
self.databases = collections.OrderedDict() | |
self.permissions = {} # .invoke_startup() will populate this | |
try: | |
self._refresh_schemas_lock = asyncio.Lock() | |
except RuntimeError as rex: | |
# Workaround for intermittent test failure, see: | |
# https://github.com/simonw/datasette/issues/1802 | |
if "There is no current event loop in thread" in str(rex): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
self._refresh_schemas_lock = asyncio.Lock() | |
else: | |
raise | |
self.crossdb = crossdb | |
self.nolock = nolock | |
if memory or crossdb or not self.files: | |
self.add_database( | |
Database(self, is_mutable=False, is_memory=True), name="_memory" | |
) | |
for file in self.files: | |
self.add_database( | |
Database(self, file, is_mutable=file not in self.immutables) | |
) | |
self.internal_db_created = False | |
if internal is None: | |
self._internal_database = Database(self, memory_name=secrets.token_hex()) | |
else: | |
self._internal_database = Database(self, path=internal, mode="rwc") | |
self._internal_database.name = "__INTERNAL__" | |
self.cache_headers = cache_headers | |
self.cors = cors | |
config_files = [] | |
metadata_files = [] | |
if config_dir: | |
metadata_files = [ | |
config_dir / filename | |
for filename in ("metadata.json", "metadata.yaml", "metadata.yml") | |
if (config_dir / filename).exists() | |
] | |
config_files = [ | |
config_dir / filename | |
for filename in ("datasette.json", "datasette.yaml", "datasette.yml") | |
if (config_dir / filename).exists() | |
] | |
if config_dir and metadata_files and not metadata: | |
with metadata_files[0].open() as fp: | |
metadata = parse_metadata(fp.read()) | |
if config_dir and config_files and not config: | |
with config_files[0].open() as fp: | |
config = parse_metadata(fp.read()) | |
# Move any "plugins" and "allow" settings from metadata to config - updates them in place | |
metadata = metadata or {} | |
config = config or {} | |
metadata, config = move_plugins_and_allow(metadata, config) | |
# Now migrate any known table configuration settings over as well | |
metadata, config = move_table_config(metadata, config) | |
self._metadata_local = metadata or {} | |
self.sqlite_extensions = [] | |
for extension in sqlite_extensions or []: | |
# Resolve spatialite, if requested | |
if extension == "spatialite": | |
# Could raise SpatialiteNotFound | |
self.sqlite_extensions.append(find_spatialite()) | |
else: | |
self.sqlite_extensions.append(extension) | |
if config_dir and (config_dir / "templates").is_dir() and not template_dir: | |
template_dir = str((config_dir / "templates").resolve()) | |
self.template_dir = template_dir | |
if config_dir and (config_dir / "plugins").is_dir() and not plugins_dir: | |
plugins_dir = str((config_dir / "plugins").resolve()) | |
self.plugins_dir = plugins_dir | |
if config_dir and (config_dir / "static").is_dir() and not static_mounts: | |
static_mounts = [("static", str((config_dir / "static").resolve()))] | |
self.static_mounts = static_mounts or [] | |
if config_dir and (config_dir / "datasette.json").exists() and not config: | |
config = json.loads((config_dir / "datasette.json").read_text()) | |
config = config or {} | |
config_settings = config.get("settings") or {} | |
# validate "settings" keys in datasette.json | |
for key in config_settings: | |
if key not in DEFAULT_SETTINGS: | |
raise StartupError("Invalid setting '{}' in datasette.json".format(key)) | |
self.config = config | |
# CLI settings should overwrite datasette.json settings | |
self._settings = dict(DEFAULT_SETTINGS, **(config_settings), **(settings or {})) | |
self.renderers = {} # File extension -> (renderer, can_render) functions | |
self.version_note = version_note | |
if self.setting("num_sql_threads") == 0: | |
self.executor = None | |
else: | |
self.executor = futures.ThreadPoolExecutor( | |
max_workers=self.setting("num_sql_threads") | |
) | |
self.max_returned_rows = self.setting("max_returned_rows") | |
self.sql_time_limit_ms = self.setting("sql_time_limit_ms") | |
self.page_size = self.setting("default_page_size") | |
# Execute plugins in constructor, to ensure they are available | |
# when the rest of `datasette inspect` executes | |
if self.plugins_dir: | |
for filepath in glob.glob(os.path.join(self.plugins_dir, "*.py")): | |
if not os.path.isfile(filepath): | |
continue | |
mod = module_from_path(filepath, name=os.path.basename(filepath)) | |
try: | |
pm.register(mod) | |
except ValueError: | |
# Plugin already registered | |
pass | |
# Configure Jinja | |
default_templates = str(app_root / "datasette" / "templates") | |
template_paths = [] | |
if self.template_dir: | |
template_paths.append(self.template_dir) | |
plugin_template_paths = [ | |
plugin["templates_path"] | |
for plugin in get_plugins() | |
if plugin["templates_path"] | |
] | |
template_paths.extend(plugin_template_paths) | |
template_paths.append(default_templates) | |
template_loader = ChoiceLoader( | |
[ | |
FileSystemLoader(template_paths), | |
# Support {% extends "default:table.html" %}: | |
PrefixLoader( | |
{"default": FileSystemLoader(default_templates)}, delimiter=":" | |
), | |
] | |
) | |
environment = Environment( | |
loader=template_loader, | |
autoescape=True, | |
enable_async=True, | |
# undefined=StrictUndefined, | |
) | |
environment.filters["escape_css_string"] = escape_css_string | |
environment.filters["quote_plus"] = urllib.parse.quote_plus | |
self._jinja_env = environment | |
environment.filters["escape_sqlite"] = escape_sqlite | |
environment.filters["to_css_class"] = to_css_class | |
self._register_renderers() | |
self._permission_checks = collections.deque(maxlen=200) | |
self._root_token = secrets.token_hex(32) | |
self.client = DatasetteClient(self) | |
@property | |
def request_id(self): | |
return self._request_id.get() | |
async def apply_metadata_json(self): | |
# Apply any metadata entries from metadata.json to the internal tables | |
# step 1: top-level metadata | |
for key in self._metadata_local or {}: | |
if key == "databases": | |
continue | |
value = self._metadata_local[key] | |
await self.set_instance_metadata(key, _to_string(value)) | |
# step 2: database-level metadata | |
for dbname, db in self._metadata_local.get("databases", {}).items(): | |
for key, value in db.items(): | |
if key in ("tables", "queries"): | |
continue | |
await self.set_database_metadata(dbname, key, _to_string(value)) | |
# step 3: table-level metadata | |
for tablename, table in db.get("tables", {}).items(): | |
for key, value in table.items(): | |
if key == "columns": | |
continue | |
await self.set_resource_metadata( | |
dbname, tablename, key, _to_string(value) | |
) | |
# step 4: column-level metadata (only descriptions in metadata.json) | |
for columnname, column_description in table.get("columns", {}).items(): | |
await self.set_column_metadata( | |
dbname, tablename, columnname, "description", column_description | |
) | |
# TODO(alex) is metadata.json was loaded in, and --internal is not memory, then log | |
# a warning to user that they should delete their metadata.json file | |
def get_jinja_environment(self, request: Request = None) -> Environment: | |
environment = self._jinja_env | |
if request: | |
for environment in pm.hook.jinja2_environment_from_request( | |
datasette=self, request=request, env=environment | |
): | |
pass | |
return environment | |
def get_permission(self, name_or_abbr: str) -> "Permission": | |
""" | |
Returns a Permission object for the given name or abbreviation. Raises KeyError if not found. | |
""" | |
if name_or_abbr in self.permissions: | |
return self.permissions[name_or_abbr] | |
# Try abbreviation | |
for permission in self.permissions.values(): | |
if permission.abbr == name_or_abbr: | |
return permission | |
raise KeyError( | |
"No permission found with name or abbreviation {}".format(name_or_abbr) | |
) | |
async def refresh_schemas(self): | |
if self._refresh_schemas_lock.locked(): | |
return | |
async with self._refresh_schemas_lock: | |
await self._refresh_schemas() | |
async def _refresh_schemas(self): | |
internal_db = self.get_internal_database() | |
if not self.internal_db_created: | |
await init_internal_db(internal_db) | |
await self.apply_metadata_json() | |
self.internal_db_created = True | |
current_schema_versions = { | |
row["database_name"]: row["schema_version"] | |
for row in await internal_db.execute( | |
"select database_name, schema_version from catalog_databases" | |
) | |
} | |
for database_name, db in self.databases.items(): | |
schema_version = (await db.execute("PRAGMA schema_version")).first()[0] | |
# Compare schema versions to see if we should skip it | |
if schema_version == current_schema_versions.get(database_name): | |
continue | |
placeholders = "(?, ?, ?, ?)" | |
values = [database_name, str(db.path), db.is_memory, schema_version] | |
if db.path is None: | |
placeholders = "(?, null, ?, ?)" | |
values = [database_name, db.is_memory, schema_version] | |
await internal_db.execute_write( | |
""" | |
INSERT OR REPLACE INTO catalog_databases (database_name, path, is_memory, schema_version) | |
VALUES {} | |
""".format( | |
placeholders | |
), | |
values, | |
) | |
await populate_schema_tables(internal_db, db) | |
@property | |
def urls(self): | |
return Urls(self) | |
async def invoke_startup(self): | |
# This must be called for Datasette to be in a usable state | |
if self._startup_invoked: | |
return | |
# Register event classes | |
event_classes = [] | |
for hook in pm.hook.register_events(datasette=self): | |
extra_classes = await await_me_maybe(hook) | |
if extra_classes: | |
event_classes.extend(extra_classes) | |
self.event_classes = tuple(event_classes) | |
# Register permissions, but watch out for duplicate name/abbr | |
names = {} | |
abbrs = {} | |
for hook in pm.hook.register_permissions(datasette=self): | |
if hook: | |
for p in hook: | |
if p.name in names and p != names[p.name]: | |
raise StartupError( | |
"Duplicate permission name: {}".format(p.name) | |
) | |
if p.abbr and p.abbr in abbrs and p != abbrs[p.abbr]: | |
raise StartupError( | |
"Duplicate permission abbr: {}".format(p.abbr) | |
) | |
names[p.name] = p | |
if p.abbr: | |
abbrs[p.abbr] = p | |
self.permissions[p.name] = p | |
for hook in pm.hook.prepare_jinja2_environment( | |
env=self._jinja_env, datasette=self | |
): | |
await await_me_maybe(hook) | |
for hook in pm.hook.startup(datasette=self): | |
await await_me_maybe(hook) | |
self._startup_invoked = True | |
def sign(self, value, namespace="default"): | |
return URLSafeSerializer(self._secret, namespace).dumps(value) | |
def unsign(self, signed, namespace="default"): | |
return URLSafeSerializer(self._secret, namespace).loads(signed) | |
def create_token( | |
self, | |
actor_id: str, | |
*, | |
expires_after: Optional[int] = None, | |
restrict_all: Optional[Iterable[str]] = None, | |
restrict_database: Optional[Dict[str, Iterable[str]]] = None, | |
restrict_resource: Optional[Dict[str, Dict[str, Iterable[str]]]] = None, | |
): | |
token = {"a": actor_id, "t": int(time.time())} | |
if expires_after: | |
token["d"] = expires_after | |
def abbreviate_action(action): | |
# rename to abbr if possible | |
permission = self.permissions.get(action) | |
if not permission: | |
return action | |
return permission.abbr or action | |
if expires_after: | |
token["d"] = expires_after | |
if restrict_all or restrict_database or restrict_resource: | |
token["_r"] = {} | |
if restrict_all: | |
token["_r"]["a"] = [abbreviate_action(a) for a in restrict_all] | |
if restrict_database: | |
token["_r"]["d"] = {} | |
for database, actions in restrict_database.items(): | |
token["_r"]["d"][database] = [abbreviate_action(a) for a in actions] | |
if restrict_resource: | |
token["_r"]["r"] = {} | |
for database, resources in restrict_resource.items(): | |
for resource, actions in resources.items(): | |
token["_r"]["r"].setdefault(database, {})[resource] = [ | |
abbreviate_action(a) for a in actions | |
] | |
return "dstok_{}".format(self.sign(token, namespace="token")) | |
def get_database(self, name=None, route=None): | |
if route is not None: | |
matches = [db for db in self.databases.values() if db.route == route] | |
if not matches: | |
raise KeyError | |
return matches[0] | |
if name is None: | |
name = [key for key in self.databases.keys()][0] | |
return self.databases[name] | |
def add_database(self, db, name=None, route=None): | |
new_databases = self.databases.copy() | |
if name is None: | |
# Pick a unique name for this database | |
suggestion = db.suggest_name() | |
name = suggestion | |
else: | |
suggestion = name | |
i = 2 | |
while name in self.databases: | |
name = "{}_{}".format(suggestion, i) | |
i += 1 | |
db.name = name | |
db.route = route or name | |
new_databases[name] = db | |
# don't mutate! that causes race conditions with live import | |
self.databases = new_databases | |
return db | |
def add_memory_database(self, memory_name): | |
return self.add_database(Database(self, memory_name=memory_name)) | |
def remove_database(self, name): | |
new_databases = self.databases.copy() | |
new_databases.pop(name) | |
self.databases = new_databases | |
def setting(self, key): | |
return self._settings.get(key, None) | |
def settings_dict(self): | |
# Returns a fully resolved settings dictionary, useful for templates | |
return {option.name: self.setting(option.name) for option in SETTINGS} | |
def _metadata_recursive_update(self, orig, updated): | |
if not isinstance(orig, dict) or not isinstance(updated, dict): | |
return orig | |
for key, upd_value in updated.items(): | |
if isinstance(upd_value, dict) and isinstance(orig.get(key), dict): | |
orig[key] = self._metadata_recursive_update(orig[key], upd_value) | |
else: | |
orig[key] = upd_value | |
return orig | |
async def get_instance_metadata(self): | |
rows = await self.get_internal_database().execute( | |
""" | |
SELECT | |
key, | |
value | |
FROM metadata_instance | |
""" | |
) | |
return dict(rows) | |
async def get_database_metadata(self, database_name: str): | |
rows = await self.get_internal_database().execute( | |
""" | |
SELECT | |
key, | |
value | |
FROM metadata_databases | |
WHERE database_name = ? | |
""", | |
[database_name], | |
) | |
return dict(rows) | |
async def get_resource_metadata(self, database_name: str, resource_name: str): | |
rows = await self.get_internal_database().execute( | |
""" | |
SELECT | |
key, | |
value | |
FROM metadata_resources | |
WHERE database_name = ? | |
AND resource_name = ? | |
""", | |
[database_name, resource_name], | |
) | |
return dict(rows) | |
async def get_column_metadata( | |
self, database_name: str, resource_name: str, column_name: str | |
): | |
rows = await self.get_internal_database().execute( | |
""" | |
SELECT | |
key, | |
value | |
FROM metadata_columns | |
WHERE database_name = ? | |
AND resource_name = ? | |
AND column_name = ? | |
""", | |
[database_name, resource_name, column_name], | |
) | |
return dict(rows) | |
async def set_instance_metadata(self, key: str, value: str): | |
# TODO upsert only supported on SQLite 3.24.0 (2018-06-04) | |
await self.get_internal_database().execute_write( | |
""" | |
INSERT INTO metadata_instance(key, value) | |
VALUES(?, ?) | |
ON CONFLICT(key) DO UPDATE SET value = excluded.value; | |
""", | |
[key, value], | |
) | |
async def set_database_metadata(self, database_name: str, key: str, value: str): | |
# TODO upsert only supported on SQLite 3.24.0 (2018-06-04) | |
await self.get_internal_database().execute_write( | |
""" | |
INSERT INTO metadata_databases(database_name, key, value) | |
VALUES(?, ?, ?) | |
ON CONFLICT(database_name, key) DO UPDATE SET value = excluded.value; | |
""", | |
[database_name, key, value], | |
) | |
async def set_resource_metadata( | |
self, database_name: str, resource_name: str, key: str, value: str | |
): | |
# TODO upsert only supported on SQLite 3.24.0 (2018-06-04) | |
await self.get_internal_database().execute_write( | |
""" | |
INSERT INTO metadata_resources(database_name, resource_name, key, value) | |
VALUES(?, ?, ?, ?) | |
ON CONFLICT(database_name, resource_name, key) DO UPDATE SET value = excluded.value; | |
""", | |
[database_name, resource_name, key, value], | |
) | |
async def set_column_metadata( | |
self, | |
database_name: str, | |
resource_name: str, | |
column_name: str, | |
key: str, | |
value: str, | |
): | |
# TODO upsert only supported on SQLite 3.24.0 (2018-06-04) | |
await self.get_internal_database().execute_write( | |
""" | |
INSERT INTO metadata_columns(database_name, resource_name, column_name, key, value) | |
VALUES(?, ?, ?, ?, ?) | |
ON CONFLICT(database_name, resource_name, column_name, key) DO UPDATE SET value = excluded.value; | |
""", | |
[database_name, resource_name, column_name, key, value], | |
) | |
def get_internal_database(self): | |
return self._internal_database | |
def plugin_config(self, plugin_name, database=None, table=None, fallback=True): | |
"""Return config for plugin, falling back from specified database/table""" | |
if database is None and table is None: | |
config = self._plugin_config_top(plugin_name) | |
else: | |
config = self._plugin_config_nested(plugin_name, database, table, fallback) | |
return resolve_env_secrets(config, os.environ) | |
def _plugin_config_top(self, plugin_name): | |
"""Returns any top-level plugin configuration for the specified plugin.""" | |
return ((self.config or {}).get("plugins") or {}).get(plugin_name) | |
def _plugin_config_nested(self, plugin_name, database, table=None, fallback=True): | |
"""Returns any database or table-level plugin configuration for the specified plugin.""" | |
db_config = ((self.config or {}).get("databases") or {}).get(database) | |
# if there's no db-level configuration, then return early, falling back to top-level if needed | |
if not db_config: | |
return self._plugin_config_top(plugin_name) if fallback else None | |
db_plugin_config = (db_config.get("plugins") or {}).get(plugin_name) | |
if table: | |
table_plugin_config = ( | |
((db_config.get("tables") or {}).get(table) or {}).get("plugins") or {} | |
).get(plugin_name) | |
# fallback to db_config or top-level config, in that order, if needed | |
if table_plugin_config is None and fallback: | |
return db_plugin_config or self._plugin_config_top(plugin_name) | |
return table_plugin_config | |
# fallback to top-level if needed | |
if db_plugin_config is None and fallback: | |
self._plugin_config_top(plugin_name) | |
return db_plugin_config | |
def app_css_hash(self): | |
if not hasattr(self, "_app_css_hash"): | |
with open(os.path.join(str(app_root), "datasette/static/app.css")) as fp: | |
self._app_css_hash = hashlib.sha1(fp.read().encode("utf8")).hexdigest()[ | |
:6 | |
] | |
return self._app_css_hash | |
async def get_canned_queries(self, database_name, actor): | |
queries = ( | |
((self.config or {}).get("databases") or {}).get(database_name) or {} | |
).get("queries") or {} | |
for more_queries in pm.hook.canned_queries( | |
datasette=self, | |
database=database_name, | |
actor=actor, | |
): | |
more_queries = await await_me_maybe(more_queries) | |
queries.update(more_queries or {}) | |
# Fix any {"name": "select ..."} queries to be {"name": {"sql": "select ..."}} | |
for key in queries: | |
if not isinstance(queries[key], dict): | |
queries[key] = {"sql": queries[key]} | |
# Also make sure "name" is available: | |
queries[key]["name"] = key | |
return queries | |
async def get_canned_query(self, database_name, query_name, actor): | |
queries = await self.get_canned_queries(database_name, actor) | |
query = queries.get(query_name) | |
if query: | |
return query | |
def _prepare_connection(self, conn, database): | |
conn.row_factory = sqlite3.Row | |
conn.text_factory = lambda x: str(x, "utf-8", "replace") | |
if self.sqlite_extensions: | |
conn.enable_load_extension(True) | |
for extension in self.sqlite_extensions: | |
# "extension" is either a string path to the extension | |
# or a 2-item tuple that specifies which entrypoint to load. | |
if isinstance(extension, tuple): | |
path, entrypoint = extension | |
conn.execute("SELECT load_extension(?, ?)", [path, entrypoint]) | |
else: | |
conn.execute("SELECT load_extension(?)", [extension]) | |
if self.setting("cache_size_kb"): | |
conn.execute(f"PRAGMA cache_size=-{self.setting('cache_size_kb')}") | |
# pylint: disable=no-member | |
pm.hook.prepare_connection(conn=conn, database=database, datasette=self) | |
# If self.crossdb and this is _memory, connect the first SQLITE_LIMIT_ATTACHED databases | |
if self.crossdb and database == "_memory": | |
count = 0 | |
for db_name, db in self.databases.items(): | |
if count >= SQLITE_LIMIT_ATTACHED or db.is_memory: | |
continue | |
sql = 'ATTACH DATABASE "file:{path}?{qs}" AS [{name}];'.format( | |
path=db.path, | |
qs="mode=ro" if db.is_mutable else "immutable=1", | |
name=db_name, | |
) | |
conn.execute(sql) | |
count += 1 | |
def add_message(self, request, message, type=INFO): | |
if not hasattr(request, "_messages"): | |
request._messages = [] | |
request._messages_should_clear = False | |
request._messages.append((message, type)) | |
def _write_messages_to_response(self, request, response): | |
if getattr(request, "_messages", None): | |
# Set those messages | |
response.set_cookie("ds_messages", self.sign(request._messages, "messages")) | |
elif getattr(request, "_messages_should_clear", False): | |
response.set_cookie("ds_messages", "", expires=0, max_age=0) | |
def _show_messages(self, request): | |
if getattr(request, "_messages", None): | |
request._messages_should_clear = True | |
messages = request._messages | |
request._messages = [] | |
return messages | |
else: | |
return [] | |
async def _crumb_items(self, request, table=None, database=None): | |
crumbs = [] | |
actor = None | |
if request: | |
actor = request.actor | |
# Top-level link | |
if await self.permission_allowed(actor=actor, action="view-instance"): | |
crumbs.append({"href": self.urls.instance(), "label": "home"}) | |
# Database link | |
if database: | |
if await self.permission_allowed( | |
actor=actor, | |
action="view-database", | |
resource=database, | |
): | |
crumbs.append( | |
{ | |
"href": self.urls.database(database), | |
"label": database, | |
} | |
) | |
# Table link | |
if table: | |
assert database, "table= requires database=" | |
if await self.permission_allowed( | |
actor=actor, | |
action="view-table", | |
resource=(database, table), | |
): | |
crumbs.append( | |
{ | |
"href": self.urls.table(database, table), | |
"label": table, | |
} | |
) | |
return crumbs | |
async def actors_from_ids( | |
self, actor_ids: Iterable[Union[str, int]] | |
) -> Dict[Union[id, str], Dict]: | |
result = pm.hook.actors_from_ids(datasette=self, actor_ids=actor_ids) | |
if result is None: | |
# Do the default thing | |
return {actor_id: {"id": actor_id} for actor_id in actor_ids} | |
result = await await_me_maybe(result) | |
return result | |
async def track_event(self, event: Event): | |
assert isinstance(event, self.event_classes), "Invalid event type: {}".format( | |
type(event) | |
) | |
for hook in pm.hook.track_event(datasette=self, event=event): | |
await await_me_maybe(hook) | |
async def permission_allowed( | |
self, actor, action, resource=None, *, default=DEFAULT_NOT_SET | |
): | |
"""Check permissions using the permissions_allowed plugin hook""" | |
result = None | |
# Use default from registered permission, if available | |
if default is DEFAULT_NOT_SET and action in self.permissions: | |
default = self.permissions[action].default | |
opinions = [] | |
# Every plugin is consulted for their opinion | |
for check in pm.hook.permission_allowed( | |
datasette=self, | |
actor=actor, | |
action=action, | |
resource=resource, | |
): | |
check = await await_me_maybe(check) | |
if check is not None: | |
opinions.append(check) | |
result = None | |
# If any plugin said False it's false - the veto rule | |
if any(not r for r in opinions): | |
result = False | |
elif any(r for r in opinions): | |
# Otherwise, if any plugin said True it's true | |
result = True | |
used_default = False | |
if result is None: | |
# No plugin expressed an opinion, so use the default | |
result = default | |
used_default = True | |
self._permission_checks.append( | |
{ | |
"request_id": self.request_id, | |
"when": datetime.datetime.now(datetime.timezone.utc).isoformat(), | |
"actor": actor, | |
"action": action, | |
"resource": resource, | |
"used_default": used_default, | |
"result": result, | |
} | |
) | |
return result | |
async def ensure_permissions( | |
self, | |
actor: dict, | |
permissions: Sequence[Union[Tuple[str, Union[str, Tuple[str, str]]], str]], | |
): | |
""" | |
permissions is a list of (action, resource) tuples or 'action' strings | |
Raises datasette.Forbidden() if any of the checks fail | |
""" | |
assert actor is None or isinstance(actor, dict), "actor must be None or a dict" | |
for permission in permissions: | |
if isinstance(permission, str): | |
action = permission | |
resource = None | |
elif isinstance(permission, (tuple, list)) and len(permission) == 2: | |
action, resource = permission | |
else: | |
assert ( | |
False | |
), "permission should be string or tuple of two items: {}".format( | |
repr(permission) | |
) | |
ok = await self.permission_allowed( | |
actor, | |
action, | |
resource=resource, | |
default=None, | |
) | |
if ok is not None: | |
if ok: | |
return | |
else: | |
raise Forbidden(action) | |
async def check_visibility( | |
self, | |
actor: dict, | |
action: Optional[str] = None, | |
resource: Optional[Union[str, Tuple[str, str]]] = None, | |
permissions: Optional[ | |
Sequence[Union[Tuple[str, Union[str, Tuple[str, str]]], str]] | |
] = None, | |
): | |
"""Returns (visible, private) - visible = can you see it, private = can others see it too""" | |
if permissions: | |
assert ( | |
not action and not resource | |
), "Can't use action= or resource= with permissions=" | |
else: | |
permissions = [(action, resource)] | |
try: | |
await self.ensure_permissions(actor, permissions) | |
except Forbidden: | |
return False, False | |
# User can see it, but can the anonymous user see it? | |
try: | |
await self.ensure_permissions(None, permissions) | |
except Forbidden: | |
# It's visible but private | |
return True, True | |
# It's visible to everyone | |
return True, False | |
async def execute( | |
self, | |
db_name, | |
sql, | |
params=None, | |
truncate=False, | |
custom_time_limit=None, | |
page_size=None, | |
log_sql_errors=True, | |
): | |
return await self.databases[db_name].execute( | |
sql, | |
params=params, | |
truncate=truncate, | |
custom_time_limit=custom_time_limit, | |
page_size=page_size, | |
log_sql_errors=log_sql_errors, | |
) | |
async def expand_foreign_keys(self, actor, database, table, column, values): | |
"""Returns dict mapping (column, value) -> label""" | |
labeled_fks = {} | |
db = self.databases[database] | |
foreign_keys = await db.foreign_keys_for_table(table) | |
# Find the foreign_key for this column | |
try: | |
fk = [ | |
foreign_key | |
for foreign_key in foreign_keys | |
if foreign_key["column"] == column | |
][0] | |
except IndexError: | |
return {} | |
# Ensure user has permission to view the referenced table | |
other_table = fk["other_table"] | |
other_column = fk["other_column"] | |
visible, _ = await self.check_visibility( | |
actor, | |
permissions=[ | |
("view-table", (database, other_table)), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if not visible: | |
return {} | |
label_column = await db.label_column_for_table(other_table) | |
if not label_column: | |
return {(fk["column"], value): str(value) for value in values} | |
labeled_fks = {} | |
sql = """ | |
select {other_column}, {label_column} | |
from {other_table} | |
where {other_column} in ({placeholders}) | |
""".format( | |
other_column=escape_sqlite(other_column), | |
label_column=escape_sqlite(label_column), | |
other_table=escape_sqlite(other_table), | |
placeholders=", ".join(["?"] * len(set(values))), | |
) | |
try: | |
results = await self.execute(database, sql, list(set(values))) | |
except QueryInterrupted: | |
pass | |
else: | |
for id, value in results: | |
labeled_fks[(fk["column"], id)] = value | |
return labeled_fks | |
def absolute_url(self, request, path): | |
url = urllib.parse.urljoin(request.url, path) | |
if url.startswith("https://") and self.setting("force_https_urls"): | |
url = "https://" + url[len("https://") :] | |
return url | |
def _connected_databases(self): | |
return [ | |
{ | |
"name": d.name, | |
"route": d.route, | |
"path": d.path, | |
"size": d.size, | |
"is_mutable": d.is_mutable, | |
"is_memory": d.is_memory, | |
"hash": d.hash, | |
} | |
for name, d in self.databases.items() | |
] | |
def _versions(self): | |
conn = sqlite3.connect(":memory:") | |
self._prepare_connection(conn, "_memory") | |
sqlite_version = conn.execute("select sqlite_version()").fetchone()[0] | |
sqlite_extensions = {"json1": detect_json1(conn)} | |
for extension, testsql, hasversion in ( | |
("spatialite", "SELECT spatialite_version()", True), | |
): | |
try: | |
result = conn.execute(testsql) | |
if hasversion: | |
sqlite_extensions[extension] = result.fetchone()[0] | |
else: | |
sqlite_extensions[extension] = None | |
except Exception: | |
pass | |
# More details on SpatiaLite | |
if "spatialite" in sqlite_extensions: | |
spatialite_details = {} | |
for fn in SPATIALITE_FUNCTIONS: | |
try: | |
result = conn.execute("select {}()".format(fn)) | |
spatialite_details[fn] = result.fetchone()[0] | |
except Exception as e: | |
spatialite_details[fn] = {"error": str(e)} | |
sqlite_extensions["spatialite"] = spatialite_details | |
# Figure out supported FTS versions | |
fts_versions = [] | |
for fts in ("FTS5", "FTS4", "FTS3"): | |
try: | |
conn.execute( | |
"CREATE VIRTUAL TABLE v{fts} USING {fts} (data)".format(fts=fts) | |
) | |
fts_versions.append(fts) | |
except sqlite3.OperationalError: | |
continue | |
datasette_version = {"version": __version__} | |
if self.version_note: | |
datasette_version["note"] = self.version_note | |
try: | |
# Optional import to avoid breaking Pyodide | |
# https://github.com/simonw/datasette/issues/1733#issuecomment-1115268245 | |
import uvicorn | |
uvicorn_version = uvicorn.__version__ | |
except ImportError: | |
uvicorn_version = None | |
info = { | |
"python": { | |
"version": ".".join(map(str, sys.version_info[:3])), | |
"full": sys.version, | |
}, | |
"datasette": datasette_version, | |
"asgi": "3.0", | |
"uvicorn": uvicorn_version, | |
"sqlite": { | |
"version": sqlite_version, | |
"fts_versions": fts_versions, | |
"extensions": sqlite_extensions, | |
"compile_options": [ | |
r[0] for r in conn.execute("pragma compile_options;").fetchall() | |
], | |
}, | |
} | |
if using_pysqlite3: | |
for package in ("pysqlite3", "pysqlite3-binary"): | |
try: | |
info["pysqlite3"] = importlib.metadata.version(package) | |
break | |
except importlib.metadata.PackageNotFoundError: | |
pass | |
return info | |
def _plugins(self, request=None, all=False): | |
ps = list(get_plugins()) | |
should_show_all = False | |
if request is not None: | |
should_show_all = request.args.get("all") | |
else: | |
should_show_all = all | |
if not should_show_all: | |
ps = [p for p in ps if p["name"] not in DEFAULT_PLUGINS] | |
ps.sort(key=lambda p: p["name"]) | |
return [ | |
{ | |
"name": p["name"], | |
"static": p["static_path"] is not None, | |
"templates": p["templates_path"] is not None, | |
"version": p.get("version"), | |
"hooks": list(sorted(set(p["hooks"]))), | |
} | |
for p in ps | |
] | |
def _threads(self): | |
if self.setting("num_sql_threads") == 0: | |
return {"num_threads": 0, "threads": []} | |
threads = list(threading.enumerate()) | |
d = { | |
"num_threads": len(threads), | |
"threads": [ | |
{"name": t.name, "ident": t.ident, "daemon": t.daemon} for t in threads | |
], | |
} | |
tasks = asyncio.all_tasks() | |
d.update( | |
{ | |
"num_tasks": len(tasks), | |
"tasks": [_cleaner_task_str(t) for t in tasks], | |
} | |
) | |
return d | |
def _actor(self, request): | |
return {"actor": request.actor} | |
async def table_config(self, database: str, table: str) -> dict: | |
"""Return dictionary of configuration for specified table""" | |
return ( | |
(self.config or {}) | |
.get("databases", {}) | |
.get(database, {}) | |
.get("tables", {}) | |
.get(table, {}) | |
) | |
def _register_renderers(self): | |
"""Register output renderers which output data in custom formats.""" | |
# Built-in renderers | |
self.renderers["json"] = (json_renderer, lambda: True) | |
# Hooks | |
hook_renderers = [] | |
# pylint: disable=no-member | |
for hook in pm.hook.register_output_renderer(datasette=self): | |
if type(hook) is list: | |
hook_renderers += hook | |
else: | |
hook_renderers.append(hook) | |
for renderer in hook_renderers: | |
self.renderers[renderer["extension"]] = ( | |
# It used to be called "callback" - remove this in Datasette 1.0 | |
renderer.get("render") or renderer["callback"], | |
renderer.get("can_render") or (lambda: True), | |
) | |
async def render_template( | |
self, | |
templates: Union[List[str], str, Template], | |
context: Optional[Union[Dict[str, Any], Context]] = None, | |
request: Optional[Request] = None, | |
view_name: Optional[str] = None, | |
): | |
if not self._startup_invoked: | |
raise Exception("render_template() called before await ds.invoke_startup()") | |
context = context or {} | |
if isinstance(templates, Template): | |
template = templates | |
else: | |
if isinstance(templates, str): | |
templates = [templates] | |
template = self.get_jinja_environment(request).select_template(templates) | |
if dataclasses.is_dataclass(context): | |
context = dataclasses.asdict(context) | |
body_scripts = [] | |
# pylint: disable=no-member | |
for extra_script in pm.hook.extra_body_script( | |
template=template.name, | |
database=context.get("database"), | |
table=context.get("table"), | |
columns=context.get("columns"), | |
view_name=view_name, | |
request=request, | |
datasette=self, | |
): | |
extra_script = await await_me_maybe(extra_script) | |
if isinstance(extra_script, dict): | |
script = extra_script["script"] | |
module = bool(extra_script.get("module")) | |
else: | |
script = extra_script | |
module = False | |
body_scripts.append({"script": Markup(script), "module": module}) | |
extra_template_vars = {} | |
# pylint: disable=no-member | |
for extra_vars in pm.hook.extra_template_vars( | |
template=template.name, | |
database=context.get("database"), | |
table=context.get("table"), | |
columns=context.get("columns"), | |
view_name=view_name, | |
request=request, | |
datasette=self, | |
): | |
extra_vars = await await_me_maybe(extra_vars) | |
assert isinstance(extra_vars, dict), "extra_vars is of type {}".format( | |
type(extra_vars) | |
) | |
extra_template_vars.update(extra_vars) | |
async def menu_links(): | |
links = [] | |
for hook in pm.hook.menu_links( | |
datasette=self, | |
actor=request.actor if request else None, | |
request=request or None, | |
): | |
extra_links = await await_me_maybe(hook) | |
if extra_links: | |
links.extend(extra_links) | |
return links | |
template_context = { | |
**context, | |
**{ | |
"request": request, | |
"crumb_items": self._crumb_items, | |
"urls": self.urls, | |
"actor": request.actor if request else None, | |
"menu_links": menu_links, | |
"display_actor": display_actor, | |
"show_logout": request is not None | |
and "ds_actor" in request.cookies | |
and request.actor, | |
"app_css_hash": self.app_css_hash(), | |
"zip": zip, | |
"body_scripts": body_scripts, | |
"format_bytes": format_bytes, | |
"show_messages": lambda: self._show_messages(request), | |
"extra_css_urls": await self._asset_urls( | |
"extra_css_urls", template, context, request, view_name | |
), | |
"extra_js_urls": await self._asset_urls( | |
"extra_js_urls", template, context, request, view_name | |
), | |
"base_url": self.setting("base_url"), | |
"csrftoken": request.scope["csrftoken"] if request else lambda: "", | |
"datasette_version": __version__, | |
}, | |
**extra_template_vars, | |
} | |
if request and request.args.get("_context") and self.setting("template_debug"): | |
return "<pre>{}</pre>".format( | |
escape(json.dumps(template_context, default=repr, indent=4)) | |
) | |
return await template.render_async(template_context) | |
def set_actor_cookie( | |
self, response: Response, actor: dict, expire_after: Optional[int] = None | |
): | |
data = {"a": actor} | |
if expire_after: | |
expires_at = int(time.time()) + (24 * 60 * 60) | |
data["e"] = baseconv.base62.encode(expires_at) | |
response.set_cookie("ds_actor", self.sign(data, "actor")) | |
def delete_actor_cookie(self, response: Response): | |
response.set_cookie("ds_actor", "", expires=0, max_age=0) | |
async def _asset_urls(self, key, template, context, request, view_name): | |
# Flatten list-of-lists from plugins: | |
seen_urls = set() | |
collected = [] | |
for hook in getattr(pm.hook, key)( | |
template=template.name, | |
database=context.get("database"), | |
table=context.get("table"), | |
columns=context.get("columns"), | |
view_name=view_name, | |
request=request, | |
datasette=self, | |
): | |
hook = await await_me_maybe(hook) | |
collected.extend(hook) | |
collected.extend((self.config or {}).get(key) or []) | |
output = [] | |
for url_or_dict in collected: | |
if isinstance(url_or_dict, dict): | |
url = url_or_dict["url"] | |
sri = url_or_dict.get("sri") | |
module = bool(url_or_dict.get("module")) | |
else: | |
url = url_or_dict | |
sri = None | |
module = False | |
if url in seen_urls: | |
continue | |
seen_urls.add(url) | |
if url.startswith("/"): | |
# Take base_url into account: | |
url = self.urls.path(url) | |
script = {"url": url} | |
if sri: | |
script["sri"] = sri | |
if module: | |
script["module"] = True | |
output.append(script) | |
return output | |
def _config(self): | |
return redact_keys( | |
self.config, ("secret", "key", "password", "token", "hash", "dsn") | |
) | |
def _routes(self): | |
routes = [] | |
for routes_to_add in pm.hook.register_routes(datasette=self): | |
for regex, view_fn in routes_to_add: | |
routes.append((regex, wrap_view(view_fn, self))) | |
def add_route(view, regex): | |
routes.append((regex, view)) | |
add_route(IndexView.as_view(self), r"/(\.(?P<format>jsono?))?$") | |
add_route(IndexView.as_view(self), r"/-/(\.(?P<format>jsono?))?$") | |
add_route(permanent_redirect("/-/"), r"/-$") | |
# TODO: /favicon.ico and /-/static/ deserve far-future cache expires | |
add_route(favicon, "/favicon.ico") | |
add_route( | |
asgi_static(app_root / "datasette" / "static"), r"/-/static/(?P<path>.*)$" | |
) | |
for path, dirname in self.static_mounts: | |
add_route(asgi_static(dirname), r"/" + path + "/(?P<path>.*)$") | |
# Mount any plugin static/ directories | |
for plugin in get_plugins(): | |
if plugin["static_path"]: | |
add_route( | |
asgi_static(plugin["static_path"]), | |
f"/-/static-plugins/{plugin['name']}/(?P<path>.*)$", | |
) | |
# Support underscores in name in addition to hyphens, see https://github.com/simonw/datasette/issues/611 | |
add_route( | |
asgi_static(plugin["static_path"]), | |
"/-/static-plugins/{}/(?P<path>.*)$".format( | |
plugin["name"].replace("-", "_") | |
), | |
) | |
add_route( | |
permanent_redirect( | |
"/_memory", forward_query_string=True, forward_rest=True | |
), | |
r"/:memory:(?P<rest>.*)$", | |
) | |
add_route( | |
JsonDataView.as_view(self, "versions.json", self._versions), | |
r"/-/versions(\.(?P<format>json))?$", | |
) | |
add_route( | |
JsonDataView.as_view( | |
self, "plugins.json", self._plugins, needs_request=True | |
), | |
r"/-/plugins(\.(?P<format>json))?$", | |
) | |
add_route( | |
JsonDataView.as_view(self, "settings.json", lambda: self._settings), | |
r"/-/settings(\.(?P<format>json))?$", | |
) | |
add_route( | |
JsonDataView.as_view(self, "config.json", lambda: self._config()), | |
r"/-/config(\.(?P<format>json))?$", | |
) | |
add_route( | |
JsonDataView.as_view(self, "threads.json", self._threads), | |
r"/-/threads(\.(?P<format>json))?$", | |
) | |
add_route( | |
JsonDataView.as_view(self, "databases.json", self._connected_databases), | |
r"/-/databases(\.(?P<format>json))?$", | |
) | |
add_route( | |
JsonDataView.as_view( | |
self, "actor.json", self._actor, needs_request=True, permission=None | |
), | |
r"/-/actor(\.(?P<format>json))?$", | |
) | |
add_route( | |
AuthTokenView.as_view(self), | |
r"/-/auth-token$", | |
) | |
add_route( | |
CreateTokenView.as_view(self), | |
r"/-/create-token$", | |
) | |
add_route( | |
ApiExplorerView.as_view(self), | |
r"/-/api$", | |
) | |
add_route( | |
LogoutView.as_view(self), | |
r"/-/logout$", | |
) | |
add_route( | |
PermissionsDebugView.as_view(self), | |
r"/-/permissions$", | |
) | |
add_route( | |
MessagesDebugView.as_view(self), | |
r"/-/messages$", | |
) | |
add_route( | |
AllowDebugView.as_view(self), | |
r"/-/allow-debug$", | |
) | |
add_route( | |
wrap_view(PatternPortfolioView, self), | |
r"/-/patterns$", | |
) | |
add_route( | |
wrap_view(database_download, self), | |
r"/(?P<database>[^\/\.]+)\.db$", | |
) | |
add_route( | |
wrap_view(DatabaseView, self), | |
r"/(?P<database>[^\/\.]+)(\.(?P<format>\w+))?$", | |
) | |
add_route(TableCreateView.as_view(self), r"/(?P<database>[^\/\.]+)/-/create$") | |
add_route( | |
wrap_view(QueryView, self), | |
r"/(?P<database>[^\/\.]+)/-/query(\.(?P<format>\w+))?$", | |
) | |
add_route( | |
wrap_view(table_view, self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)(\.(?P<format>\w+))?$", | |
) | |
add_route( | |
RowView.as_view(self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^/]+?)/(?P<pks>[^/]+?)(\.(?P<format>\w+))?$", | |
) | |
add_route( | |
TableInsertView.as_view(self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)/-/insert$", | |
) | |
add_route( | |
TableUpsertView.as_view(self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)/-/upsert$", | |
) | |
add_route( | |
TableDropView.as_view(self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^\/\.]+)/-/drop$", | |
) | |
add_route( | |
RowDeleteView.as_view(self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^/]+?)/(?P<pks>[^/]+?)/-/delete$", | |
) | |
add_route( | |
RowUpdateView.as_view(self), | |
r"/(?P<database>[^\/\.]+)/(?P<table>[^/]+?)/(?P<pks>[^/]+?)/-/update$", | |
) | |
return [ | |
# Compile any strings to regular expressions | |
((re.compile(pattern) if isinstance(pattern, str) else pattern), view) | |
for pattern, view in routes | |
] | |
async def resolve_database(self, request): | |
database_route = tilde_decode(request.url_vars["database"]) | |
try: | |
return self.get_database(route=database_route) | |
except KeyError: | |
raise DatabaseNotFound(database_route) | |
async def resolve_table(self, request): | |
db = await self.resolve_database(request) | |
table_name = tilde_decode(request.url_vars["table"]) | |
# Table must exist | |
is_view = False | |
table_exists = await db.table_exists(table_name) | |
if not table_exists: | |
is_view = await db.view_exists(table_name) | |
if not (table_exists or is_view): | |
raise TableNotFound(db.name, table_name) | |
return ResolvedTable(db, table_name, is_view) | |
async def resolve_row(self, request): | |
db, table_name, _ = await self.resolve_table(request) | |
pk_values = urlsafe_components(request.url_vars["pks"]) | |
sql, params, pks = await row_sql_params_pks(db, table_name, pk_values) | |
results = await db.execute(sql, params, truncate=True) | |
row = results.first() | |
if row is None: | |
raise RowNotFound(db.name, table_name, pk_values) | |
return ResolvedRow(db, table_name, sql, params, pks, pk_values, results.first()) | |
def app(self): | |
"""Returns an ASGI app function that serves the whole of Datasette""" | |
routes = self._routes() | |
async def setup_db(): | |
# First time server starts up, calculate table counts for immutable databases | |
for database in self.databases.values(): | |
if not database.is_mutable: | |
await database.table_counts(limit=60 * 60 * 1000) | |
async def custom_csrf_error(scope, send, message_id): | |
await asgi_send( | |
send, | |
content=await self.render_template( | |
"csrf_error.html", | |
{"message_id": message_id, "message_name": Errors(message_id).name}, | |
), | |
status=403, | |
content_type="text/html; charset=utf-8", | |
) | |
asgi = asgi_csrf.asgi_csrf( | |
DatasetteRouter(self, routes), | |
signing_secret=self._secret, | |
cookie_name="ds_csrftoken", | |
skip_if_scope=lambda scope: any( | |
pm.hook.skip_csrf(datasette=self, scope=scope) | |
), | |
send_csrf_failed=custom_csrf_error, | |
) | |
if self.setting("trace_debug"): | |
asgi = AsgiTracer(asgi) | |
asgi = AsgiLifespan(asgi) | |
asgi = AsgiRunOnFirstRequest(asgi, on_startup=[setup_db, self.invoke_startup]) | |
for wrapper in pm.hook.asgi_wrapper(datasette=self): | |
asgi = wrapper(asgi) | |
return asgi | |
class DatasetteRouter: | |
def __init__(self, datasette, routes): | |
self.ds = datasette | |
self.routes = routes or [] | |
async def __call__(self, scope, receive, send): | |
# Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves | |
path = scope["path"] | |
raw_path = scope.get("raw_path") | |
if raw_path: | |
path = raw_path.decode("ascii") | |
path = path.partition("?")[0] | |
token = self.ds._request_id.set(str(uuid.uuid4())) | |
try: | |
return await self.route_path(scope, receive, send, path) | |
finally: | |
self.ds._request_id.reset(token) | |
async def route_path(self, scope, receive, send, path): | |
# Strip off base_url if present before routing | |
base_url = self.ds.setting("base_url") | |
if base_url != "/" and path.startswith(base_url): | |
path = "/" + path[len(base_url) :] | |
scope = dict(scope, route_path=path) | |
request = Request(scope, receive) | |
# Populate request_messages if ds_messages cookie is present | |
try: | |
request._messages = self.ds.unsign( | |
request.cookies.get("ds_messages", ""), "messages" | |
) | |
except BadSignature: | |
pass | |
scope_modifications = {} | |
# Apply force_https_urls, if set | |
if ( | |
self.ds.setting("force_https_urls") | |
and scope["type"] == "http" | |
and scope.get("scheme") != "https" | |
): | |
scope_modifications["scheme"] = "https" | |
# Handle authentication | |
default_actor = scope.get("actor") or None | |
actor = None | |
for actor in pm.hook.actor_from_request(datasette=self.ds, request=request): | |
actor = await await_me_maybe(actor) | |
if actor: | |
break | |
scope_modifications["actor"] = actor or default_actor | |
scope = dict(scope, **scope_modifications) | |
match, view = resolve_routes(self.routes, path) | |
if match is None: | |
return await self.handle_404(request, send) | |
new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) | |
request.scope = new_scope | |
try: | |
response = await view(request, send) | |
if response: | |
self.ds._write_messages_to_response(request, response) | |
await response.asgi_send(send) | |
return | |
except NotFound as exception: | |
return await self.handle_404(request, send, exception) | |
except Forbidden as exception: | |
# Try the forbidden() plugin hook | |
for custom_response in pm.hook.forbidden( | |
datasette=self.ds, request=request, message=exception.args[0] | |
): | |
custom_response = await await_me_maybe(custom_response) | |
assert ( | |
custom_response | |
), "Default forbidden() hook should have been called" | |
return await custom_response.asgi_send(send) | |
except Exception as exception: | |
return await self.handle_exception(request, send, exception) | |
async def handle_404(self, request, send, exception=None): | |
# If path contains % encoding, redirect to tilde encoding | |
if "%" in request.path: | |
# Try the same path but with "%" replaced by "~" | |
# and "~" replaced with "~7E" | |
# and "." replaced with "~2E" | |
new_path = ( | |
request.path.replace("~", "~7E").replace("%", "~").replace(".", "~2E") | |
) | |
if request.query_string: | |
new_path += "?{}".format(request.query_string) | |
await asgi_send_redirect(send, new_path) | |
return | |
# If URL has a trailing slash, redirect to URL without it | |
path = request.scope.get( | |
"raw_path", request.scope["path"].encode("utf8") | |
).partition(b"?")[0] | |
context = {} | |
if path.endswith(b"/"): | |
path = path.rstrip(b"/") | |
if request.scope["query_string"]: | |
path += b"?" + request.scope["query_string"] | |
await asgi_send_redirect(send, path.decode("latin1")) | |
else: | |
# Is there a pages/* template matching this path? | |
route_path = request.scope.get("route_path", request.scope["path"]) | |
# Jinja requires template names to use "/" even on Windows | |
template_name = "pages" + route_path + ".html" | |
# Build a list of pages/blah/{name}.html matching expressions | |
environment = self.ds.get_jinja_environment(request) | |
pattern_templates = [ | |
filepath | |
for filepath in environment.list_templates() | |
if "{" in filepath and filepath.startswith("pages/") | |
] | |
page_routes = [ | |
(route_pattern_from_filepath(filepath[len("pages/") :]), filepath) | |
for filepath in pattern_templates | |
] | |
try: | |
template = environment.select_template([template_name]) | |
except TemplateNotFound: | |
template = None | |
if template is None: | |
# Try for a pages/blah/{name}.html template match | |
for regex, wildcard_template in page_routes: | |
match = regex.match(route_path) | |
if match is not None: | |
context.update(match.groupdict()) | |
template = wildcard_template | |
break | |
if template: | |
headers = {} | |
status = [200] | |
def custom_header(name, value): | |
headers[name] = value | |
return "" | |
def custom_status(code): | |
status[0] = code | |
return "" | |
def custom_redirect(location, code=302): | |
status[0] = code | |
headers["Location"] = location | |
return "" | |
def raise_404(message=""): | |
raise NotFoundExplicit(message) | |
context.update( | |
{ | |
"custom_header": custom_header, | |
"custom_status": custom_status, | |
"custom_redirect": custom_redirect, | |
"raise_404": raise_404, | |
} | |
) | |
try: | |
body = await self.ds.render_template( | |
template, | |
context, | |
request=request, | |
view_name="page", | |
) | |
except NotFoundExplicit as e: | |
await self.handle_exception(request, send, e) | |
return | |
# Pull content-type out into separate parameter | |
content_type = "text/html; charset=utf-8" | |
matches = [k for k in headers if k.lower() == "content-type"] | |
if matches: | |
content_type = headers[matches[0]] | |
await asgi_send( | |
send, | |
body, | |
status=status[0], | |
headers=headers, | |
content_type=content_type, | |
) | |
else: | |
await self.handle_exception(request, send, exception or NotFound("404")) | |
async def handle_exception(self, request, send, exception): | |
responses = [] | |
for hook in pm.hook.handle_exception( | |
datasette=self.ds, | |
request=request, | |
exception=exception, | |
): | |
response = await await_me_maybe(hook) | |
if response is not None: | |
responses.append(response) | |
assert responses, "Default exception handler should have returned something" | |
# Even if there are multiple responses use just the first one | |
response = responses[0] | |
await response.asgi_send(send) | |
_cleaner_task_str_re = re.compile(r"\S*site-packages/") | |
def _cleaner_task_str(task): | |
s = str(task) | |
# This has something like the following in it: | |
# running at /Users/simonw/Dropbox/Development/datasette/venv-3.7.5/lib/python3.7/site-packages/uvicorn/main.py:361> | |
# Clean up everything up to and including site-packages | |
return _cleaner_task_str_re.sub("", s) | |
def wrap_view(view_fn_or_class, datasette): | |
is_function = isinstance(view_fn_or_class, types.FunctionType) | |
if is_function: | |
return wrap_view_function(view_fn_or_class, datasette) | |
else: | |
if not isinstance(view_fn_or_class, type): | |
raise ValueError("view_fn_or_class must be a function or a class") | |
return wrap_view_class(view_fn_or_class, datasette) | |
def wrap_view_class(view_class, datasette): | |
async def async_view_for_class(request, send): | |
instance = view_class() | |
if inspect.iscoroutinefunction(instance.__call__): | |
return await async_call_with_supported_arguments( | |
instance.__call__, | |
scope=request.scope, | |
receive=request.receive, | |
send=send, | |
request=request, | |
datasette=datasette, | |
) | |
else: | |
return call_with_supported_arguments( | |
instance.__call__, | |
scope=request.scope, | |
receive=request.receive, | |
send=send, | |
request=request, | |
datasette=datasette, | |
) | |
async_view_for_class.view_class = view_class | |
return async_view_for_class | |
def wrap_view_function(view_fn, datasette): | |
@functools.wraps(view_fn) | |
async def async_view_fn(request, send): | |
if inspect.iscoroutinefunction(view_fn): | |
response = await async_call_with_supported_arguments( | |
view_fn, | |
scope=request.scope, | |
receive=request.receive, | |
send=send, | |
request=request, | |
datasette=datasette, | |
) | |
else: | |
response = call_with_supported_arguments( | |
view_fn, | |
scope=request.scope, | |
receive=request.receive, | |
send=send, | |
request=request, | |
datasette=datasette, | |
) | |
if response is not None: | |
return response | |
return async_view_fn | |
def permanent_redirect(path, forward_query_string=False, forward_rest=False): | |
return wrap_view( | |
lambda request, send: Response.redirect( | |
path | |
+ (request.url_vars["rest"] if forward_rest else "") | |
+ ( | |
("?" + request.query_string) | |
if forward_query_string and request.query_string | |
else "" | |
), | |
status=301, | |
), | |
datasette=None, | |
) | |
_curly_re = re.compile(r"({.*?})") | |
def route_pattern_from_filepath(filepath): | |
# Drop the ".html" suffix | |
if filepath.endswith(".html"): | |
filepath = filepath[: -len(".html")] | |
re_bits = ["/"] | |
for bit in _curly_re.split(filepath): | |
if _curly_re.match(bit): | |
re_bits.append(f"(?P<{bit[1:-1]}>[^/]*)") | |
else: | |
re_bits.append(re.escape(bit)) | |
return re.compile("^" + "".join(re_bits) + "$") | |
class NotFoundExplicit(NotFound): | |
pass | |
class DatasetteClient: | |
def __init__(self, ds): | |
self.ds = ds | |
self.app = ds.app() | |
def actor_cookie(self, actor): | |
# Utility method, mainly for tests | |
return self.ds.sign({"a": actor}, "actor") | |
def _fix(self, path, avoid_path_rewrites=False): | |
if not isinstance(path, PrefixedUrlString) and not avoid_path_rewrites: | |
path = self.ds.urls.path(path) | |
if path.startswith("/"): | |
path = f"https://localhost{path}" | |
return path | |
async def _request(self, method, path, **kwargs): | |
async with httpx.AsyncClient( | |
transport=httpx.ASGITransport(app=self.app), | |
cookies=kwargs.pop("cookies", None), | |
) as client: | |
return await getattr(client, method)(self._fix(path), **kwargs) | |
async def get(self, path, **kwargs): | |
return await self._request("get", path, **kwargs) | |
async def options(self, path, **kwargs): | |
return await self._request("options", path, **kwargs) | |
async def head(self, path, **kwargs): | |
return await self._request("head", path, **kwargs) | |
async def post(self, path, **kwargs): | |
return await self._request("post", path, **kwargs) | |
async def put(self, path, **kwargs): | |
return await self._request("put", path, **kwargs) | |
async def patch(self, path, **kwargs): | |
return await self._request("patch", path, **kwargs) | |
async def delete(self, path, **kwargs): | |
return await self._request("delete", path, **kwargs) | |
async def request(self, method, path, **kwargs): | |
avoid_path_rewrites = kwargs.pop("avoid_path_rewrites", None) | |
async with httpx.AsyncClient( | |
transport=httpx.ASGITransport(app=self.app), | |
cookies=kwargs.pop("cookies", None), | |
) as client: | |
return await client.request( | |
method, self._fix(path, avoid_path_rewrites), **kwargs | |
) | |
</document_content> | |
</document> | |
<document index="5"> | |
<source>datasette/blob_renderer.py</source> | |
<document_content> | |
from datasette import hookimpl | |
from datasette.utils.asgi import Response, BadRequest | |
from datasette.utils import to_css_class | |
import hashlib | |
_BLOB_COLUMN = "_blob_column" | |
_BLOB_HASH = "_blob_hash" | |
async def render_blob(datasette, database, rows, columns, request, table, view_name): | |
if _BLOB_COLUMN not in request.args: | |
raise BadRequest(f"?{_BLOB_COLUMN}= is required") | |
blob_column = request.args[_BLOB_COLUMN] | |
if blob_column not in columns: | |
raise BadRequest(f"{blob_column} is not a valid column") | |
# If ?_blob_hash= provided, use that to select the row - otherwise use first row | |
blob_hash = None | |
if _BLOB_HASH in request.args: | |
blob_hash = request.args[_BLOB_HASH] | |
for row in rows: | |
value = row[blob_column] | |
if hashlib.sha256(value).hexdigest() == blob_hash: | |
break | |
else: | |
# Loop did not break | |
raise BadRequest( | |
"Link has expired - the requested binary content has changed or could not be found." | |
) | |
else: | |
row = rows[0] | |
value = row[blob_column] | |
filename_bits = [] | |
if table: | |
filename_bits.append(to_css_class(table)) | |
if "pks" in request.url_vars: | |
filename_bits.append(request.url_vars["pks"]) | |
filename_bits.append(to_css_class(blob_column)) | |
if blob_hash: | |
filename_bits.append(blob_hash[:6]) | |
filename = "-".join(filename_bits) + ".blob" | |
headers = { | |
"X-Content-Type-Options": "nosniff", | |
"Content-Disposition": f'attachment; filename="{filename}"', | |
} | |
return Response( | |
body=value or b"", | |
status=200, | |
headers=headers, | |
content_type="application/binary", | |
) | |
@hookimpl | |
def register_output_renderer(): | |
return { | |
"extension": "blob", | |
"render": render_blob, | |
"can_render": lambda: False, | |
} | |
</document_content> | |
</document> | |
<document index="6"> | |
<source>datasette/cli.py</source> | |
<document_content> | |
import asyncio | |
import uvicorn | |
import click | |
from click import formatting | |
from click.types import CompositeParamType | |
from click_default_group import DefaultGroup | |
import functools | |
import json | |
import os | |
import pathlib | |
from runpy import run_module | |
import shutil | |
from subprocess import call | |
import sys | |
import textwrap | |
import webbrowser | |
from .app import ( | |
Datasette, | |
DEFAULT_SETTINGS, | |
SETTINGS, | |
SQLITE_LIMIT_ATTACHED, | |
pm, | |
) | |
from .utils import ( | |
LoadExtension, | |
StartupError, | |
check_connection, | |
deep_dict_update, | |
find_spatialite, | |
parse_metadata, | |
ConnectionProblem, | |
SpatialiteConnectionProblem, | |
initial_path_for_datasette, | |
pairs_to_nested_config, | |
temporary_docker_directory, | |
value_as_boolean, | |
SpatialiteNotFound, | |
StaticMount, | |
ValueAsBooleanError, | |
) | |
from .utils.sqlite import sqlite3 | |
from .utils.testing import TestClient | |
from .version import __version__ | |
# Use Rich for tracebacks if it is installed | |
try: | |
from rich.traceback import install | |
install(show_locals=True) | |
except ImportError: | |
pass | |
class Setting(CompositeParamType): | |
name = "setting" | |
arity = 2 | |
def convert(self, config, param, ctx): | |
name, value = config | |
if name in DEFAULT_SETTINGS: | |
# For backwards compatibility with how this worked prior to | |
# Datasette 1.0, we turn bare setting names into setting.name | |
# Type checking for those older settings | |
default = DEFAULT_SETTINGS[name] | |
name = "settings.{}".format(name) | |
if isinstance(default, bool): | |
try: | |
return name, "true" if value_as_boolean(value) else "false" | |
except ValueAsBooleanError: | |
self.fail(f'"{name}" should be on/off/true/false/1/0', param, ctx) | |
elif isinstance(default, int): | |
if not value.isdigit(): | |
self.fail(f'"{name}" should be an integer', param, ctx) | |
return name, value | |
elif isinstance(default, str): | |
return name, value | |
else: | |
# Should never happen: | |
self.fail("Invalid option") | |
return name, value | |
def sqlite_extensions(fn): | |
fn = click.option( | |
"sqlite_extensions", | |
"--load-extension", | |
type=LoadExtension(), | |
envvar="DATASETTE_LOAD_EXTENSION", | |
multiple=True, | |
help="Path to a SQLite extension to load, and optional entrypoint", | |
)(fn) | |
# Wrap it in a custom error handler | |
@functools.wraps(fn) | |
def wrapped(*args, **kwargs): | |
try: | |
return fn(*args, **kwargs) | |
except AttributeError as e: | |
if "enable_load_extension" in str(e): | |
raise click.ClickException( | |
textwrap.dedent( | |
""" | |
Your Python installation does not have the ability to load SQLite extensions. | |
More information: https://datasette.io/help/extensions | |
""" | |
).strip() | |
) | |
raise | |
return wrapped | |
@click.group(cls=DefaultGroup, default="serve", default_if_no_args=True) | |
@click.version_option(version=__version__) | |
def cli(): | |
""" | |
Datasette is an open source multi-tool for exploring and publishing data | |
\b | |
About Datasette: https://datasette.io/ | |
Full documentation: https://docs.datasette.io/ | |
""" | |
@cli.command() | |
@click.argument("files", type=click.Path(exists=True), nargs=-1) | |
@click.option("--inspect-file", default="-") | |
@sqlite_extensions | |
def inspect(files, inspect_file, sqlite_extensions): | |
""" | |
Generate JSON summary of provided database files | |
This can then be passed to "datasette --inspect-file" to speed up count | |
operations against immutable database files. | |
""" | |
app = Datasette([], immutables=files, sqlite_extensions=sqlite_extensions) | |
loop = asyncio.get_event_loop() | |
inspect_data = loop.run_until_complete(inspect_(files, sqlite_extensions)) | |
if inspect_file == "-": | |
sys.stdout.write(json.dumps(inspect_data, indent=2)) | |
else: | |
with open(inspect_file, "w") as fp: | |
fp.write(json.dumps(inspect_data, indent=2)) | |
async def inspect_(files, sqlite_extensions): | |
app = Datasette([], immutables=files, sqlite_extensions=sqlite_extensions) | |
data = {} | |
for name, database in app.databases.items(): | |
counts = await database.table_counts(limit=3600 * 1000) | |
data[name] = { | |
"hash": database.hash, | |
"size": database.size, | |
"file": database.path, | |
"tables": { | |
table_name: {"count": table_count} | |
for table_name, table_count in counts.items() | |
}, | |
} | |
return data | |
@cli.group() | |
def publish(): | |
"""Publish specified SQLite database files to the internet along with a Datasette-powered interface and API""" | |
pass | |
# Register publish plugins | |
pm.hook.publish_subcommand(publish=publish) | |
@cli.command() | |
@click.option("--all", help="Include built-in default plugins", is_flag=True) | |
@click.option( | |
"--requirements", help="Output requirements.txt of installed plugins", is_flag=True | |
) | |
@click.option( | |
"--plugins-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom plugins", | |
) | |
def plugins(all, requirements, plugins_dir): | |
"""List currently installed plugins""" | |
app = Datasette([], plugins_dir=plugins_dir) | |
if requirements: | |
for plugin in app._plugins(): | |
if plugin["version"]: | |
click.echo("{}=={}".format(plugin["name"], plugin["version"])) | |
else: | |
click.echo(json.dumps(app._plugins(all=all), indent=4)) | |
@cli.command() | |
@click.argument("files", type=click.Path(exists=True), nargs=-1, required=True) | |
@click.option( | |
"-t", | |
"--tag", | |
help="Name for the resulting Docker container, can optionally use name:tag format", | |
) | |
@click.option( | |
"-m", | |
"--metadata", | |
type=click.File(mode="r"), | |
help="Path to JSON/YAML file containing metadata to publish", | |
) | |
@click.option("--extra-options", help="Extra options to pass to datasette serve") | |
@click.option("--branch", help="Install datasette from a GitHub branch e.g. main") | |
@click.option( | |
"--template-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom templates", | |
) | |
@click.option( | |
"--plugins-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom plugins", | |
) | |
@click.option( | |
"--static", | |
type=StaticMount(), | |
help="Serve static files from this directory at /MOUNT/...", | |
multiple=True, | |
) | |
@click.option( | |
"--install", help="Additional packages (e.g. plugins) to install", multiple=True | |
) | |
@click.option("--spatialite", is_flag=True, help="Enable SpatialLite extension") | |
@click.option("--version-note", help="Additional note to show on /-/versions") | |
@click.option( | |
"--secret", | |
help="Secret used for signing secure values, such as signed cookies", | |
envvar="DATASETTE_PUBLISH_SECRET", | |
default=lambda: os.urandom(32).hex(), | |
) | |
@click.option( | |
"-p", | |
"--port", | |
default=8001, | |
type=click.IntRange(1, 65535), | |
help="Port to run the server on, defaults to 8001", | |
) | |
@click.option("--title", help="Title for metadata") | |
@click.option("--license", help="License label for metadata") | |
@click.option("--license_url", help="License URL for metadata") | |
@click.option("--source", help="Source label for metadata") | |
@click.option("--source_url", help="Source URL for metadata") | |
@click.option("--about", help="About label for metadata") | |
@click.option("--about_url", help="About URL for metadata") | |
def package( | |
files, | |
tag, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
spatialite, | |
version_note, | |
secret, | |
port, | |
**extra_metadata, | |
): | |
"""Package SQLite files into a Datasette Docker container""" | |
if not shutil.which("docker"): | |
click.secho( | |
' The package command requires "docker" to be installed and configured ', | |
bg="red", | |
fg="white", | |
bold=True, | |
err=True, | |
) | |
sys.exit(1) | |
with temporary_docker_directory( | |
files, | |
"datasette", | |
metadata=metadata, | |
extra_options=extra_options, | |
branch=branch, | |
template_dir=template_dir, | |
plugins_dir=plugins_dir, | |
static=static, | |
install=install, | |
spatialite=spatialite, | |
version_note=version_note, | |
secret=secret, | |
extra_metadata=extra_metadata, | |
port=port, | |
): | |
args = ["docker", "build"] | |
if tag: | |
args.append("-t") | |
args.append(tag) | |
args.append(".") | |
call(args) | |
@cli.command() | |
@click.argument("packages", nargs=-1) | |
@click.option( | |
"-U", "--upgrade", is_flag=True, help="Upgrade packages to latest version" | |
) | |
@click.option( | |
"-r", | |
"--requirement", | |
type=click.Path(exists=True), | |
help="Install from requirements file", | |
) | |
@click.option( | |
"-e", | |
"--editable", | |
help="Install a project in editable mode from this path", | |
) | |
def install(packages, upgrade, requirement, editable): | |
"""Install plugins and packages from PyPI into the same environment as Datasette""" | |
if not packages and not requirement and not editable: | |
raise click.UsageError("Please specify at least one package to install") | |
args = ["pip", "install"] | |
if upgrade: | |
args += ["--upgrade"] | |
if editable: | |
args += ["--editable", editable] | |
if requirement: | |
args += ["-r", requirement] | |
args += list(packages) | |
sys.argv = args | |
run_module("pip", run_name="__main__") | |
@cli.command() | |
@click.argument("packages", nargs=-1, required=True) | |
@click.option("-y", "--yes", is_flag=True, help="Don't ask for confirmation") | |
def uninstall(packages, yes): | |
"""Uninstall plugins and Python packages from the Datasette environment""" | |
sys.argv = ["pip", "uninstall"] + list(packages) + (["-y"] if yes else []) | |
run_module("pip", run_name="__main__") | |
@cli.command() | |
@click.argument("files", type=click.Path(), nargs=-1) | |
@click.option( | |
"-i", | |
"--immutable", | |
type=click.Path(exists=True), | |
help="Database files to open in immutable mode", | |
multiple=True, | |
) | |
@click.option( | |
"-h", | |
"--host", | |
default="127.0.0.1", | |
help=( | |
"Host for server. Defaults to 127.0.0.1 which means only connections " | |
"from the local machine will be allowed. Use 0.0.0.0 to listen to " | |
"all IPs and allow access from other machines." | |
), | |
) | |
@click.option( | |
"-p", | |
"--port", | |
default=8001, | |
type=click.IntRange(0, 65535), | |
help="Port for server, defaults to 8001. Use -p 0 to automatically assign an available port.", | |
) | |
@click.option( | |
"--uds", | |
help="Bind to a Unix domain socket", | |
) | |
@click.option( | |
"--reload", | |
is_flag=True, | |
help="Automatically reload if code or metadata change detected - useful for development", | |
) | |
@click.option( | |
"--cors", is_flag=True, help="Enable CORS by serving Access-Control-Allow-Origin: *" | |
) | |
@sqlite_extensions | |
@click.option( | |
"--inspect-file", help='Path to JSON file created using "datasette inspect"' | |
) | |
@click.option( | |
"-m", | |
"--metadata", | |
type=click.File(mode="r"), | |
help="Path to JSON/YAML file containing license/source metadata", | |
) | |
@click.option( | |
"--template-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom templates", | |
) | |
@click.option( | |
"--plugins-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom plugins", | |
) | |
@click.option( | |
"--static", | |
type=StaticMount(), | |
help="Serve static files from this directory at /MOUNT/...", | |
multiple=True, | |
) | |
@click.option("--memory", is_flag=True, help="Make /_memory database available") | |
@click.option( | |
"-c", | |
"--config", | |
type=click.File(mode="r"), | |
help="Path to JSON/YAML Datasette configuration file", | |
) | |
@click.option( | |
"-s", | |
"--setting", | |
"settings", | |
type=Setting(), | |
help="nested.key, value setting to use in Datasette configuration", | |
multiple=True, | |
) | |
@click.option( | |
"--secret", | |
help="Secret used for signing secure values, such as signed cookies", | |
envvar="DATASETTE_SECRET", | |
) | |
@click.option( | |
"--root", | |
help="Output URL that sets a cookie authenticating the root user", | |
is_flag=True, | |
) | |
@click.option( | |
"--get", | |
help="Run an HTTP GET request against this path, print results and exit", | |
) | |
@click.option( | |
"--token", | |
help="API token to send with --get requests", | |
) | |
@click.option( | |
"--actor", | |
help="Actor to use for --get requests (JSON string)", | |
) | |
@click.option("--version-note", help="Additional note to show on /-/versions") | |
@click.option("--help-settings", is_flag=True, help="Show available settings") | |
@click.option("--pdb", is_flag=True, help="Launch debugger on any errors") | |
@click.option( | |
"-o", | |
"--open", | |
"open_browser", | |
is_flag=True, | |
help="Open Datasette in your web browser", | |
) | |
@click.option( | |
"--create", | |
is_flag=True, | |
help="Create database files if they do not exist", | |
) | |
@click.option( | |
"--crossdb", | |
is_flag=True, | |
help="Enable cross-database joins using the /_memory database", | |
) | |
@click.option( | |
"--nolock", | |
is_flag=True, | |
help="Ignore locking, open locked files in read-only mode", | |
) | |
@click.option( | |
"--ssl-keyfile", | |
help="SSL key file", | |
envvar="DATASETTE_SSL_KEYFILE", | |
) | |
@click.option( | |
"--ssl-certfile", | |
help="SSL certificate file", | |
envvar="DATASETTE_SSL_CERTFILE", | |
) | |
@click.option( | |
"--internal", | |
type=click.Path(), | |
help="Path to a persistent Datasette internal SQLite database", | |
) | |
def serve( | |
files, | |
immutable, | |
host, | |
port, | |
uds, | |
reload, | |
cors, | |
sqlite_extensions, | |
inspect_file, | |
metadata, | |
template_dir, | |
plugins_dir, | |
static, | |
memory, | |
config, | |
settings, | |
secret, | |
root, | |
get, | |
token, | |
actor, | |
version_note, | |
help_settings, | |
pdb, | |
open_browser, | |
create, | |
crossdb, | |
nolock, | |
ssl_keyfile, | |
ssl_certfile, | |
internal, | |
return_instance=False, | |
): | |
"""Serve up specified SQLite database files with a web UI""" | |
if help_settings: | |
formatter = formatting.HelpFormatter() | |
with formatter.section("Settings"): | |
formatter.write_dl( | |
[ | |
(option.name, f"{option.help} (default={option.default})") | |
for option in SETTINGS | |
] | |
) | |
click.echo(formatter.getvalue()) | |
sys.exit(0) | |
if reload: | |
import hupper | |
reloader = hupper.start_reloader("datasette.cli.serve") | |
if immutable: | |
reloader.watch_files(immutable) | |
if config: | |
reloader.watch_files([config.name]) | |
if metadata: | |
reloader.watch_files([metadata.name]) | |
inspect_data = None | |
if inspect_file: | |
with open(inspect_file) as fp: | |
inspect_data = json.load(fp) | |
metadata_data = None | |
if metadata: | |
metadata_data = parse_metadata(metadata.read()) | |
config_data = None | |
if config: | |
config_data = parse_metadata(config.read()) | |
config_data = config_data or {} | |
# Merge in settings from -s/--setting | |
if settings: | |
settings_updates = pairs_to_nested_config(settings) | |
# Merge recursively, to avoid over-writing nested values | |
# https://github.com/simonw/datasette/issues/2389 | |
deep_dict_update(config_data, settings_updates) | |
kwargs = dict( | |
immutables=immutable, | |
cache_headers=not reload, | |
cors=cors, | |
inspect_data=inspect_data, | |
config=config_data, | |
metadata=metadata_data, | |
sqlite_extensions=sqlite_extensions, | |
template_dir=template_dir, | |
plugins_dir=plugins_dir, | |
static_mounts=static, | |
settings=None, # These are passed in config= now | |
memory=memory, | |
secret=secret, | |
version_note=version_note, | |
pdb=pdb, | |
crossdb=crossdb, | |
nolock=nolock, | |
internal=internal, | |
) | |
# if files is a single directory, use that as config_dir= | |
if 1 == len(files) and os.path.isdir(files[0]): | |
kwargs["config_dir"] = pathlib.Path(files[0]) | |
files = [] | |
# Verify list of files, create if needed (and --create) | |
for file in files: | |
if not pathlib.Path(file).exists(): | |
if create: | |
sqlite3.connect(file).execute("vacuum") | |
else: | |
raise click.ClickException( | |
"Invalid value for '[FILES]...': Path '{}' does not exist.".format( | |
file | |
) | |
) | |
# De-duplicate files so 'datasette db.db db.db' only attaches one /db | |
files = list(dict.fromkeys(files)) | |
try: | |
ds = Datasette(files, **kwargs) | |
except SpatialiteNotFound: | |
raise click.ClickException("Could not find SpatiaLite extension") | |
except StartupError as e: | |
raise click.ClickException(e.args[0]) | |
if return_instance: | |
# Private utility mechanism for writing unit tests | |
return ds | |
# Run the "startup" plugin hooks | |
asyncio.get_event_loop().run_until_complete(ds.invoke_startup()) | |
# Run async soundness checks - but only if we're not under pytest | |
asyncio.get_event_loop().run_until_complete(check_databases(ds)) | |
if token and not get: | |
raise click.ClickException("--token can only be used with --get") | |
if get: | |
client = TestClient(ds) | |
headers = {} | |
if token: | |
headers["Authorization"] = "Bearer {}".format(token) | |
cookies = {} | |
if actor: | |
cookies["ds_actor"] = client.actor_cookie(json.loads(actor)) | |
response = client.get(get, headers=headers, cookies=cookies) | |
click.echo(response.text) | |
exit_code = 0 if response.status == 200 else 1 | |
sys.exit(exit_code) | |
return | |
# Start the server | |
url = None | |
if root: | |
url = "https://{}:{}{}?token={}".format( | |
host, port, ds.urls.path("-/auth-token"), ds._root_token | |
) | |
click.echo(url) | |
if open_browser: | |
if url is None: | |
# Figure out most convenient URL - to table, database or homepage | |
path = asyncio.get_event_loop().run_until_complete( | |
initial_path_for_datasette(ds) | |
) | |
url = f"https://{host}:{port}{path}" | |
webbrowser.open(url) | |
uvicorn_kwargs = dict( | |
host=host, port=port, log_level="info", lifespan="on", workers=1 | |
) | |
if uds: | |
uvicorn_kwargs["uds"] = uds | |
if ssl_keyfile: | |
uvicorn_kwargs["ssl_keyfile"] = ssl_keyfile | |
if ssl_certfile: | |
uvicorn_kwargs["ssl_certfile"] = ssl_certfile | |
uvicorn.run(ds.app(), **uvicorn_kwargs) | |
@cli.command() | |
@click.argument("id") | |
@click.option( | |
"--secret", | |
help="Secret used for signing the API tokens", | |
envvar="DATASETTE_SECRET", | |
required=True, | |
) | |
@click.option( | |
"-e", | |
"--expires-after", | |
help="Token should expire after this many seconds", | |
type=int, | |
) | |
@click.option( | |
"alls", | |
"-a", | |
"--all", | |
type=str, | |
metavar="ACTION", | |
multiple=True, | |
help="Restrict token to this action", | |
) | |
@click.option( | |
"databases", | |
"-d", | |
"--database", | |
type=(str, str), | |
metavar="DB ACTION", | |
multiple=True, | |
help="Restrict token to this action on this database", | |
) | |
@click.option( | |
"resources", | |
"-r", | |
"--resource", | |
type=(str, str, str), | |
metavar="DB RESOURCE ACTION", | |
multiple=True, | |
help="Restrict token to this action on this database resource (a table, SQL view or named query)", | |
) | |
@click.option( | |
"--debug", | |
help="Show decoded token", | |
is_flag=True, | |
) | |
@click.option( | |
"--plugins-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom plugins", | |
) | |
def create_token( | |
id, secret, expires_after, alls, databases, resources, debug, plugins_dir | |
): | |
""" | |
Create a signed API token for the specified actor ID | |
Example: | |
datasette create-token root --secret mysecret | |
To allow only "view-database-download" for all databases: | |
\b | |
datasette create-token root --secret mysecret \\ | |
--all view-database-download | |
To allow "create-table" against a specific database: | |
\b | |
datasette create-token root --secret mysecret \\ | |
--database mydb create-table | |
To allow "insert-row" against a specific table: | |
\b | |
datasette create-token root --secret myscret \\ | |
--resource mydb mytable insert-row | |
Restricted actions can be specified multiple times using | |
multiple --all, --database, and --resource options. | |
Add --debug to see a decoded version of the token. | |
""" | |
ds = Datasette(secret=secret, plugins_dir=plugins_dir) | |
# Run ds.invoke_startup() in an event loop | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(ds.invoke_startup()) | |
# Warn about any unknown actions | |
actions = [] | |
actions.extend(alls) | |
actions.extend([p[1] for p in databases]) | |
actions.extend([p[2] for p in resources]) | |
for action in actions: | |
if not ds.permissions.get(action): | |
click.secho( | |
f" Unknown permission: {action} ", | |
fg="red", | |
err=True, | |
) | |
restrict_database = {} | |
for database, action in databases: | |
restrict_database.setdefault(database, []).append(action) | |
restrict_resource = {} | |
for database, resource, action in resources: | |
restrict_resource.setdefault(database, {}).setdefault(resource, []).append( | |
action | |
) | |
token = ds.create_token( | |
id, | |
expires_after=expires_after, | |
restrict_all=alls, | |
restrict_database=restrict_database, | |
restrict_resource=restrict_resource, | |
) | |
click.echo(token) | |
if debug: | |
encoded = token[len("dstok_") :] | |
click.echo("\nDecoded:\n") | |
click.echo(json.dumps(ds.unsign(encoded, namespace="token"), indent=2)) | |
pm.hook.register_commands(cli=cli) | |
async def check_databases(ds): | |
# Run check_connection against every connected database | |
# to confirm they are all usable | |
for database in list(ds.databases.values()): | |
try: | |
await database.execute_fn(check_connection) | |
except SpatialiteConnectionProblem: | |
suggestion = "" | |
try: | |
find_spatialite() | |
suggestion = "\n\nTry adding the --load-extension=spatialite option." | |
except SpatialiteNotFound: | |
pass | |
raise click.UsageError( | |
"It looks like you're trying to load a SpatiaLite" | |
+ " database without first loading the SpatiaLite module." | |
+ suggestion | |
+ "\n\nRead more: https://docs.datasette.io/en/stable/spatialite.html" | |
) | |
except ConnectionProblem as e: | |
raise click.UsageError( | |
f"Connection to {database.path} failed check: {str(e.args[0])}" | |
) | |
# If --crossdb and more than SQLITE_LIMIT_ATTACHED show warning | |
if ( | |
ds.crossdb | |
and len([db for db in ds.databases.values() if not db.is_memory]) | |
> SQLITE_LIMIT_ATTACHED | |
): | |
msg = ( | |
"Warning: --crossdb only works with the first {} attached databases".format( | |
SQLITE_LIMIT_ATTACHED | |
) | |
) | |
click.echo(click.style(msg, bold=True, fg="yellow"), err=True) | |
</document_content> | |
</document> | |
<document index="7"> | |
<source>datasette/database.py</source> | |
<document_content> | |
import asyncio | |
from collections import namedtuple | |
from pathlib import Path | |
import janus | |
import queue | |
import sqlite_utils | |
import sys | |
import threading | |
import uuid | |
from .tracer import trace | |
from .utils import ( | |
detect_fts, | |
detect_primary_keys, | |
detect_spatialite, | |
get_all_foreign_keys, | |
get_outbound_foreign_keys, | |
md5_not_usedforsecurity, | |
sqlite_timelimit, | |
sqlite3, | |
table_columns, | |
table_column_details, | |
) | |
from .utils.sqlite import sqlite_version | |
from .inspect import inspect_hash | |
connections = threading.local() | |
AttachedDatabase = namedtuple("AttachedDatabase", ("seq", "name", "file")) | |
class Database: | |
# For table counts stop at this many rows: | |
count_limit = 10000 | |
def __init__( | |
self, | |
ds, | |
path=None, | |
is_mutable=True, | |
is_memory=False, | |
memory_name=None, | |
mode=None, | |
): | |
self.name = None | |
self.route = None | |
self.ds = ds | |
self.path = path | |
self.is_mutable = is_mutable | |
self.is_memory = is_memory | |
self.memory_name = memory_name | |
if memory_name is not None: | |
self.is_memory = True | |
self.cached_hash = None | |
self.cached_size = None | |
self._cached_table_counts = None | |
self._write_thread = None | |
self._write_queue = None | |
# These are used when in non-threaded mode: | |
self._read_connection = None | |
self._write_connection = None | |
# This is used to track all file connections so they can be closed | |
self._all_file_connections = [] | |
self.mode = mode | |
@property | |
def cached_table_counts(self): | |
if self._cached_table_counts is not None: | |
return self._cached_table_counts | |
# Maybe use self.ds.inspect_data to populate cached_table_counts | |
if self.ds.inspect_data and self.ds.inspect_data.get(self.name): | |
self._cached_table_counts = { | |
key: value["count"] | |
for key, value in self.ds.inspect_data[self.name]["tables"].items() | |
} | |
return self._cached_table_counts | |
@property | |
def color(self): | |
if self.hash: | |
return self.hash[:6] | |
return md5_not_usedforsecurity(self.name)[:6] | |
def suggest_name(self): | |
if self.path: | |
return Path(self.path).stem | |
elif self.memory_name: | |
return self.memory_name | |
else: | |
return "db" | |
def connect(self, write=False): | |
extra_kwargs = {} | |
if write: | |
extra_kwargs["isolation_level"] = "IMMEDIATE" | |
if self.memory_name: | |
uri = "file:{}?mode=memory&cache=shared".format(self.memory_name) | |
conn = sqlite3.connect( | |
uri, uri=True, check_same_thread=False, **extra_kwargs | |
) | |
if not write: | |
conn.execute("PRAGMA query_only=1") | |
return conn | |
if self.is_memory: | |
return sqlite3.connect(":memory:", uri=True) | |
# mode=ro or immutable=1? | |
if self.is_mutable: | |
qs = "?mode=ro" | |
if self.ds.nolock: | |
qs += "&nolock=1" | |
else: | |
qs = "?immutable=1" | |
assert not (write and not self.is_mutable) | |
if write: | |
qs = "" | |
if self.mode is not None: | |
qs = f"?mode={self.mode}" | |
conn = sqlite3.connect( | |
f"file:{self.path}{qs}", uri=True, check_same_thread=False, **extra_kwargs | |
) | |
self._all_file_connections.append(conn) | |
return conn | |
def close(self): | |
# Close all connections - useful to avoid running out of file handles in tests | |
for connection in self._all_file_connections: | |
connection.close() | |
async def execute_write(self, sql, params=None, block=True): | |
def _inner(conn): | |
return conn.execute(sql, params or []) | |
with trace("sql", database=self.name, sql=sql.strip(), params=params): | |
results = await self.execute_write_fn(_inner, block=block) | |
return results | |
async def execute_write_script(self, sql, block=True): | |
def _inner(conn): | |
return conn.executescript(sql) | |
with trace("sql", database=self.name, sql=sql.strip(), executescript=True): | |
results = await self.execute_write_fn(_inner, block=block) | |
return results | |
async def execute_write_many(self, sql, params_seq, block=True): | |
def _inner(conn): | |
count = 0 | |
def count_params(params): | |
nonlocal count | |
for param in params: | |
count += 1 | |
yield param | |
return conn.executemany(sql, count_params(params_seq)), count | |
with trace( | |
"sql", database=self.name, sql=sql.strip(), executemany=True | |
) as kwargs: | |
results, count = await self.execute_write_fn(_inner, block=block) | |
kwargs["count"] = count | |
return results | |
async def execute_isolated_fn(self, fn): | |
# Open a new connection just for the duration of this function | |
# blocking the write queue to avoid any writes occurring during it | |
if self.ds.executor is None: | |
# non-threaded mode | |
isolated_connection = self.connect(write=True) | |
try: | |
result = fn(isolated_connection) | |
finally: | |
isolated_connection.close() | |
try: | |
self._all_file_connections.remove(isolated_connection) | |
except ValueError: | |
# Was probably a memory connection | |
pass | |
return result | |
else: | |
# Threaded mode - send to write thread | |
return await self._send_to_write_thread(fn, isolated_connection=True) | |
async def execute_write_fn(self, fn, block=True, transaction=True): | |
if self.ds.executor is None: | |
# non-threaded mode | |
if self._write_connection is None: | |
self._write_connection = self.connect(write=True) | |
self.ds._prepare_connection(self._write_connection, self.name) | |
if transaction: | |
with self._write_connection: | |
return fn(self._write_connection) | |
else: | |
return fn(self._write_connection) | |
else: | |
return await self._send_to_write_thread( | |
fn, block=block, transaction=transaction | |
) | |
async def _send_to_write_thread( | |
self, fn, block=True, isolated_connection=False, transaction=True | |
): | |
if self._write_queue is None: | |
self._write_queue = queue.Queue() | |
if self._write_thread is None: | |
self._write_thread = threading.Thread( | |
target=self._execute_writes, daemon=True | |
) | |
self._write_thread.name = "_execute_writes for database {}".format( | |
self.name | |
) | |
self._write_thread.start() | |
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io") | |
reply_queue = janus.Queue() | |
self._write_queue.put( | |
WriteTask(fn, task_id, reply_queue, isolated_connection, transaction) | |
) | |
if block: | |
result = await reply_queue.async_q.get() | |
if isinstance(result, Exception): | |
raise result | |
else: | |
return result | |
else: | |
return task_id | |
def _execute_writes(self): | |
# Infinite looping thread that protects the single write connection | |
# to this database | |
conn_exception = None | |
conn = None | |
try: | |
conn = self.connect(write=True) | |
self.ds._prepare_connection(conn, self.name) | |
except Exception as e: | |
conn_exception = e | |
while True: | |
task = self._write_queue.get() | |
if conn_exception is not None: | |
result = conn_exception | |
else: | |
if task.isolated_connection: | |
isolated_connection = self.connect(write=True) | |
try: | |
result = task.fn(isolated_connection) | |
except Exception as e: | |
sys.stderr.write("{}\n".format(e)) | |
sys.stderr.flush() | |
result = e | |
finally: | |
isolated_connection.close() | |
try: | |
self._all_file_connections.remove(isolated_connection) | |
except ValueError: | |
# Was probably a memory connection | |
pass | |
else: | |
try: | |
if task.transaction: | |
with conn: | |
result = task.fn(conn) | |
else: | |
result = task.fn(conn) | |
except Exception as e: | |
sys.stderr.write("{}\n".format(e)) | |
sys.stderr.flush() | |
result = e | |
task.reply_queue.sync_q.put(result) | |
async def execute_fn(self, fn): | |
if self.ds.executor is None: | |
# non-threaded mode | |
if self._read_connection is None: | |
self._read_connection = self.connect() | |
self.ds._prepare_connection(self._read_connection, self.name) | |
return fn(self._read_connection) | |
# threaded mode | |
def in_thread(): | |
conn = getattr(connections, self.name, None) | |
if not conn: | |
conn = self.connect() | |
self.ds._prepare_connection(conn, self.name) | |
setattr(connections, self.name, conn) | |
return fn(conn) | |
return await asyncio.get_event_loop().run_in_executor( | |
self.ds.executor, in_thread | |
) | |
async def execute( | |
self, | |
sql, | |
params=None, | |
truncate=False, | |
custom_time_limit=None, | |
page_size=None, | |
log_sql_errors=True, | |
): | |
"""Executes sql against db_name in a thread""" | |
page_size = page_size or self.ds.page_size | |
def sql_operation_in_thread(conn): | |
time_limit_ms = self.ds.sql_time_limit_ms | |
if custom_time_limit and custom_time_limit < time_limit_ms: | |
time_limit_ms = custom_time_limit | |
with sqlite_timelimit(conn, time_limit_ms): | |
try: | |
cursor = conn.cursor() | |
cursor.execute(sql, params if params is not None else {}) | |
max_returned_rows = self.ds.max_returned_rows | |
if max_returned_rows == page_size: | |
max_returned_rows += 1 | |
if max_returned_rows and truncate: | |
rows = cursor.fetchmany(max_returned_rows + 1) | |
truncated = len(rows) > max_returned_rows | |
rows = rows[:max_returned_rows] | |
else: | |
rows = cursor.fetchall() | |
truncated = False | |
except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: | |
if e.args == ("interrupted",): | |
raise QueryInterrupted(e, sql, params) | |
if log_sql_errors: | |
sys.stderr.write( | |
"ERROR: conn={}, sql = {}, params = {}: {}\n".format( | |
conn, repr(sql), params, e | |
) | |
) | |
sys.stderr.flush() | |
raise | |
if truncate: | |
return Results(rows, truncated, cursor.description) | |
else: | |
return Results(rows, False, cursor.description) | |
with trace("sql", database=self.name, sql=sql.strip(), params=params): | |
results = await self.execute_fn(sql_operation_in_thread) | |
return results | |
@property | |
def hash(self): | |
if self.cached_hash is not None: | |
return self.cached_hash | |
elif self.is_mutable or self.is_memory: | |
return None | |
elif self.ds.inspect_data and self.ds.inspect_data.get(self.name): | |
self.cached_hash = self.ds.inspect_data[self.name]["hash"] | |
return self.cached_hash | |
else: | |
p = Path(self.path) | |
self.cached_hash = inspect_hash(p) | |
return self.cached_hash | |
@property | |
def size(self): | |
if self.cached_size is not None: | |
return self.cached_size | |
elif self.is_memory: | |
return 0 | |
elif self.is_mutable: | |
return Path(self.path).stat().st_size | |
elif self.ds.inspect_data and self.ds.inspect_data.get(self.name): | |
self.cached_size = self.ds.inspect_data[self.name]["size"] | |
return self.cached_size | |
else: | |
self.cached_size = Path(self.path).stat().st_size | |
return self.cached_size | |
async def table_counts(self, limit=10): | |
if not self.is_mutable and self.cached_table_counts is not None: | |
return self.cached_table_counts | |
# Try to get counts for each table, $limit timeout for each count | |
counts = {} | |
for table in await self.table_names(): | |
try: | |
table_count = ( | |
await self.execute( | |
f"select count(*) from (select * from [{table}] limit {self.count_limit + 1})", | |
custom_time_limit=limit, | |
) | |
).rows[0][0] | |
counts[table] = table_count | |
# In some cases I saw "SQL Logic Error" here in addition to | |
# QueryInterrupted - so we catch that too: | |
except (QueryInterrupted, sqlite3.OperationalError, sqlite3.DatabaseError): | |
counts[table] = None | |
if not self.is_mutable: | |
self._cached_table_counts = counts | |
return counts | |
@property | |
def mtime_ns(self): | |
if self.is_memory: | |
return None | |
return Path(self.path).stat().st_mtime_ns | |
async def attached_databases(self): | |
# This used to be: | |
# select seq, name, file from pragma_database_list() where seq > 0 | |
# But SQLite prior to 3.16.0 doesn't support pragma functions | |
results = await self.execute("PRAGMA database_list;") | |
# {'seq': 0, 'name': 'main', 'file': ''} | |
return [AttachedDatabase(*row) for row in results.rows if row["seq"] > 0] | |
async def table_exists(self, table): | |
results = await self.execute( | |
"select 1 from sqlite_master where type='table' and name=?", params=(table,) | |
) | |
return bool(results.rows) | |
async def view_exists(self, table): | |
results = await self.execute( | |
"select 1 from sqlite_master where type='view' and name=?", params=(table,) | |
) | |
return bool(results.rows) | |
async def table_names(self): | |
results = await self.execute( | |
"select name from sqlite_master where type='table'" | |
) | |
return [r[0] for r in results.rows] | |
async def table_columns(self, table): | |
return await self.execute_fn(lambda conn: table_columns(conn, table)) | |
async def table_column_details(self, table): | |
return await self.execute_fn(lambda conn: table_column_details(conn, table)) | |
async def primary_keys(self, table): | |
return await self.execute_fn(lambda conn: detect_primary_keys(conn, table)) | |
async def fts_table(self, table): | |
return await self.execute_fn(lambda conn: detect_fts(conn, table)) | |
async def label_column_for_table(self, table): | |
explicit_label_column = (await self.ds.table_config(self.name, table)).get( | |
"label_column" | |
) | |
if explicit_label_column: | |
return explicit_label_column | |
def column_details(conn): | |
# Returns {column_name: (type, is_unique)} | |
db = sqlite_utils.Database(conn) | |
columns = db[table].columns_dict | |
indexes = db[table].indexes | |
details = {} | |
for name in columns: | |
is_unique = any( | |
index | |
for index in indexes | |
if index.columns == [name] and index.unique | |
) | |
details[name] = (columns[name], is_unique) | |
return details | |
column_details = await self.execute_fn(column_details) | |
# Is there just one unique column that's text? | |
unique_text_columns = [ | |
name | |
for name, (type_, is_unique) in column_details.items() | |
if is_unique and type_ is str | |
] | |
if len(unique_text_columns) == 1: | |
return unique_text_columns[0] | |
column_names = list(column_details.keys()) | |
# Is there a name or title column? | |
name_or_title = [c for c in column_names if c.lower() in ("name", "title")] | |
if name_or_title: | |
return name_or_title[0] | |
# If a table has two columns, one of which is ID, then label_column is the other one | |
if ( | |
column_names | |
and len(column_names) == 2 | |
and ("id" in column_names or "pk" in column_names) | |
and not set(column_names) == {"id", "pk"} | |
): | |
return [c for c in column_names if c not in ("id", "pk")][0] | |
# Couldn't find a label: | |
return None | |
async def foreign_keys_for_table(self, table): | |
return await self.execute_fn( | |
lambda conn: get_outbound_foreign_keys(conn, table) | |
) | |
async def hidden_table_names(self): | |
hidden_tables = [] | |
# Add any tables marked as hidden in config | |
db_config = self.ds.config.get("databases", {}).get(self.name, {}) | |
if "tables" in db_config: | |
hidden_tables += [ | |
t for t in db_config["tables"] if db_config["tables"][t].get("hidden") | |
] | |
if sqlite_version()[1] >= 37: | |
hidden_tables += [ | |
x[0] | |
for x in await self.execute( | |
""" | |
with shadow_tables as ( | |
select name | |
from pragma_table_list | |
where [type] = 'shadow' | |
order by name | |
), | |
core_tables as ( | |
select name | |
from sqlite_master | |
WHERE name in ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4') | |
OR substr(name, 1, 1) == '_' | |
), | |
combined as ( | |
select name from shadow_tables | |
union all | |
select name from core_tables | |
) | |
select name from combined order by 1 | |
""" | |
) | |
] | |
else: | |
hidden_tables += [ | |
x[0] | |
for x in await self.execute( | |
""" | |
WITH base AS ( | |
SELECT name | |
FROM sqlite_master | |
WHERE name IN ('sqlite_stat1', 'sqlite_stat2', 'sqlite_stat3', 'sqlite_stat4') | |
OR substr(name, 1, 1) == '_' | |
), | |
fts_suffixes AS ( | |
SELECT column1 AS suffix | |
FROM (VALUES ('_data'), ('_idx'), ('_docsize'), ('_content'), ('_config')) | |
), | |
fts5_names AS ( | |
SELECT name | |
FROM sqlite_master | |
WHERE sql LIKE '%VIRTUAL TABLE%USING FTS%' | |
), | |
fts5_shadow_tables AS ( | |
SELECT | |
printf('%s%s', fts5_names.name, fts_suffixes.suffix) AS name | |
FROM fts5_names | |
JOIN fts_suffixes | |
), | |
fts3_suffixes AS ( | |
SELECT column1 AS suffix | |
FROM (VALUES ('_content'), ('_segdir'), ('_segments'), ('_stat'), ('_docsize')) | |
), | |
fts3_names AS ( | |
SELECT name | |
FROM sqlite_master | |
WHERE sql LIKE '%VIRTUAL TABLE%USING FTS3%' | |
OR sql LIKE '%VIRTUAL TABLE%USING FTS4%' | |
), | |
fts3_shadow_tables AS ( | |
SELECT | |
printf('%s%s', fts3_names.name, fts3_suffixes.suffix) AS name | |
FROM fts3_names | |
JOIN fts3_suffixes | |
), | |
final AS ( | |
SELECT name FROM base | |
UNION ALL | |
SELECT name FROM fts5_shadow_tables | |
UNION ALL | |
SELECT name FROM fts3_shadow_tables | |
) | |
SELECT name FROM final ORDER BY 1 | |
""" | |
) | |
] | |
has_spatialite = await self.execute_fn(detect_spatialite) | |
if has_spatialite: | |
# Also hide Spatialite internal tables | |
hidden_tables += [ | |
"ElementaryGeometries", | |
"SpatialIndex", | |
"geometry_columns", | |
"spatial_ref_sys", | |
"spatialite_history", | |
"sql_statements_log", | |
"sqlite_sequence", | |
"views_geometry_columns", | |
"virts_geometry_columns", | |
"data_licenses", | |
"KNN", | |
"KNN2", | |
] + [ | |
r[0] | |
for r in ( | |
await self.execute( | |
""" | |
select name from sqlite_master | |
where name like "idx_%" | |
and type = "table" | |
""" | |
) | |
).rows | |
] | |
return hidden_tables | |
async def view_names(self): | |
results = await self.execute("select name from sqlite_master where type='view'") | |
return [r[0] for r in results.rows] | |
async def get_all_foreign_keys(self): | |
return await self.execute_fn(get_all_foreign_keys) | |
async def get_table_definition(self, table, type_="table"): | |
table_definition_rows = list( | |
await self.execute( | |
"select sql from sqlite_master where name = :n and type=:t", | |
{"n": table, "t": type_}, | |
) | |
) | |
if not table_definition_rows: | |
return None | |
bits = [table_definition_rows[0][0] + ";"] | |
# Add on any indexes | |
index_rows = list( | |
await self.execute( | |
"select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null", | |
{"n": table}, | |
) | |
) | |
for index_row in index_rows: | |
bits.append(index_row[0] + ";") | |
return "\n".join(bits) | |
async def get_view_definition(self, view): | |
return await self.get_table_definition(view, "view") | |
def __repr__(self): | |
tags = [] | |
if self.is_mutable: | |
tags.append("mutable") | |
if self.is_memory: | |
tags.append("memory") | |
if self.hash: | |
tags.append(f"hash={self.hash}") | |
if self.size is not None: | |
tags.append(f"size={self.size}") | |
tags_str = "" | |
if tags: | |
tags_str = f" ({', '.join(tags)})" | |
return f"<Database: {self.name}{tags_str}>" | |
class WriteTask: | |
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection", "transaction") | |
def __init__(self, fn, task_id, reply_queue, isolated_connection, transaction): | |
self.fn = fn | |
self.task_id = task_id | |
self.reply_queue = reply_queue | |
self.isolated_connection = isolated_connection | |
self.transaction = transaction | |
class QueryInterrupted(Exception): | |
def __init__(self, e, sql, params): | |
self.e = e | |
self.sql = sql | |
self.params = params | |
def __str__(self): | |
return "QueryInterrupted: {}".format(self.e) | |
class MultipleValues(Exception): | |
pass | |
class Results: | |
def __init__(self, rows, truncated, description): | |
self.rows = rows | |
self.truncated = truncated | |
self.description = description | |
@property | |
def columns(self): | |
return [d[0] for d in self.description] | |
def first(self): | |
if self.rows: | |
return self.rows[0] | |
else: | |
return None | |
def single_value(self): | |
if self.rows and 1 == len(self.rows) and 1 == len(self.rows[0]): | |
return self.rows[0][0] | |
else: | |
raise MultipleValues | |
def dicts(self): | |
return [dict(row) for row in self.rows] | |
def __iter__(self): | |
return iter(self.rows) | |
def __len__(self): | |
return len(self.rows) | |
</document_content> | |
</document> | |
<document index="8"> | |
<source>datasette/default_magic_parameters.py</source> | |
<document_content> | |
from datasette import hookimpl | |
import datetime | |
import os | |
import time | |
def header(key, request): | |
key = key.replace("_", "-").encode("utf-8") | |
headers_dict = dict(request.scope["headers"]) | |
return headers_dict.get(key, b"").decode("utf-8") | |
def actor(key, request): | |
if request.actor is None: | |
raise KeyError | |
return request.actor[key] | |
def cookie(key, request): | |
return request.cookies[key] | |
def now(key, request): | |
if key == "epoch": | |
return int(time.time()) | |
elif key == "date_utc": | |
return datetime.datetime.now(datetime.timezone.utc).date().isoformat() | |
elif key == "datetime_utc": | |
return ( | |
datetime.datetime.now(datetime.timezone.utc).strftime(r"%Y-%m-%dT%H:%M:%S") | |
+ "Z" | |
) | |
else: | |
raise KeyError | |
def random(key, request): | |
if key.startswith("chars_") and key.split("chars_")[-1].isdigit(): | |
num_chars = int(key.split("chars_")[-1]) | |
if num_chars % 2 == 1: | |
urandom_len = (num_chars + 1) / 2 | |
else: | |
urandom_len = num_chars / 2 | |
return os.urandom(int(urandom_len)).hex()[:num_chars] | |
else: | |
raise KeyError | |
@hookimpl | |
def register_magic_parameters(): | |
return [ | |
("header", header), | |
("actor", actor), | |
("cookie", cookie), | |
("now", now), | |
("random", random), | |
] | |
</document_content> | |
</document> | |
<document index="9"> | |
<source>datasette/default_menu_links.py</source> | |
<document_content> | |
from datasette import hookimpl | |
@hookimpl | |
def menu_links(datasette, actor): | |
async def inner(): | |
if not await datasette.permission_allowed(actor, "debug-menu"): | |
return [] | |
return [ | |
{"href": datasette.urls.path("/-/databases"), "label": "Databases"}, | |
{ | |
"href": datasette.urls.path("/-/plugins"), | |
"label": "Installed plugins", | |
}, | |
{ | |
"href": datasette.urls.path("/-/versions"), | |
"label": "Version info", | |
}, | |
{ | |
"href": datasette.urls.path("/-/settings"), | |
"label": "Settings", | |
}, | |
{ | |
"href": datasette.urls.path("/-/permissions"), | |
"label": "Debug permissions", | |
}, | |
{ | |
"href": datasette.urls.path("/-/messages"), | |
"label": "Debug messages", | |
}, | |
{ | |
"href": datasette.urls.path("/-/allow-debug"), | |
"label": "Debug allow rules", | |
}, | |
{"href": datasette.urls.path("/-/threads"), "label": "Debug threads"}, | |
{"href": datasette.urls.path("/-/actor"), "label": "Debug actor"}, | |
{"href": datasette.urls.path("/-/patterns"), "label": "Pattern portfolio"}, | |
] | |
return inner | |
</document_content> | |
</document> | |
<document index="10"> | |
<source>datasette/default_permissions.py</source> | |
<document_content> | |
from datasette import hookimpl, Permission | |
from datasette.utils import actor_matches_allow | |
import itsdangerous | |
import time | |
from typing import Union, Tuple | |
@hookimpl | |
def register_permissions(): | |
return ( | |
Permission( | |
name="view-instance", | |
abbr="vi", | |
description="View Datasette instance", | |
takes_database=False, | |
takes_resource=False, | |
default=True, | |
), | |
Permission( | |
name="view-database", | |
abbr="vd", | |
description="View database", | |
takes_database=True, | |
takes_resource=False, | |
default=True, | |
implies_can_view=True, | |
), | |
Permission( | |
name="view-database-download", | |
abbr="vdd", | |
description="Download database file", | |
takes_database=True, | |
takes_resource=False, | |
default=True, | |
), | |
Permission( | |
name="view-table", | |
abbr="vt", | |
description="View table", | |
takes_database=True, | |
takes_resource=True, | |
default=True, | |
implies_can_view=True, | |
), | |
Permission( | |
name="view-query", | |
abbr="vq", | |
description="View named query results", | |
takes_database=True, | |
takes_resource=True, | |
default=True, | |
implies_can_view=True, | |
), | |
Permission( | |
name="execute-sql", | |
abbr="es", | |
description="Execute read-only SQL queries", | |
takes_database=True, | |
takes_resource=False, | |
default=True, | |
implies_can_view=True, | |
), | |
Permission( | |
name="permissions-debug", | |
abbr="pd", | |
description="Access permission debug tool", | |
takes_database=False, | |
takes_resource=False, | |
default=False, | |
), | |
Permission( | |
name="debug-menu", | |
abbr="dm", | |
description="View debug menu items", | |
takes_database=False, | |
takes_resource=False, | |
default=False, | |
), | |
Permission( | |
name="insert-row", | |
abbr="ir", | |
description="Insert rows", | |
takes_database=True, | |
takes_resource=True, | |
default=False, | |
), | |
Permission( | |
name="delete-row", | |
abbr="dr", | |
description="Delete rows", | |
takes_database=True, | |
takes_resource=True, | |
default=False, | |
), | |
Permission( | |
name="update-row", | |
abbr="ur", | |
description="Update rows", | |
takes_database=True, | |
takes_resource=True, | |
default=False, | |
), | |
Permission( | |
name="create-table", | |
abbr="ct", | |
description="Create tables", | |
takes_database=True, | |
takes_resource=False, | |
default=False, | |
), | |
Permission( | |
name="alter-table", | |
abbr="at", | |
description="Alter tables", | |
takes_database=True, | |
takes_resource=True, | |
default=False, | |
), | |
Permission( | |
name="drop-table", | |
abbr="dt", | |
description="Drop tables", | |
takes_database=True, | |
takes_resource=True, | |
default=False, | |
), | |
) | |
@hookimpl(tryfirst=True, specname="permission_allowed") | |
def permission_allowed_default(datasette, actor, action, resource): | |
async def inner(): | |
# id=root gets some special permissions: | |
if action in ( | |
"permissions-debug", | |
"debug-menu", | |
"insert-row", | |
"create-table", | |
"alter-table", | |
"drop-table", | |
"delete-row", | |
"update-row", | |
): | |
if actor and actor.get("id") == "root": | |
return True | |
# Resolve view permissions in allow blocks in configuration | |
if action in ( | |
"view-instance", | |
"view-database", | |
"view-table", | |
"view-query", | |
"execute-sql", | |
): | |
result = await _resolve_config_view_permissions( | |
datasette, actor, action, resource | |
) | |
if result is not None: | |
return result | |
# Resolve custom permissions: blocks in configuration | |
result = await _resolve_config_permissions_blocks( | |
datasette, actor, action, resource | |
) | |
if result is not None: | |
return result | |
# --setting default_allow_sql | |
if action == "execute-sql" and not datasette.setting("default_allow_sql"): | |
return False | |
return inner | |
async def _resolve_config_permissions_blocks(datasette, actor, action, resource): | |
# Check custom permissions: blocks | |
config = datasette.config or {} | |
root_block = (config.get("permissions", None) or {}).get(action) | |
if root_block: | |
root_result = actor_matches_allow(actor, root_block) | |
if root_result is not None: | |
return root_result | |
# Now try database-specific blocks | |
if not resource: | |
return None | |
if isinstance(resource, str): | |
database = resource | |
else: | |
database = resource[0] | |
database_block = ( | |
(config.get("databases", {}).get(database, {}).get("permissions", None)) or {} | |
).get(action) | |
if database_block: | |
database_result = actor_matches_allow(actor, database_block) | |
if database_result is not None: | |
return database_result | |
# Finally try table/query specific blocks | |
if not isinstance(resource, tuple): | |
return None | |
database, table_or_query = resource | |
table_block = ( | |
( | |
config.get("databases", {}) | |
.get(database, {}) | |
.get("tables", {}) | |
.get(table_or_query, {}) | |
.get("permissions", None) | |
) | |
or {} | |
).get(action) | |
if table_block: | |
table_result = actor_matches_allow(actor, table_block) | |
if table_result is not None: | |
return table_result | |
# Finally the canned queries | |
query_block = ( | |
( | |
config.get("databases", {}) | |
.get(database, {}) | |
.get("queries", {}) | |
.get(table_or_query, {}) | |
.get("permissions", None) | |
) | |
or {} | |
).get(action) | |
if query_block: | |
query_result = actor_matches_allow(actor, query_block) | |
if query_result is not None: | |
return query_result | |
return None | |
async def _resolve_config_view_permissions(datasette, actor, action, resource): | |
config = datasette.config or {} | |
if action == "view-instance": | |
allow = config.get("allow") | |
if allow is not None: | |
return actor_matches_allow(actor, allow) | |
elif action == "view-database": | |
database_allow = ((config.get("databases") or {}).get(resource) or {}).get( | |
"allow" | |
) | |
if database_allow is None: | |
return None | |
return actor_matches_allow(actor, database_allow) | |
elif action == "view-table": | |
database, table = resource | |
tables = ((config.get("databases") or {}).get(database) or {}).get( | |
"tables" | |
) or {} | |
table_allow = (tables.get(table) or {}).get("allow") | |
if table_allow is None: | |
return None | |
return actor_matches_allow(actor, table_allow) | |
elif action == "view-query": | |
# Check if this query has a "allow" block in config | |
database, query_name = resource | |
query = await datasette.get_canned_query(database, query_name, actor) | |
assert query is not None | |
allow = query.get("allow") | |
if allow is None: | |
return None | |
return actor_matches_allow(actor, allow) | |
elif action == "execute-sql": | |
# Use allow_sql block from database block, or from top-level | |
database_allow_sql = ((config.get("databases") or {}).get(resource) or {}).get( | |
"allow_sql" | |
) | |
if database_allow_sql is None: | |
database_allow_sql = config.get("allow_sql") | |
if database_allow_sql is None: | |
return None | |
return actor_matches_allow(actor, database_allow_sql) | |
def restrictions_allow_action( | |
datasette: "Datasette", | |
restrictions: dict, | |
action: str, | |
resource: Union[str, Tuple[str, str]], | |
): | |
"Do these restrictions allow the requested action against the requested resource?" | |
if action == "view-instance": | |
# Special case for view-instance: it's allowed if the restrictions include any | |
# permissions that have the implies_can_view=True flag set | |
all_rules = restrictions.get("a") or [] | |
for database_rules in (restrictions.get("d") or {}).values(): | |
all_rules += database_rules | |
for database_resource_rules in (restrictions.get("r") or {}).values(): | |
for resource_rules in database_resource_rules.values(): | |
all_rules += resource_rules | |
permissions = [datasette.get_permission(action) for action in all_rules] | |
if any(p for p in permissions if p.implies_can_view): | |
return True | |
if action == "view-database": | |
# Special case for view-database: it's allowed if the restrictions include any | |
# permissions that have the implies_can_view=True flag set AND takes_database | |
all_rules = restrictions.get("a") or [] | |
database_rules = list((restrictions.get("d") or {}).get(resource) or []) | |
all_rules += database_rules | |
resource_rules = ((restrictions.get("r") or {}).get(resource) or {}).values() | |
for resource_rules in (restrictions.get("r") or {}).values(): | |
for table_rules in resource_rules.values(): | |
all_rules += table_rules | |
permissions = [datasette.get_permission(action) for action in all_rules] | |
if any(p for p in permissions if p.implies_can_view and p.takes_database): | |
return True | |
# Does this action have an abbreviation? | |
to_check = {action} | |
permission = datasette.permissions.get(action) | |
if permission and permission.abbr: | |
to_check.add(permission.abbr) | |
# If restrictions is defined then we use those to further restrict the actor | |
# Crucially, we only use this to say NO (return False) - we never | |
# use it to return YES (True) because that might over-ride other | |
# restrictions placed on this actor | |
all_allowed = restrictions.get("a") | |
if all_allowed is not None: | |
assert isinstance(all_allowed, list) | |
if to_check.intersection(all_allowed): | |
return True | |
# How about for the current database? | |
if resource: | |
if isinstance(resource, str): | |
database_name = resource | |
else: | |
database_name = resource[0] | |
database_allowed = restrictions.get("d", {}).get(database_name) | |
if database_allowed is not None: | |
assert isinstance(database_allowed, list) | |
if to_check.intersection(database_allowed): | |
return True | |
# Or the current table? That's any time the resource is (database, table) | |
if resource is not None and not isinstance(resource, str) and len(resource) == 2: | |
database, table = resource | |
table_allowed = restrictions.get("r", {}).get(database, {}).get(table) | |
# TODO: What should this do for canned queries? | |
if table_allowed is not None: | |
assert isinstance(table_allowed, list) | |
if to_check.intersection(table_allowed): | |
return True | |
# This action is not specifically allowed, so reject it | |
return False | |
@hookimpl(specname="permission_allowed") | |
def permission_allowed_actor_restrictions(datasette, actor, action, resource): | |
if actor is None: | |
return None | |
if "_r" not in actor: | |
# No restrictions, so we have no opinion | |
return None | |
_r = actor.get("_r") | |
if restrictions_allow_action(datasette, _r, action, resource): | |
# Return None because we do not have an opinion here | |
return None | |
else: | |
# Block this permission check | |
return False | |
@hookimpl | |
def actor_from_request(datasette, request): | |
prefix = "dstok_" | |
if not datasette.setting("allow_signed_tokens"): | |
return None | |
max_signed_tokens_ttl = datasette.setting("max_signed_tokens_ttl") | |
authorization = request.headers.get("authorization") | |
if not authorization: | |
return None | |
if not authorization.startswith("Bearer "): | |
return None | |
token = authorization[len("Bearer ") :] | |
if not token.startswith(prefix): | |
return None | |
token = token[len(prefix) :] | |
try: | |
decoded = datasette.unsign(token, namespace="token") | |
except itsdangerous.BadSignature: | |
return None | |
if "t" not in decoded: | |
# Missing timestamp | |
return None | |
created = decoded["t"] | |
if not isinstance(created, int): | |
# Invalid timestamp | |
return None | |
duration = decoded.get("d") | |
if duration is not None and not isinstance(duration, int): | |
# Invalid duration | |
return None | |
if (duration is None and max_signed_tokens_ttl) or ( | |
duration is not None | |
and max_signed_tokens_ttl | |
and duration > max_signed_tokens_ttl | |
): | |
duration = max_signed_tokens_ttl | |
if duration: | |
if time.time() - created > duration: | |
# Expired | |
return None | |
actor = {"id": decoded["a"], "token": "dstok"} | |
if "_r" in decoded: | |
actor["_r"] = decoded["_r"] | |
if duration: | |
actor["token_expires"] = created + duration | |
return actor | |
@hookimpl | |
def skip_csrf(scope): | |
# Skip CSRF check for requests with content-type: application/json | |
if scope["type"] == "http": | |
headers = scope.get("headers") or {} | |
if dict(headers).get(b"content-type") == b"application/json": | |
return True | |
</document_content> | |
</document> | |
<document index="11"> | |
<source>datasette/events.py</source> | |
<document_content> | |
from abc import ABC, abstractproperty | |
from dataclasses import asdict, dataclass, field | |
from datasette.hookspecs import hookimpl | |
from datetime import datetime, timezone | |
from typing import Optional | |
@dataclass | |
class Event(ABC): | |
@abstractproperty | |
def name(self): | |
pass | |
created: datetime = field( | |
init=False, default_factory=lambda: datetime.now(timezone.utc) | |
) | |
actor: Optional[dict] | |
def properties(self): | |
properties = asdict(self) | |
properties.pop("actor", None) | |
properties.pop("created", None) | |
return properties | |
@dataclass | |
class LoginEvent(Event): | |
""" | |
Event name: ``login`` | |
A user (represented by ``event.actor``) has logged in. | |
""" | |
name = "login" | |
@dataclass | |
class LogoutEvent(Event): | |
""" | |
Event name: ``logout`` | |
A user (represented by ``event.actor``) has logged out. | |
""" | |
name = "logout" | |
@dataclass | |
class CreateTokenEvent(Event): | |
""" | |
Event name: ``create-token`` | |
A user created an API token. | |
:ivar expires_after: Number of seconds after which this token will expire. | |
:type expires_after: int or None | |
:ivar restrict_all: Restricted permissions for this token. | |
:type restrict_all: list | |
:ivar restrict_database: Restricted database permissions for this token. | |
:type restrict_database: dict | |
:ivar restrict_resource: Restricted resource permissions for this token. | |
:type restrict_resource: dict | |
""" | |
name = "create-token" | |
expires_after: Optional[int] | |
restrict_all: list | |
restrict_database: dict | |
restrict_resource: dict | |
@dataclass | |
class CreateTableEvent(Event): | |
""" | |
Event name: ``create-table`` | |
A new table has been created in the database. | |
:ivar database: The name of the database where the table was created. | |
:type database: str | |
:ivar table: The name of the table that was created | |
:type table: str | |
:ivar schema: The SQL schema definition for the new table. | |
:type schema: str | |
""" | |
name = "create-table" | |
database: str | |
table: str | |
schema: str | |
@dataclass | |
class DropTableEvent(Event): | |
""" | |
Event name: ``drop-table`` | |
A table has been dropped from the database. | |
:ivar database: The name of the database where the table was dropped. | |
:type database: str | |
:ivar table: The name of the table that was dropped | |
:type table: str | |
""" | |
name = "drop-table" | |
database: str | |
table: str | |
@dataclass | |
class AlterTableEvent(Event): | |
""" | |
Event name: ``alter-table`` | |
A table has been altered. | |
:ivar database: The name of the database where the table was altered | |
:type database: str | |
:ivar table: The name of the table that was altered | |
:type table: str | |
:ivar before_schema: The table's SQL schema before the alteration | |
:type before_schema: str | |
:ivar after_schema: The table's SQL schema after the alteration | |
:type after_schema: str | |
""" | |
name = "alter-table" | |
database: str | |
table: str | |
before_schema: str | |
after_schema: str | |
@dataclass | |
class InsertRowsEvent(Event): | |
""" | |
Event name: ``insert-rows`` | |
Rows were inserted into a table. | |
:ivar database: The name of the database where the rows were inserted. | |
:type database: str | |
:ivar table: The name of the table where the rows were inserted. | |
:type table: str | |
:ivar num_rows: The number of rows that were requested to be inserted. | |
:type num_rows: int | |
:ivar ignore: Was ignore set? | |
:type ignore: bool | |
:ivar replace: Was replace set? | |
:type replace: bool | |
""" | |
name = "insert-rows" | |
database: str | |
table: str | |
num_rows: int | |
ignore: bool | |
replace: bool | |
@dataclass | |
class UpsertRowsEvent(Event): | |
""" | |
Event name: ``upsert-rows`` | |
Rows were upserted into a table. | |
:ivar database: The name of the database where the rows were inserted. | |
:type database: str | |
:ivar table: The name of the table where the rows were inserted. | |
:type table: str | |
:ivar num_rows: The number of rows that were requested to be inserted. | |
:type num_rows: int | |
""" | |
name = "upsert-rows" | |
database: str | |
table: str | |
num_rows: int | |
@dataclass | |
class UpdateRowEvent(Event): | |
""" | |
Event name: ``update-row`` | |
A row was updated in a table. | |
:ivar database: The name of the database where the row was updated. | |
:type database: str | |
:ivar table: The name of the table where the row was updated. | |
:type table: str | |
:ivar pks: The primary key values of the updated row. | |
""" | |
name = "update-row" | |
database: str | |
table: str | |
pks: list | |
@dataclass | |
class DeleteRowEvent(Event): | |
""" | |
Event name: ``delete-row`` | |
A row was deleted from a table. | |
:ivar database: The name of the database where the row was deleted. | |
:type database: str | |
:ivar table: The name of the table where the row was deleted. | |
:type table: str | |
:ivar pks: The primary key values of the deleted row. | |
""" | |
name = "delete-row" | |
database: str | |
table: str | |
pks: list | |
@hookimpl | |
def register_events(): | |
return [ | |
LoginEvent, | |
LogoutEvent, | |
CreateTableEvent, | |
CreateTokenEvent, | |
AlterTableEvent, | |
DropTableEvent, | |
InsertRowsEvent, | |
UpsertRowsEvent, | |
UpdateRowEvent, | |
DeleteRowEvent, | |
] | |
</document_content> | |
</document> | |
<document index="12"> | |
<source>datasette/facets.py</source> | |
<document_content> | |
import json | |
import urllib | |
from datasette import hookimpl | |
from datasette.database import QueryInterrupted | |
from datasette.utils import ( | |
escape_sqlite, | |
path_with_added_args, | |
path_with_removed_args, | |
detect_json1, | |
sqlite3, | |
) | |
def load_facet_configs(request, table_config): | |
# Given a request and the configuration for a table, return | |
# a dictionary of selected facets, their lists of configs and for each | |
# config whether it came from the request or the metadata. | |
# | |
# return {type: [ | |
# {"source": "metadata", "config": config1}, | |
# {"source": "request", "config": config2}]} | |
facet_configs = {} | |
table_config = table_config or {} | |
table_facet_configs = table_config.get("facets", []) | |
for facet_config in table_facet_configs: | |
if isinstance(facet_config, str): | |
type = "column" | |
facet_config = {"simple": facet_config} | |
else: | |
assert ( | |
len(facet_config.values()) == 1 | |
), "Metadata config dicts should be {type: config}" | |
type, facet_config = list(facet_config.items())[0] | |
if isinstance(facet_config, str): | |
facet_config = {"simple": facet_config} | |
facet_configs.setdefault(type, []).append( | |
{"source": "metadata", "config": facet_config} | |
) | |
qs_pairs = urllib.parse.parse_qs(request.query_string, keep_blank_values=True) | |
for key, values in qs_pairs.items(): | |
if key.startswith("_facet"): | |
# Figure out the facet type | |
if key == "_facet": | |
type = "column" | |
elif key.startswith("_facet_"): | |
type = key[len("_facet_") :] | |
for value in values: | |
# The value is the facet_config - either JSON or not | |
facet_config = ( | |
json.loads(value) if value.startswith("{") else {"simple": value} | |
) | |
facet_configs.setdefault(type, []).append( | |
{"source": "request", "config": facet_config} | |
) | |
return facet_configs | |
@hookimpl | |
def register_facet_classes(): | |
classes = [ColumnFacet, DateFacet] | |
if detect_json1(): | |
classes.append(ArrayFacet) | |
return classes | |
class Facet: | |
type = None | |
# How many rows to consider when suggesting facets: | |
suggest_consider = 1000 | |
def __init__( | |
self, | |
ds, | |
request, | |
database, | |
sql=None, | |
table=None, | |
params=None, | |
table_config=None, | |
row_count=None, | |
): | |
assert table or sql, "Must provide either table= or sql=" | |
self.ds = ds | |
self.request = request | |
self.database = database | |
# For foreign key expansion. Can be None for e.g. canned SQL queries: | |
self.table = table | |
self.sql = sql or f"select * from [{table}]" | |
self.params = params or [] | |
self.table_config = table_config | |
# row_count can be None, in which case we calculate it ourselves: | |
self.row_count = row_count | |
def get_configs(self): | |
configs = load_facet_configs(self.request, self.table_config) | |
return configs.get(self.type) or [] | |
def get_querystring_pairs(self): | |
# ?_foo=bar&_foo=2&empty= becomes: | |
# [('_foo', 'bar'), ('_foo', '2'), ('empty', '')] | |
return urllib.parse.parse_qsl(self.request.query_string, keep_blank_values=True) | |
def get_facet_size(self): | |
facet_size = self.ds.setting("default_facet_size") | |
max_returned_rows = self.ds.setting("max_returned_rows") | |
table_facet_size = None | |
if self.table: | |
config_facet_size = ( | |
self.ds.config.get("databases", {}) | |
.get(self.database, {}) | |
.get("tables", {}) | |
.get(self.table, {}) | |
.get("facet_size") | |
) | |
if config_facet_size: | |
table_facet_size = config_facet_size | |
custom_facet_size = self.request.args.get("_facet_size") | |
if custom_facet_size: | |
if custom_facet_size == "max": | |
facet_size = max_returned_rows | |
elif custom_facet_size.isdigit(): | |
facet_size = int(custom_facet_size) | |
else: | |
# Invalid value, ignore it | |
custom_facet_size = None | |
if table_facet_size and not custom_facet_size: | |
if table_facet_size == "max": | |
facet_size = max_returned_rows | |
else: | |
facet_size = table_facet_size | |
return min(facet_size, max_returned_rows) | |
async def suggest(self): | |
return [] | |
async def facet_results(self): | |
# returns ([results], [timed_out]) | |
# TODO: Include "hideable" with each one somehow, which indicates if it was | |
# defined in metadata (in which case you cannot turn it off) | |
raise NotImplementedError | |
async def get_columns(self, sql, params=None): | |
# Detect column names using the "limit 0" trick | |
return ( | |
await self.ds.execute( | |
self.database, f"select * from ({sql}) limit 0", params or [] | |
) | |
).columns | |
class ColumnFacet(Facet): | |
type = "column" | |
async def suggest(self): | |
row_count = await self.get_row_count() | |
columns = await self.get_columns(self.sql, self.params) | |
facet_size = self.get_facet_size() | |
suggested_facets = [] | |
already_enabled = [c["config"]["simple"] for c in self.get_configs()] | |
for column in columns: | |
if column in already_enabled: | |
continue | |
suggested_facet_sql = """ | |
with limited as (select * from ({sql}) limit {suggest_consider}) | |
select {column} as value, count(*) as n from limited | |
where value is not null | |
group by value | |
limit {limit} | |
""".format( | |
column=escape_sqlite(column), | |
sql=self.sql, | |
limit=facet_size + 1, | |
suggest_consider=self.suggest_consider, | |
) | |
distinct_values = None | |
try: | |
distinct_values = await self.ds.execute( | |
self.database, | |
suggested_facet_sql, | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting("facet_suggest_time_limit_ms"), | |
) | |
num_distinct_values = len(distinct_values) | |
if ( | |
1 < num_distinct_values < row_count | |
and num_distinct_values <= facet_size | |
# And at least one has n > 1 | |
and any(r["n"] > 1 for r in distinct_values) | |
): | |
suggested_facets.append( | |
{ | |
"name": column, | |
"toggle_url": self.ds.absolute_url( | |
self.request, | |
self.ds.urls.path( | |
path_with_added_args( | |
self.request, {"_facet": column} | |
) | |
), | |
), | |
} | |
) | |
except QueryInterrupted: | |
continue | |
return suggested_facets | |
async def get_row_count(self): | |
if self.row_count is None: | |
self.row_count = ( | |
await self.ds.execute( | |
self.database, | |
f"select count(*) from (select * from ({self.sql}) limit {self.suggest_consider})", | |
self.params, | |
) | |
).rows[0][0] | |
return self.row_count | |
async def facet_results(self): | |
facet_results = [] | |
facets_timed_out = [] | |
qs_pairs = self.get_querystring_pairs() | |
facet_size = self.get_facet_size() | |
for source_and_config in self.get_configs(): | |
config = source_and_config["config"] | |
source = source_and_config["source"] | |
column = config.get("column") or config["simple"] | |
facet_sql = """ | |
select {col} as value, count(*) as count from ( | |
{sql} | |
) | |
where {col} is not null | |
group by {col} order by count desc, value limit {limit} | |
""".format( | |
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 | |
) | |
try: | |
facet_rows_results = await self.ds.execute( | |
self.database, | |
facet_sql, | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting("facet_time_limit_ms"), | |
) | |
facet_results_values = [] | |
facet_results.append( | |
{ | |
"name": column, | |
"type": self.type, | |
"hideable": source != "metadata", | |
"toggle_url": self.ds.urls.path( | |
path_with_removed_args(self.request, {"_facet": column}) | |
), | |
"results": facet_results_values, | |
"truncated": len(facet_rows_results) > facet_size, | |
} | |
) | |
facet_rows = facet_rows_results.rows[:facet_size] | |
if self.table: | |
# Attempt to expand foreign keys into labels | |
values = [row["value"] for row in facet_rows] | |
expanded = await self.ds.expand_foreign_keys( | |
self.request.actor, self.database, self.table, column, values | |
) | |
else: | |
expanded = {} | |
for row in facet_rows: | |
column_qs = column | |
if column.startswith("_"): | |
column_qs = "{}__exact".format(column) | |
selected = (column_qs, str(row["value"])) in qs_pairs | |
if selected: | |
toggle_path = path_with_removed_args( | |
self.request, {column_qs: str(row["value"])} | |
) | |
else: | |
toggle_path = path_with_added_args( | |
self.request, {column_qs: row["value"]} | |
) | |
facet_results_values.append( | |
{ | |
"value": row["value"], | |
"label": expanded.get((column, row["value"]), row["value"]), | |
"count": row["count"], | |
"toggle_url": self.ds.absolute_url( | |
self.request, self.ds.urls.path(toggle_path) | |
), | |
"selected": selected, | |
} | |
) | |
except QueryInterrupted: | |
facets_timed_out.append(column) | |
return facet_results, facets_timed_out | |
class ArrayFacet(Facet): | |
type = "array" | |
def _is_json_array_of_strings(self, json_string): | |
try: | |
array = json.loads(json_string) | |
except ValueError: | |
return False | |
for item in array: | |
if not isinstance(item, str): | |
return False | |
return True | |
async def suggest(self): | |
columns = await self.get_columns(self.sql, self.params) | |
suggested_facets = [] | |
already_enabled = [c["config"]["simple"] for c in self.get_configs()] | |
for column in columns: | |
if column in already_enabled: | |
continue | |
# Is every value in this column either null or a JSON array? | |
suggested_facet_sql = """ | |
with limited as (select * from ({sql}) limit {suggest_consider}) | |
select distinct json_type({column}) | |
from limited | |
where {column} is not null and {column} != '' | |
""".format( | |
column=escape_sqlite(column), | |
sql=self.sql, | |
suggest_consider=self.suggest_consider, | |
) | |
try: | |
results = await self.ds.execute( | |
self.database, | |
suggested_facet_sql, | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting("facet_suggest_time_limit_ms"), | |
log_sql_errors=False, | |
) | |
types = tuple(r[0] for r in results.rows) | |
if types in (("array",), ("array", None)): | |
# Now check that first 100 arrays contain only strings | |
first_100 = [ | |
v[0] | |
for v in await self.ds.execute( | |
self.database, | |
( | |
"select {column} from ({sql}) " | |
"where {column} is not null " | |
"and {column} != '' " | |
"and json_array_length({column}) > 0 " | |
"limit 100" | |
).format(column=escape_sqlite(column), sql=self.sql), | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting( | |
"facet_suggest_time_limit_ms" | |
), | |
log_sql_errors=False, | |
) | |
] | |
if first_100 and all( | |
self._is_json_array_of_strings(r) for r in first_100 | |
): | |
suggested_facets.append( | |
{ | |
"name": column, | |
"type": "array", | |
"toggle_url": self.ds.absolute_url( | |
self.request, | |
self.ds.urls.path( | |
path_with_added_args( | |
self.request, {"_facet_array": column} | |
) | |
), | |
), | |
} | |
) | |
except (QueryInterrupted, sqlite3.OperationalError): | |
continue | |
return suggested_facets | |
async def facet_results(self): | |
# self.configs should be a plain list of columns | |
facet_results = [] | |
facets_timed_out = [] | |
facet_size = self.get_facet_size() | |
for source_and_config in self.get_configs(): | |
config = source_and_config["config"] | |
source = source_and_config["source"] | |
column = config.get("column") or config["simple"] | |
# https://github.com/simonw/datasette/issues/448 | |
facet_sql = """ | |
with inner as ({sql}), | |
deduped_array_items as ( | |
select | |
distinct j.value, | |
inner.* | |
from | |
json_each([inner].{col}) j | |
join inner | |
) | |
select | |
value as value, | |
count(*) as count | |
from | |
deduped_array_items | |
group by | |
value | |
order by | |
count(*) desc, value limit {limit} | |
""".format( | |
col=escape_sqlite(column), | |
sql=self.sql, | |
limit=facet_size + 1, | |
) | |
try: | |
facet_rows_results = await self.ds.execute( | |
self.database, | |
facet_sql, | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting("facet_time_limit_ms"), | |
) | |
facet_results_values = [] | |
facet_results.append( | |
{ | |
"name": column, | |
"type": self.type, | |
"results": facet_results_values, | |
"hideable": source != "metadata", | |
"toggle_url": self.ds.urls.path( | |
path_with_removed_args( | |
self.request, {"_facet_array": column} | |
) | |
), | |
"truncated": len(facet_rows_results) > facet_size, | |
} | |
) | |
facet_rows = facet_rows_results.rows[:facet_size] | |
pairs = self.get_querystring_pairs() | |
for row in facet_rows: | |
value = str(row["value"]) | |
selected = (f"{column}__arraycontains", value) in pairs | |
if selected: | |
toggle_path = path_with_removed_args( | |
self.request, {f"{column}__arraycontains": value} | |
) | |
else: | |
toggle_path = path_with_added_args( | |
self.request, {f"{column}__arraycontains": value} | |
) | |
facet_results_values.append( | |
{ | |
"value": value, | |
"label": value, | |
"count": row["count"], | |
"toggle_url": self.ds.absolute_url( | |
self.request, toggle_path | |
), | |
"selected": selected, | |
} | |
) | |
except QueryInterrupted: | |
facets_timed_out.append(column) | |
return facet_results, facets_timed_out | |
class DateFacet(Facet): | |
type = "date" | |
async def suggest(self): | |
columns = await self.get_columns(self.sql, self.params) | |
already_enabled = [c["config"]["simple"] for c in self.get_configs()] | |
suggested_facets = [] | |
for column in columns: | |
if column in already_enabled: | |
continue | |
# Does this column contain any dates in the first 100 rows? | |
suggested_facet_sql = """ | |
select date({column}) from ( | |
select * from ({sql}) limit 100 | |
) where {column} glob "????-??-*" | |
""".format( | |
column=escape_sqlite(column), sql=self.sql | |
) | |
try: | |
results = await self.ds.execute( | |
self.database, | |
suggested_facet_sql, | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting("facet_suggest_time_limit_ms"), | |
log_sql_errors=False, | |
) | |
values = tuple(r[0] for r in results.rows) | |
if any(values): | |
suggested_facets.append( | |
{ | |
"name": column, | |
"type": "date", | |
"toggle_url": self.ds.absolute_url( | |
self.request, | |
self.ds.urls.path( | |
path_with_added_args( | |
self.request, {"_facet_date": column} | |
) | |
), | |
), | |
} | |
) | |
except (QueryInterrupted, sqlite3.OperationalError): | |
continue | |
return suggested_facets | |
async def facet_results(self): | |
facet_results = [] | |
facets_timed_out = [] | |
args = dict(self.get_querystring_pairs()) | |
facet_size = self.get_facet_size() | |
for source_and_config in self.get_configs(): | |
config = source_and_config["config"] | |
source = source_and_config["source"] | |
column = config.get("column") or config["simple"] | |
# TODO: does this query break if inner sql produces value or count columns? | |
facet_sql = """ | |
select date({col}) as value, count(*) as count from ( | |
{sql} | |
) | |
where date({col}) is not null | |
group by date({col}) order by count desc, value limit {limit} | |
""".format( | |
col=escape_sqlite(column), sql=self.sql, limit=facet_size + 1 | |
) | |
try: | |
facet_rows_results = await self.ds.execute( | |
self.database, | |
facet_sql, | |
self.params, | |
truncate=False, | |
custom_time_limit=self.ds.setting("facet_time_limit_ms"), | |
) | |
facet_results_values = [] | |
facet_results.append( | |
{ | |
"name": column, | |
"type": self.type, | |
"results": facet_results_values, | |
"hideable": source != "metadata", | |
"toggle_url": path_with_removed_args( | |
self.request, {"_facet_date": column} | |
), | |
"truncated": len(facet_rows_results) > facet_size, | |
} | |
) | |
facet_rows = facet_rows_results.rows[:facet_size] | |
for row in facet_rows: | |
selected = str(args.get(f"{column}__date")) == str(row["value"]) | |
if selected: | |
toggle_path = path_with_removed_args( | |
self.request, {f"{column}__date": str(row["value"])} | |
) | |
else: | |
toggle_path = path_with_added_args( | |
self.request, {f"{column}__date": row["value"]} | |
) | |
facet_results_values.append( | |
{ | |
"value": row["value"], | |
"label": row["value"], | |
"count": row["count"], | |
"toggle_url": self.ds.absolute_url( | |
self.request, toggle_path | |
), | |
"selected": selected, | |
} | |
) | |
except QueryInterrupted: | |
facets_timed_out.append(column) | |
return facet_results, facets_timed_out | |
</document_content> | |
</document> | |
<document index="13"> | |
<source>datasette/filters.py</source> | |
<document_content> | |
from datasette import hookimpl | |
from datasette.views.base import DatasetteError | |
from datasette.utils.asgi import BadRequest | |
import json | |
import numbers | |
from .utils import detect_json1, escape_sqlite, path_with_removed_args | |
@hookimpl(specname="filters_from_request") | |
def where_filters(request, database, datasette): | |
# This one deals with ?_where= | |
async def inner(): | |
where_clauses = [] | |
extra_wheres_for_ui = [] | |
if "_where" in request.args: | |
if not await datasette.permission_allowed( | |
request.actor, | |
"execute-sql", | |
resource=database, | |
default=True, | |
): | |
raise DatasetteError("_where= is not allowed", status=403) | |
else: | |
where_clauses.extend(request.args.getlist("_where")) | |
extra_wheres_for_ui = [ | |
{ | |
"text": text, | |
"remove_url": path_with_removed_args(request, {"_where": text}), | |
} | |
for text in request.args.getlist("_where") | |
] | |
return FilterArguments( | |
where_clauses, | |
extra_context={ | |
"extra_wheres_for_ui": extra_wheres_for_ui, | |
}, | |
) | |
return inner | |
@hookimpl(specname="filters_from_request") | |
def search_filters(request, database, table, datasette): | |
# ?_search= and _search_colname= | |
async def inner(): | |
where_clauses = [] | |
params = {} | |
human_descriptions = [] | |
extra_context = {} | |
# Figure out which fts_table to use | |
table_metadata = await datasette.table_config(database, table) | |
db = datasette.get_database(database) | |
fts_table = request.args.get("_fts_table") | |
fts_table = fts_table or table_metadata.get("fts_table") | |
fts_table = fts_table or await db.fts_table(table) | |
fts_pk = request.args.get("_fts_pk", table_metadata.get("fts_pk", "rowid")) | |
search_args = { | |
key: request.args[key] | |
for key in request.args | |
if key.startswith("_search") and key != "_searchmode" | |
} | |
search = "" | |
search_mode_raw = table_metadata.get("searchmode") == "raw" | |
# Or set search mode from the querystring | |
qs_searchmode = request.args.get("_searchmode") | |
if qs_searchmode == "escaped": | |
search_mode_raw = False | |
if qs_searchmode == "raw": | |
search_mode_raw = True | |
extra_context["supports_search"] = bool(fts_table) | |
if fts_table and search_args: | |
if "_search" in search_args: | |
# Simple ?_search=xxx | |
search = search_args["_search"] | |
where_clauses.append( | |
"{fts_pk} in (select rowid from {fts_table} where {fts_table} match {match_clause})".format( | |
fts_table=escape_sqlite(fts_table), | |
fts_pk=escape_sqlite(fts_pk), | |
match_clause=( | |
":search" if search_mode_raw else "escape_fts(:search)" | |
), | |
) | |
) | |
human_descriptions.append(f'search matches "{search}"') | |
params["search"] = search | |
extra_context["search"] = search | |
else: | |
# More complex: search against specific columns | |
for i, (key, search_text) in enumerate(search_args.items()): | |
search_col = key.split("_search_", 1)[1] | |
if search_col not in await db.table_columns(fts_table): | |
raise BadRequest("Cannot search by that column") | |
where_clauses.append( | |
"rowid in (select rowid from {fts_table} where {search_col} match {match_clause})".format( | |
fts_table=escape_sqlite(fts_table), | |
search_col=escape_sqlite(search_col), | |
match_clause=( | |
":search_{}".format(i) | |
if search_mode_raw | |
else "escape_fts(:search_{})".format(i) | |
), | |
) | |
) | |
human_descriptions.append( | |
f'search column "{search_col}" matches "{search_text}"' | |
) | |
params[f"search_{i}"] = search_text | |
extra_context["search"] = search_text | |
return FilterArguments(where_clauses, params, human_descriptions, extra_context) | |
return inner | |
@hookimpl(specname="filters_from_request") | |
def through_filters(request, database, table, datasette): | |
# ?_search= and _search_colname= | |
async def inner(): | |
where_clauses = [] | |
params = {} | |
human_descriptions = [] | |
extra_context = {} | |
# Support for ?_through={table, column, value} | |
if "_through" in request.args: | |
for through in request.args.getlist("_through"): | |
through_data = json.loads(through) | |
through_table = through_data["table"] | |
other_column = through_data["column"] | |
value = through_data["value"] | |
db = datasette.get_database(database) | |
outgoing_foreign_keys = await db.foreign_keys_for_table(through_table) | |
try: | |
fk_to_us = [ | |
fk for fk in outgoing_foreign_keys if fk["other_table"] == table | |
][0] | |
except IndexError: | |
raise DatasetteError( | |
"Invalid _through - could not find corresponding foreign key" | |
) | |
param = f"p{len(params)}" | |
where_clauses.append( | |
"{our_pk} in (select {our_column} from {through_table} where {other_column} = :{param})".format( | |
through_table=escape_sqlite(through_table), | |
our_pk=escape_sqlite(fk_to_us["other_column"]), | |
our_column=escape_sqlite(fk_to_us["column"]), | |
other_column=escape_sqlite(other_column), | |
param=param, | |
) | |
) | |
params[param] = value | |
human_descriptions.append(f'{through_table}.{other_column} = "{value}"') | |
return FilterArguments(where_clauses, params, human_descriptions, extra_context) | |
return inner | |
class FilterArguments: | |
def __init__( | |
self, where_clauses, params=None, human_descriptions=None, extra_context=None | |
): | |
self.where_clauses = where_clauses | |
self.params = params or {} | |
self.human_descriptions = human_descriptions or [] | |
self.extra_context = extra_context or {} | |
class Filter: | |
key = None | |
display = None | |
no_argument = False | |
def where_clause(self, table, column, value, param_counter): | |
raise NotImplementedError | |
def human_clause(self, column, value): | |
raise NotImplementedError | |
class TemplatedFilter(Filter): | |
def __init__( | |
self, | |
key, | |
display, | |
sql_template, | |
human_template, | |
format="{}", | |
numeric=False, | |
no_argument=False, | |
): | |
self.key = key | |
self.display = display | |
self.sql_template = sql_template | |
self.human_template = human_template | |
self.format = format | |
self.numeric = numeric | |
self.no_argument = no_argument | |
def where_clause(self, table, column, value, param_counter): | |
converted = self.format.format(value) | |
if self.numeric and converted.isdigit(): | |
converted = int(converted) | |
if self.no_argument: | |
kwargs = {"c": column} | |
converted = None | |
else: | |
kwargs = {"c": column, "p": f"p{param_counter}", "t": table} | |
return self.sql_template.format(**kwargs), converted | |
def human_clause(self, column, value): | |
if callable(self.human_template): | |
template = self.human_template(column, value) | |
else: | |
template = self.human_template | |
if self.no_argument: | |
return template.format(c=column) | |
else: | |
return template.format(c=column, v=value) | |
class InFilter(Filter): | |
key = "in" | |
display = "in" | |
def split_value(self, value): | |
if value.startswith("["): | |
return json.loads(value) | |
else: | |
return [v.strip() for v in value.split(",")] | |
def where_clause(self, table, column, value, param_counter): | |
values = self.split_value(value) | |
params = [f":p{param_counter + i}" for i in range(len(values))] | |
sql = f"{escape_sqlite(column)} in ({', '.join(params)})" | |
return sql, values | |
def human_clause(self, column, value): | |
return f"{column} in {json.dumps(self.split_value(value))}" | |
class NotInFilter(InFilter): | |
key = "notin" | |
display = "not in" | |
def where_clause(self, table, column, value, param_counter): | |
values = self.split_value(value) | |
params = [f":p{param_counter + i}" for i in range(len(values))] | |
sql = f"{escape_sqlite(column)} not in ({', '.join(params)})" | |
return sql, values | |
def human_clause(self, column, value): | |
return f"{column} not in {json.dumps(self.split_value(value))}" | |
class Filters: | |
_filters = ( | |
[ | |
# key, display, sql_template, human_template, format=, numeric=, no_argument= | |
TemplatedFilter( | |
"exact", | |
"=", | |
'"{c}" = :{p}', | |
lambda c, v: "{c} = {v}" if v.isdigit() else '{c} = "{v}"', | |
), | |
TemplatedFilter( | |
"not", | |
"!=", | |
'"{c}" != :{p}', | |
lambda c, v: "{c} != {v}" if v.isdigit() else '{c} != "{v}"', | |
), | |
TemplatedFilter( | |
"contains", | |
"contains", | |
'"{c}" like :{p}', | |
'{c} contains "{v}"', | |
format="%{}%", | |
), | |
TemplatedFilter( | |
"notcontains", | |
"does not contain", | |
'"{c}" not like :{p}', | |
'{c} does not contain "{v}"', | |
format="%{}%", | |
), | |
TemplatedFilter( | |
"endswith", | |
"ends with", | |
'"{c}" like :{p}', | |
'{c} ends with "{v}"', | |
format="%{}", | |
), | |
TemplatedFilter( | |
"startswith", | |
"starts with", | |
'"{c}" like :{p}', | |
'{c} starts with "{v}"', | |
format="{}%", | |
), | |
TemplatedFilter("gt", ">", '"{c}" > :{p}', "{c} > {v}", numeric=True), | |
TemplatedFilter( | |
"gte", "\u2265", '"{c}" >= :{p}', "{c} \u2265 {v}", numeric=True | |
), | |
TemplatedFilter("lt", "<", '"{c}" < :{p}', "{c} < {v}", numeric=True), | |
TemplatedFilter( | |
"lte", "\u2264", '"{c}" <= :{p}', "{c} \u2264 {v}", numeric=True | |
), | |
TemplatedFilter("like", "like", '"{c}" like :{p}', '{c} like "{v}"'), | |
TemplatedFilter( | |
"notlike", "not like", '"{c}" not like :{p}', '{c} not like "{v}"' | |
), | |
TemplatedFilter("glob", "glob", '"{c}" glob :{p}', '{c} glob "{v}"'), | |
InFilter(), | |
NotInFilter(), | |
] | |
+ ( | |
[ | |
TemplatedFilter( | |
"arraycontains", | |
"array contains", | |
""":{p} in (select value from json_each([{t}].[{c}]))""", | |
'{c} contains "{v}"', | |
), | |
TemplatedFilter( | |
"arraynotcontains", | |
"array does not contain", | |
""":{p} not in (select value from json_each([{t}].[{c}]))""", | |
'{c} does not contain "{v}"', | |
), | |
] | |
if detect_json1() | |
else [] | |
) | |
+ [ | |
TemplatedFilter( | |
"date", "date", 'date("{c}") = :{p}', '"{c}" is on date {v}' | |
), | |
TemplatedFilter( | |
"isnull", "is null", '"{c}" is null', "{c} is null", no_argument=True | |
), | |
TemplatedFilter( | |
"notnull", | |
"is not null", | |
'"{c}" is not null', | |
"{c} is not null", | |
no_argument=True, | |
), | |
TemplatedFilter( | |
"isblank", | |
"is blank", | |
'("{c}" is null or "{c}" = "")', | |
"{c} is blank", | |
no_argument=True, | |
), | |
TemplatedFilter( | |
"notblank", | |
"is not blank", | |
'("{c}" is not null and "{c}" != "")', | |
"{c} is not blank", | |
no_argument=True, | |
), | |
] | |
) | |
_filters_by_key = {f.key: f for f in _filters} | |
def __init__(self, pairs): | |
self.pairs = pairs | |
def lookups(self): | |
"""Yields (lookup, display, no_argument) pairs""" | |
for filter in self._filters: | |
yield filter.key, filter.display, filter.no_argument | |
def human_description_en(self, extra=None): | |
bits = [] | |
if extra: | |
bits.extend(extra) | |
for column, lookup, value in self.selections(): | |
filter = self._filters_by_key.get(lookup, None) | |
if filter: | |
bits.append(filter.human_clause(column, value)) | |
# Comma separated, with an ' and ' at the end | |
and_bits = [] | |
commas, tail = bits[:-1], bits[-1:] | |
if commas: | |
and_bits.append(", ".join(commas)) | |
if tail: | |
and_bits.append(tail[0]) | |
s = " and ".join(and_bits) | |
if not s: | |
return "" | |
return f"where {s}" | |
def selections(self): | |
"""Yields (column, lookup, value) tuples""" | |
for key, value in self.pairs: | |
if "__" in key: | |
column, lookup = key.rsplit("__", 1) | |
else: | |
column = key | |
lookup = "exact" | |
yield column, lookup, value | |
def has_selections(self): | |
return bool(self.pairs) | |
def build_where_clauses(self, table): | |
sql_bits = [] | |
params = {} | |
i = 0 | |
for column, lookup, value in self.selections(): | |
filter = self._filters_by_key.get(lookup, None) | |
if filter: | |
sql_bit, param = filter.where_clause(table, column, value, i) | |
sql_bits.append(sql_bit) | |
if param is not None: | |
if not isinstance(param, list): | |
param = [param] | |
for individual_param in param: | |
param_id = f"p{i}" | |
params[param_id] = individual_param | |
i += 1 | |
return sql_bits, params | |
</document_content> | |
</document> | |
<document index="14"> | |
<source>datasette/forbidden.py</source> | |
<document_content> | |
from datasette import hookimpl, Response | |
@hookimpl(trylast=True) | |
def forbidden(datasette, request, message): | |
async def inner(): | |
return Response.html( | |
await datasette.render_template( | |
"error.html", | |
{ | |
"title": "Forbidden", | |
"error": message, | |
}, | |
request=request, | |
), | |
status=403, | |
) | |
return inner | |
</document_content> | |
</document> | |
<document index="15"> | |
<source>datasette/handle_exception.py</source> | |
<document_content> | |
from datasette import hookimpl, Response | |
from .utils import add_cors_headers | |
from .utils.asgi import ( | |
Base400, | |
) | |
from .views.base import DatasetteError | |
from markupsafe import Markup | |
import traceback | |
try: | |
import ipdb as pdb | |
except ImportError: | |
import pdb | |
try: | |
import rich | |
except ImportError: | |
rich = None | |
@hookimpl(trylast=True) | |
def handle_exception(datasette, request, exception): | |
async def inner(): | |
if datasette.pdb: | |
pdb.post_mortem(exception.__traceback__) | |
if rich is not None: | |
rich.get_console().print_exception(show_locals=True) | |
title = None | |
if isinstance(exception, Base400): | |
status = exception.status | |
info = {} | |
message = exception.args[0] | |
elif isinstance(exception, DatasetteError): | |
status = exception.status | |
info = exception.error_dict | |
message = exception.message | |
if exception.message_is_html: | |
message = Markup(message) | |
title = exception.title | |
else: | |
status = 500 | |
info = {} | |
message = str(exception) | |
traceback.print_exc() | |
templates = [f"{status}.html", "error.html"] | |
info.update( | |
{ | |
"ok": False, | |
"error": message, | |
"status": status, | |
"title": title, | |
} | |
) | |
headers = {} | |
if datasette.cors: | |
add_cors_headers(headers) | |
if request.path.split("?")[0].endswith(".json"): | |
return Response.json(info, status=status, headers=headers) | |
else: | |
environment = datasette.get_jinja_environment(request) | |
template = environment.select_template(templates) | |
return Response.html( | |
await template.render_async( | |
dict( | |
info, | |
urls=datasette.urls, | |
app_css_hash=datasette.app_css_hash(), | |
menu_links=lambda: [], | |
) | |
), | |
status=status, | |
headers=headers, | |
) | |
return inner | |
</document_content> | |
</document> | |
<document index="16"> | |
<source>datasette/hookspecs.py</source> | |
<document_content> | |
from pluggy import HookimplMarker | |
from pluggy import HookspecMarker | |
hookspec = HookspecMarker("datasette") | |
hookimpl = HookimplMarker("datasette") | |
@hookspec | |
def startup(datasette): | |
"""Fires directly after Datasette first starts running""" | |
@hookspec | |
def asgi_wrapper(datasette): | |
"""Returns an ASGI middleware callable to wrap our ASGI application with""" | |
@hookspec | |
def prepare_connection(conn, database, datasette): | |
"""Modify SQLite connection in some way e.g. register custom SQL functions""" | |
@hookspec | |
def prepare_jinja2_environment(env, datasette): | |
"""Modify Jinja2 template environment e.g. register custom template tags""" | |
@hookspec | |
def extra_css_urls(template, database, table, columns, view_name, request, datasette): | |
"""Extra CSS URLs added by this plugin""" | |
@hookspec | |
def extra_js_urls(template, database, table, columns, view_name, request, datasette): | |
"""Extra JavaScript URLs added by this plugin""" | |
@hookspec | |
def extra_body_script( | |
template, database, table, columns, view_name, request, datasette | |
): | |
"""Extra JavaScript code to be included in <script> at bottom of body""" | |
@hookspec | |
def extra_template_vars( | |
template, database, table, columns, view_name, request, datasette | |
): | |
"""Extra template variables to be made available to the template - can return dict or callable or awaitable""" | |
@hookspec | |
def publish_subcommand(publish): | |
"""Subcommands for 'datasette publish'""" | |
@hookspec | |
def render_cell(row, value, column, table, database, datasette, request): | |
"""Customize rendering of HTML table cell values""" | |
@hookspec | |
def register_output_renderer(datasette): | |
"""Register a renderer to output data in a different format""" | |
@hookspec | |
def register_facet_classes(): | |
"""Register Facet subclasses""" | |
@hookspec | |
def register_permissions(datasette): | |
"""Register permissions: returns a list of datasette.permission.Permission named tuples""" | |
@hookspec | |
def register_routes(datasette): | |
"""Register URL routes: return a list of (regex, view_function) pairs""" | |
@hookspec | |
def register_commands(cli): | |
"""Register additional CLI commands, e.g. 'datasette mycommand ...'""" | |
@hookspec | |
def actor_from_request(datasette, request): | |
"""Return an actor dictionary based on the incoming request""" | |
@hookspec(firstresult=True) | |
def actors_from_ids(datasette, actor_ids): | |
"""Returns a dictionary mapping those IDs to actor dictionaries""" | |
@hookspec | |
def jinja2_environment_from_request(datasette, request, env): | |
"""Return a Jinja2 environment based on the incoming request""" | |
@hookspec | |
def filters_from_request(request, database, table, datasette): | |
""" | |
Return datasette.filters.FilterArguments( | |
where_clauses=[str, str, str], | |
params={}, | |
human_descriptions=[str, str, str], | |
extra_context={} | |
) based on the request""" | |
@hookspec | |
def permission_allowed(datasette, actor, action, resource): | |
"""Check if actor is allowed to perform this action - return True, False or None""" | |
@hookspec | |
def canned_queries(datasette, database, actor): | |
"""Return a dictionary of canned query definitions or an awaitable function that returns them""" | |
@hookspec | |
def register_magic_parameters(datasette): | |
"""Return a list of (name, function) magic parameter functions""" | |
@hookspec | |
def forbidden(datasette, request, message): | |
"""Custom response for a 403 forbidden error""" | |
@hookspec | |
def menu_links(datasette, actor, request): | |
"""Links for the navigation menu""" | |
@hookspec | |
def row_actions(datasette, actor, request, database, table, row): | |
"""Links for the row actions menu""" | |
@hookspec | |
def table_actions(datasette, actor, database, table, request): | |
"""Links for the table actions menu""" | |
@hookspec | |
def view_actions(datasette, actor, database, view, request): | |
"""Links for the view actions menu""" | |
@hookspec | |
def query_actions(datasette, actor, database, query_name, request, sql, params): | |
"""Links for the query and canned query actions menu""" | |
@hookspec | |
def database_actions(datasette, actor, database, request): | |
"""Links for the database actions menu""" | |
@hookspec | |
def homepage_actions(datasette, actor, request): | |
"""Links for the homepage actions menu""" | |
@hookspec | |
def skip_csrf(datasette, scope): | |
"""Mechanism for skipping CSRF checks for certain requests""" | |
@hookspec | |
def handle_exception(datasette, request, exception): | |
"""Handle an uncaught exception. Can return a Response or None.""" | |
@hookspec | |
def track_event(datasette, event): | |
"""Respond to an event tracked by Datasette""" | |
@hookspec | |
def register_events(datasette): | |
"""Return a list of Event subclasses to use with track_event()""" | |
@hookspec | |
def top_homepage(datasette, request): | |
"""HTML to include at the top of the homepage""" | |
@hookspec | |
def top_database(datasette, request, database): | |
"""HTML to include at the top of the database page""" | |
@hookspec | |
def top_table(datasette, request, database, table): | |
"""HTML to include at the top of the table page""" | |
@hookspec | |
def top_row(datasette, request, database, table, row): | |
"""HTML to include at the top of the row page""" | |
@hookspec | |
def top_query(datasette, request, database, sql): | |
"""HTML to include at the top of the query results page""" | |
@hookspec | |
def top_canned_query(datasette, request, database, query_name): | |
"""HTML to include at the top of the canned query page""" | |
</document_content> | |
</document> | |
<document index="17"> | |
<source>datasette/inspect.py</source> | |
<document_content> | |
import hashlib | |
from .utils import ( | |
detect_spatialite, | |
detect_fts, | |
detect_primary_keys, | |
escape_sqlite, | |
get_all_foreign_keys, | |
table_columns, | |
sqlite3, | |
) | |
HASH_BLOCK_SIZE = 1024 * 1024 | |
def inspect_hash(path): | |
"""Calculate the hash of a database, efficiently.""" | |
m = hashlib.sha256() | |
with path.open("rb") as fp: | |
while True: | |
data = fp.read(HASH_BLOCK_SIZE) | |
if not data: | |
break | |
m.update(data) | |
return m.hexdigest() | |
def inspect_views(conn): | |
"""List views in a database.""" | |
return [ | |
v[0] for v in conn.execute('select name from sqlite_master where type = "view"') | |
] | |
def inspect_tables(conn, database_metadata): | |
"""List tables and their row counts, excluding uninteresting tables.""" | |
tables = {} | |
table_names = [ | |
r["name"] | |
for r in conn.execute('select * from sqlite_master where type="table"') | |
] | |
for table in table_names: | |
table_metadata = database_metadata.get("tables", {}).get(table, {}) | |
try: | |
count = conn.execute( | |
f"select count(*) from {escape_sqlite(table)}" | |
).fetchone()[0] | |
except sqlite3.OperationalError: | |
# This can happen when running against a FTS virtual table | |
# e.g. "select count(*) from some_fts;" | |
count = 0 | |
column_names = table_columns(conn, table) | |
tables[table] = { | |
"name": table, | |
"columns": column_names, | |
"primary_keys": detect_primary_keys(conn, table), | |
"count": count, | |
"hidden": table_metadata.get("hidden") or False, | |
"fts_table": detect_fts(conn, table), | |
} | |
foreign_keys = get_all_foreign_keys(conn) | |
for table, info in foreign_keys.items(): | |
tables[table]["foreign_keys"] = info | |
# Mark tables 'hidden' if they relate to FTS virtual tables | |
hidden_tables = [ | |
r["name"] | |
for r in conn.execute( | |
""" | |
select name from sqlite_master | |
where rootpage = 0 | |
and sql like '%VIRTUAL TABLE%USING FTS%' | |
""" | |
) | |
] | |
if detect_spatialite(conn): | |
# Also hide Spatialite internal tables | |
hidden_tables += [ | |
"ElementaryGeometries", | |
"SpatialIndex", | |
"geometry_columns", | |
"spatial_ref_sys", | |
"spatialite_history", | |
"sql_statements_log", | |
"sqlite_sequence", | |
"views_geometry_columns", | |
"virts_geometry_columns", | |
] + [ | |
r["name"] | |
for r in conn.execute( | |
""" | |
select name from sqlite_master | |
where name like "idx_%" | |
and type = "table" | |
""" | |
) | |
] | |
for t in tables.keys(): | |
for hidden_table in hidden_tables: | |
if t == hidden_table or t.startswith(hidden_table): | |
tables[t]["hidden"] = True | |
continue | |
return tables | |
</document_content> | |
</document> | |
<document index="18"> | |
<source>datasette/permissions.py</source> | |
<document_content> | |
from dataclasses import dataclass | |
from typing import Optional | |
@dataclass | |
class Permission: | |
name: str | |
abbr: Optional[str] | |
description: Optional[str] | |
takes_database: bool | |
takes_resource: bool | |
default: bool | |
# This is deliberately undocumented: it's considered an internal | |
# implementation detail for view-table/view-database and should | |
# not be used by plugins as it may change in the future. | |
implies_can_view: bool = False | |
</document_content> | |
</document> | |
<document index="19"> | |
<source>datasette/plugins.py</source> | |
<document_content> | |
import importlib | |
import os | |
import pluggy | |
from pprint import pprint | |
import sys | |
from . import hookspecs | |
if sys.version_info >= (3, 9): | |
import importlib.resources as importlib_resources | |
else: | |
import importlib_resources | |
if sys.version_info >= (3, 10): | |
import importlib.metadata as importlib_metadata | |
else: | |
import importlib_metadata | |
DEFAULT_PLUGINS = ( | |
"datasette.publish.heroku", | |
"datasette.publish.cloudrun", | |
"datasette.facets", | |
"datasette.filters", | |
"datasette.sql_functions", | |
"datasette.actor_auth_cookie", | |
"datasette.default_permissions", | |
"datasette.default_magic_parameters", | |
"datasette.blob_renderer", | |
"datasette.default_menu_links", | |
"datasette.handle_exception", | |
"datasette.forbidden", | |
"datasette.events", | |
) | |
pm = pluggy.PluginManager("datasette") | |
pm.add_hookspecs(hookspecs) | |
DATASETTE_TRACE_PLUGINS = os.environ.get("DATASETTE_TRACE_PLUGINS", None) | |
def before(hook_name, hook_impls, kwargs): | |
print(file=sys.stderr) | |
print(f"{hook_name}:", file=sys.stderr) | |
pprint(kwargs, width=40, indent=4, stream=sys.stderr) | |
print("Hook implementations:", file=sys.stderr) | |
pprint(hook_impls, width=40, indent=4, stream=sys.stderr) | |
def after(outcome, hook_name, hook_impls, kwargs): | |
results = outcome.get_result() | |
if not isinstance(results, list): | |
results = [results] | |
print(f"Results:", file=sys.stderr) | |
pprint(results, width=40, indent=4, stream=sys.stderr) | |
if DATASETTE_TRACE_PLUGINS: | |
pm.add_hookcall_monitoring(before, after) | |
DATASETTE_LOAD_PLUGINS = os.environ.get("DATASETTE_LOAD_PLUGINS", None) | |
if not hasattr(sys, "_called_from_test") and DATASETTE_LOAD_PLUGINS is None: | |
# Only load plugins if not running tests | |
pm.load_setuptools_entrypoints("datasette") | |
# Load any plugins specified in DATASETTE_LOAD_PLUGINS") | |
if DATASETTE_LOAD_PLUGINS is not None: | |
for package_name in [ | |
name for name in DATASETTE_LOAD_PLUGINS.split(",") if name.strip() | |
]: | |
try: | |
distribution = importlib_metadata.distribution(package_name) | |
entry_points = distribution.entry_points | |
for entry_point in entry_points: | |
if entry_point.group == "datasette": | |
mod = entry_point.load() | |
pm.register(mod, name=entry_point.name) | |
# Ensure name can be found in plugin_to_distinfo later: | |
pm._plugin_distinfo.append((mod, distribution)) | |
except importlib_metadata.PackageNotFoundError: | |
sys.stderr.write("Plugin {} could not be found\n".format(package_name)) | |
# Load default plugins | |
for plugin in DEFAULT_PLUGINS: | |
mod = importlib.import_module(plugin) | |
pm.register(mod, plugin) | |
def get_plugins(): | |
plugins = [] | |
plugin_to_distinfo = dict(pm.list_plugin_distinfo()) | |
for plugin in pm.get_plugins(): | |
static_path = None | |
templates_path = None | |
if plugin.__name__ not in DEFAULT_PLUGINS: | |
try: | |
if (importlib_resources.files(plugin.__name__) / "static").is_dir(): | |
static_path = str( | |
importlib_resources.files(plugin.__name__) / "static" | |
) | |
if (importlib_resources.files(plugin.__name__) / "templates").is_dir(): | |
templates_path = str( | |
importlib_resources.files(plugin.__name__) / "templates" | |
) | |
except (TypeError, ModuleNotFoundError): | |
# Caused by --plugins_dir= plugins | |
pass | |
plugin_info = { | |
"name": plugin.__name__, | |
"static_path": static_path, | |
"templates_path": templates_path, | |
"hooks": [h.name for h in pm.get_hookcallers(plugin)], | |
} | |
distinfo = plugin_to_distinfo.get(plugin) | |
if distinfo: | |
plugin_info["version"] = distinfo.version | |
plugin_info["name"] = distinfo.name or distinfo.project_name | |
plugins.append(plugin_info) | |
return plugins | |
</document_content> | |
</document> | |
<document index="20"> | |
<source>datasette/renderer.py</source> | |
<document_content> | |
import json | |
from datasette.utils import ( | |
value_as_boolean, | |
remove_infinites, | |
CustomJSONEncoder, | |
path_from_row_pks, | |
sqlite3, | |
) | |
from datasette.utils.asgi import Response | |
def convert_specific_columns_to_json(rows, columns, json_cols): | |
json_cols = set(json_cols) | |
if not json_cols.intersection(columns): | |
return rows | |
new_rows = [] | |
for row in rows: | |
new_row = [] | |
for value, column in zip(row, columns): | |
if column in json_cols: | |
try: | |
value = json.loads(value) | |
except (TypeError, ValueError) as e: | |
pass | |
new_row.append(value) | |
new_rows.append(new_row) | |
return new_rows | |
def json_renderer(request, args, data, error, truncated=None): | |
"""Render a response as JSON""" | |
status_code = 200 | |
# Handle the _json= parameter which may modify data["rows"] | |
json_cols = [] | |
if "_json" in args: | |
json_cols = args.getlist("_json") | |
if json_cols and "rows" in data and "columns" in data: | |
data["rows"] = convert_specific_columns_to_json( | |
data["rows"], data["columns"], json_cols | |
) | |
# unless _json_infinity=1 requested, replace infinity with None | |
if "rows" in data and not value_as_boolean(args.get("_json_infinity", "0")): | |
data["rows"] = [remove_infinites(row) for row in data["rows"]] | |
# Deal with the _shape option | |
shape = args.get("_shape", "objects") | |
# if there's an error, ignore the shape entirely | |
data["ok"] = True | |
if error: | |
shape = "objects" | |
status_code = 400 | |
data["error"] = error | |
data["ok"] = False | |
if truncated is not None: | |
data["truncated"] = truncated | |
if shape == "arrayfirst": | |
if not data["rows"]: | |
data = [] | |
elif isinstance(data["rows"][0], sqlite3.Row): | |
data = [row[0] for row in data["rows"]] | |
else: | |
assert isinstance(data["rows"][0], dict) | |
data = [next(iter(row.values())) for row in data["rows"]] | |
elif shape in ("objects", "object", "array"): | |
columns = data.get("columns") | |
rows = data.get("rows") | |
if rows and columns and not isinstance(rows[0], dict): | |
data["rows"] = [dict(zip(columns, row)) for row in rows] | |
if shape == "object": | |
shape_error = None | |
if "primary_keys" not in data: | |
shape_error = "_shape=object is only available on tables" | |
else: | |
pks = data["primary_keys"] | |
if not pks: | |
shape_error = ( | |
"_shape=object not available for tables with no primary keys" | |
) | |
else: | |
object_rows = {} | |
for row in data["rows"]: | |
pk_string = path_from_row_pks(row, pks, not pks) | |
object_rows[pk_string] = row | |
data = object_rows | |
if shape_error: | |
data = {"ok": False, "error": shape_error} | |
elif shape == "array": | |
data = data["rows"] | |
elif shape == "arrays": | |
if not data["rows"]: | |
pass | |
elif isinstance(data["rows"][0], sqlite3.Row): | |
data["rows"] = [list(row) for row in data["rows"]] | |
else: | |
data["rows"] = [list(row.values()) for row in data["rows"]] | |
else: | |
status_code = 400 | |
data = { | |
"ok": False, | |
"error": f"Invalid _shape: {shape}", | |
"status": 400, | |
"title": None, | |
} | |
# Don't include "columns" in output | |
# https://github.com/simonw/datasette/issues/2136 | |
if isinstance(data, dict) and "columns" not in request.args.getlist("_extra"): | |
data.pop("columns", None) | |
# Handle _nl option for _shape=array | |
nl = args.get("_nl", "") | |
if nl and shape == "array": | |
body = "\n".join(json.dumps(item, cls=CustomJSONEncoder) for item in data) | |
content_type = "text/plain" | |
else: | |
body = json.dumps(data, cls=CustomJSONEncoder) | |
content_type = "application/json; charset=utf-8" | |
headers = {} | |
return Response( | |
body, status=status_code, headers=headers, content_type=content_type | |
) | |
</document_content> | |
</document> | |
<document index="21"> | |
<source>datasette/sql_functions.py</source> | |
<document_content> | |
from datasette import hookimpl | |
from datasette.utils import escape_fts | |
@hookimpl | |
def prepare_connection(conn): | |
conn.create_function("escape_fts", 1, escape_fts) | |
</document_content> | |
</document> | |
<document index="22"> | |
<source>datasette/tracer.py</source> | |
<document_content> | |
import asyncio | |
from contextlib import contextmanager | |
from contextvars import ContextVar | |
from markupsafe import escape | |
import time | |
import json | |
import traceback | |
tracers = {} | |
TRACE_RESERVED_KEYS = {"type", "start", "end", "duration_ms", "traceback"} | |
trace_task_id = ContextVar("trace_task_id", default=None) | |
def get_task_id(): | |
current = trace_task_id.get(None) | |
if current is not None: | |
return current | |
try: | |
loop = asyncio.get_event_loop() | |
except RuntimeError: | |
return None | |
return id(asyncio.current_task(loop=loop)) | |
@contextmanager | |
def trace_child_tasks(): | |
token = trace_task_id.set(get_task_id()) | |
yield | |
trace_task_id.reset(token) | |
@contextmanager | |
def trace(trace_type, **kwargs): | |
assert not TRACE_RESERVED_KEYS.intersection( | |
kwargs.keys() | |
), f".trace() keyword parameters cannot include {TRACE_RESERVED_KEYS}" | |
task_id = get_task_id() | |
if task_id is None: | |
yield kwargs | |
return | |
tracer = tracers.get(task_id) | |
if tracer is None: | |
yield kwargs | |
return | |
start = time.perf_counter() | |
captured_error = None | |
try: | |
yield kwargs | |
except Exception as ex: | |
captured_error = ex | |
raise | |
finally: | |
end = time.perf_counter() | |
trace_info = { | |
"type": trace_type, | |
"start": start, | |
"end": end, | |
"duration_ms": (end - start) * 1000, | |
"traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), | |
"error": str(captured_error) if captured_error else None, | |
} | |
trace_info.update(kwargs) | |
tracer.append(trace_info) | |
@contextmanager | |
def capture_traces(tracer): | |
# tracer is a list | |
task_id = get_task_id() | |
if task_id is None: | |
yield | |
return | |
tracers[task_id] = tracer | |
yield | |
del tracers[task_id] | |
class AsgiTracer: | |
# If the body is larger than this we don't attempt to append the trace | |
max_body_bytes = 1024 * 256 # 256 KB | |
def __init__(self, app): | |
self.app = app | |
async def __call__(self, scope, receive, send): | |
if b"_trace=1" not in scope.get("query_string", b"").split(b"&"): | |
await self.app(scope, receive, send) | |
return | |
trace_start = time.perf_counter() | |
traces = [] | |
accumulated_body = b"" | |
size_limit_exceeded = False | |
response_headers = [] | |
async def wrapped_send(message): | |
nonlocal accumulated_body, size_limit_exceeded, response_headers | |
if message["type"] == "http.response.start": | |
response_headers = message["headers"] | |
await send(message) | |
return | |
if message["type"] != "http.response.body" or size_limit_exceeded: | |
await send(message) | |
return | |
# Accumulate body until the end or until size is exceeded | |
accumulated_body += message["body"] | |
if len(accumulated_body) > self.max_body_bytes: | |
# Send what we have accumulated so far | |
await send( | |
{ | |
"type": "http.response.body", | |
"body": accumulated_body, | |
"more_body": bool(message.get("more_body")), | |
} | |
) | |
size_limit_exceeded = True | |
return | |
if not message.get("more_body"): | |
# We have all the body - modify it and send the result | |
# TODO: What to do about Content-Type or other cases? | |
trace_info = { | |
"request_duration_ms": 1000 * (time.perf_counter() - trace_start), | |
"sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), | |
"num_traces": len(traces), | |
"traces": traces, | |
} | |
try: | |
content_type = [ | |
v.decode("utf8") | |
for k, v in response_headers | |
if k.lower() == b"content-type" | |
][0] | |
except IndexError: | |
content_type = "" | |
if "text/html" in content_type and b"</body>" in accumulated_body: | |
extra = escape(json.dumps(trace_info, indent=2)) | |
extra_html = f"<pre>{extra}</pre></body>".encode("utf8") | |
accumulated_body = accumulated_body.replace(b"</body>", extra_html) | |
elif "json" in content_type and accumulated_body.startswith(b"{"): | |
data = json.loads(accumulated_body.decode("utf8")) | |
if "_trace" not in data: | |
data["_trace"] = trace_info | |
accumulated_body = json.dumps(data).encode("utf8") | |
await send({"type": "http.response.body", "body": accumulated_body}) | |
with capture_traces(traces): | |
await self.app(scope, receive, wrapped_send) | |
</document_content> | |
</document> | |
<document index="23"> | |
<source>datasette/url_builder.py</source> | |
<document_content> | |
from .utils import tilde_encode, path_with_format, PrefixedUrlString | |
import urllib | |
class Urls: | |
def __init__(self, ds): | |
self.ds = ds | |
def path(self, path, format=None): | |
if not isinstance(path, PrefixedUrlString): | |
if path.startswith("/"): | |
path = path[1:] | |
path = self.ds.setting("base_url") + path | |
if format is not None: | |
path = path_with_format(path=path, format=format) | |
return PrefixedUrlString(path) | |
def instance(self, format=None): | |
return self.path("", format=format) | |
def static(self, path): | |
return self.path(f"-/static/{path}") | |
def static_plugins(self, plugin, path): | |
return self.path(f"-/static-plugins/{plugin}/{path}") | |
def logout(self): | |
return self.path("-/logout") | |
def database(self, database, format=None): | |
db = self.ds.get_database(database) | |
return self.path(tilde_encode(db.route), format=format) | |
def database_query(self, database, sql, format=None): | |
path = f"{self.database(database)}/-/query?" + urllib.parse.urlencode( | |
{"sql": sql} | |
) | |
return self.path(path, format=format) | |
def table(self, database, table, format=None): | |
path = f"{self.database(database)}/{tilde_encode(table)}" | |
if format is not None: | |
path = path_with_format(path=path, format=format) | |
return PrefixedUrlString(path) | |
def query(self, database, query, format=None): | |
path = f"{self.database(database)}/{tilde_encode(query)}" | |
if format is not None: | |
path = path_with_format(path=path, format=format) | |
return PrefixedUrlString(path) | |
def row(self, database, table, row_path, format=None): | |
path = f"{self.table(database, table)}/{row_path}" | |
if format is not None: | |
path = path_with_format(path=path, format=format) | |
return PrefixedUrlString(path) | |
def row_blob(self, database, table, row_path, column): | |
return self.table(database, table) + "/{}.blob?_blob_column={}".format( | |
row_path, urllib.parse.quote_plus(column) | |
) | |
</document_content> | |
</document> | |
<document index="24"> | |
<source>datasette/version.py</source> | |
<document_content> | |
__version__ = "1.0a16" | |
__version_info__ = tuple(__version__.split(".")) | |
</document_content> | |
</document> | |
<document index="25"> | |
<source>datasette/utils/__init__.py</source> | |
<document_content> | |
import asyncio | |
from contextlib import contextmanager | |
import aiofiles | |
import click | |
from collections import OrderedDict, namedtuple, Counter | |
import copy | |
import base64 | |
import hashlib | |
import inspect | |
import json | |
import markupsafe | |
import mergedeep | |
import os | |
import re | |
import shlex | |
import tempfile | |
import typing | |
import time | |
import types | |
import secrets | |
import shutil | |
from typing import Iterable, List, Tuple | |
import urllib | |
import yaml | |
from .shutil_backport import copytree | |
from .sqlite import sqlite3, supports_table_xinfo | |
if typing.TYPE_CHECKING: | |
from datasette.database import Database | |
# From https://www.sqlite.org/lang_keywords.html | |
reserved_words = set( | |
( | |
"abort action add after all alter analyze and as asc attach autoincrement " | |
"before begin between by cascade case cast check collate column commit " | |
"conflict constraint create cross current_date current_time " | |
"current_timestamp database default deferrable deferred delete desc detach " | |
"distinct drop each else end escape except exclusive exists explain fail " | |
"for foreign from full glob group having if ignore immediate in index " | |
"indexed initially inner insert instead intersect into is isnull join key " | |
"left like limit match natural no not notnull null of offset on or order " | |
"outer plan pragma primary query raise recursive references regexp reindex " | |
"release rename replace restrict right rollback row savepoint select set " | |
"table temp temporary then to transaction trigger union unique update using " | |
"vacuum values view virtual when where with without" | |
).split() | |
) | |
APT_GET_DOCKERFILE_EXTRAS = r""" | |
RUN apt-get update && \ | |
apt-get install -y {} && \ | |
rm -rf /var/lib/apt/lists/* | |
""" | |
# Can replace with sqlite-utils when I add that dependency | |
SPATIALITE_PATHS = ( | |
"/usr/lib/x86_64-linux-gnu/mod_spatialite.so", | |
"/usr/local/lib/mod_spatialite.dylib", | |
"/usr/local/lib/mod_spatialite.so", | |
"/opt/homebrew/lib/mod_spatialite.dylib", | |
) | |
# Used to display /-/versions.json SpatiaLite information | |
SPATIALITE_FUNCTIONS = ( | |
"spatialite_version", | |
"spatialite_target_cpu", | |
"check_strict_sql_quoting", | |
"freexl_version", | |
"proj_version", | |
"geos_version", | |
"rttopo_version", | |
"libxml2_version", | |
"HasIconv", | |
"HasMathSQL", | |
"HasGeoCallbacks", | |
"HasProj", | |
"HasProj6", | |
"HasGeos", | |
"HasGeosAdvanced", | |
"HasGeosTrunk", | |
"HasGeosReentrant", | |
"HasGeosOnlyReentrant", | |
"HasMiniZip", | |
"HasRtTopo", | |
"HasLibXML2", | |
"HasEpsg", | |
"HasFreeXL", | |
"HasGeoPackage", | |
"HasGCP", | |
"HasTopology", | |
"HasKNN", | |
"HasRouting", | |
) | |
# Length of hash subset used in hashed URLs: | |
HASH_LENGTH = 7 | |
# Can replace this with Column from sqlite_utils when I add that dependency | |
Column = namedtuple( | |
"Column", ("cid", "name", "type", "notnull", "default_value", "is_pk", "hidden") | |
) | |
functions_marked_as_documented = [] | |
def documented(fn): | |
functions_marked_as_documented.append(fn) | |
return fn | |
@documented | |
async def await_me_maybe(value: typing.Any) -> typing.Any: | |
"If value is callable, call it. If awaitable, await it. Otherwise return it." | |
if callable(value): | |
value = value() | |
if asyncio.iscoroutine(value): | |
value = await value | |
return value | |
def urlsafe_components(token): | |
"""Splits token on commas and tilde-decodes each component""" | |
return [tilde_decode(b) for b in token.split(",")] | |
def path_from_row_pks(row, pks, use_rowid, quote=True): | |
"""Generate an optionally tilde-encoded unique identifier | |
for a row from its primary keys.""" | |
if use_rowid: | |
bits = [row["rowid"]] | |
else: | |
bits = [ | |
row[pk]["value"] if isinstance(row[pk], dict) else row[pk] for pk in pks | |
] | |
if quote: | |
bits = [tilde_encode(str(bit)) for bit in bits] | |
else: | |
bits = [str(bit) for bit in bits] | |
return ",".join(bits) | |
def compound_keys_after_sql(pks, start_index=0): | |
# Implementation of keyset pagination | |
# See https://github.com/simonw/datasette/issues/190 | |
# For pk1/pk2/pk3 returns: | |
# | |
# ([pk1] > :p0) | |
# or | |
# ([pk1] = :p0 and [pk2] > :p1) | |
# or | |
# ([pk1] = :p0 and [pk2] = :p1 and [pk3] > :p2) | |
or_clauses = [] | |
pks_left = pks[:] | |
while pks_left: | |
and_clauses = [] | |
last = pks_left[-1] | |
rest = pks_left[:-1] | |
and_clauses = [ | |
f"{escape_sqlite(pk)} = :p{i + start_index}" for i, pk in enumerate(rest) | |
] | |
and_clauses.append(f"{escape_sqlite(last)} > :p{len(rest) + start_index}") | |
or_clauses.append(f"({' and '.join(and_clauses)})") | |
pks_left.pop() | |
or_clauses.reverse() | |
return "({})".format("\n or\n".join(or_clauses)) | |
class CustomJSONEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, sqlite3.Row): | |
return tuple(obj) | |
if isinstance(obj, sqlite3.Cursor): | |
return list(obj) | |
if isinstance(obj, bytes): | |
# Does it encode to utf8? | |
try: | |
return obj.decode("utf8") | |
except UnicodeDecodeError: | |
return { | |
"$base64": True, | |
"encoded": base64.b64encode(obj).decode("latin1"), | |
} | |
return json.JSONEncoder.default(self, obj) | |
@contextmanager | |
def sqlite_timelimit(conn, ms): | |
deadline = time.perf_counter() + (ms / 1000) | |
# n is the number of SQLite virtual machine instructions that will be | |
# executed between each check. It takes about 0.08ms to execute 1000. | |
# https://github.com/simonw/datasette/issues/1679 | |
n = 1000 | |
if ms <= 20: | |
# This mainly happens while executing our test suite | |
n = 1 | |
def handler(): | |
if time.perf_counter() >= deadline: | |
# Returning 1 terminates the query with an error | |
return 1 | |
conn.set_progress_handler(handler, n) | |
try: | |
yield | |
finally: | |
conn.set_progress_handler(None, n) | |
class InvalidSql(Exception): | |
pass | |
# Allow SQL to start with a /* */ or -- comment | |
comment_re = ( | |
# Start of string, then any amount of whitespace | |
r"^\s*(" | |
+ | |
# Comment that starts with -- and ends at a newline | |
r"(?:\-\-.*?\n\s*)" | |
+ | |
# Comment that starts with /* and ends with */ - but does not have */ in it | |
r"|(?:\/\*((?!\*\/)[\s\S])*\*\/)" | |
+ | |
# Whitespace | |
r"\s*)*\s*" | |
) | |
allowed_sql_res = [ | |
re.compile(comment_re + r"select\b"), | |
re.compile(comment_re + r"explain\s+select\b"), | |
re.compile(comment_re + r"explain\s+query\s+plan\s+select\b"), | |
re.compile(comment_re + r"with\b"), | |
re.compile(comment_re + r"explain\s+with\b"), | |
re.compile(comment_re + r"explain\s+query\s+plan\s+with\b"), | |
] | |
allowed_pragmas = ( | |
"database_list", | |
"foreign_key_list", | |
"function_list", | |
"index_info", | |
"index_list", | |
"index_xinfo", | |
"page_count", | |
"max_page_count", | |
"page_size", | |
"schema_version", | |
"table_info", | |
"table_xinfo", | |
"table_list", | |
) | |
disallawed_sql_res = [ | |
( | |
re.compile(f"pragma(?!_({'|'.join(allowed_pragmas)}))"), | |
"Statement contained a disallowed PRAGMA. Allowed pragma functions are {}".format( | |
", ".join("pragma_{}()".format(pragma) for pragma in allowed_pragmas) | |
), | |
) | |
] | |
def validate_sql_select(sql): | |
sql = "\n".join( | |
line for line in sql.split("\n") if not line.strip().startswith("--") | |
) | |
sql = sql.strip().lower() | |
if not any(r.match(sql) for r in allowed_sql_res): | |
raise InvalidSql("Statement must be a SELECT") | |
for r, msg in disallawed_sql_res: | |
if r.search(sql): | |
raise InvalidSql(msg) | |
def append_querystring(url, querystring): | |
op = "&" if ("?" in url) else "?" | |
return f"{url}{op}{querystring}" | |
def path_with_added_args(request, args, path=None): | |
path = path or request.path | |
if isinstance(args, dict): | |
args = args.items() | |
args_to_remove = {k for k, v in args if v is None} | |
current = [] | |
for key, value in urllib.parse.parse_qsl(request.query_string): | |
if key not in args_to_remove: | |
current.append((key, value)) | |
current.extend([(key, value) for key, value in args if value is not None]) | |
query_string = urllib.parse.urlencode(current) | |
if query_string: | |
query_string = f"?{query_string}" | |
return path + query_string | |
def path_with_removed_args(request, args, path=None): | |
query_string = request.query_string | |
if path is None: | |
path = request.path | |
else: | |
if "?" in path: | |
bits = path.split("?", 1) | |
path, query_string = bits | |
# args can be a dict or a set | |
current = [] | |
if isinstance(args, set): | |
def should_remove(key, value): | |
return key in args | |
elif isinstance(args, dict): | |
# Must match key AND value | |
def should_remove(key, value): | |
return args.get(key) == value | |
for key, value in urllib.parse.parse_qsl(query_string): | |
if not should_remove(key, value): | |
current.append((key, value)) | |
query_string = urllib.parse.urlencode(current) | |
if query_string: | |
query_string = f"?{query_string}" | |
return path + query_string | |
def path_with_replaced_args(request, args, path=None): | |
path = path or request.path | |
if isinstance(args, dict): | |
args = args.items() | |
keys_to_replace = {p[0] for p in args} | |
current = [] | |
for key, value in urllib.parse.parse_qsl(request.query_string): | |
if key not in keys_to_replace: | |
current.append((key, value)) | |
current.extend([p for p in args if p[1] is not None]) | |
query_string = urllib.parse.urlencode(current) | |
if query_string: | |
query_string = f"?{query_string}" | |
return path + query_string | |
_css_re = re.compile(r"""['"\n\\]""") | |
_boring_keyword_re = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") | |
def escape_css_string(s): | |
return _css_re.sub( | |
lambda m: "\\" + (f"{ord(m.group()):X}".zfill(6)), | |
s.replace("\r\n", "\n"), | |
) | |
def escape_sqlite(s): | |
if _boring_keyword_re.match(s) and (s.lower() not in reserved_words): | |
return s | |
else: | |
return f"[{s}]" | |
def make_dockerfile( | |
files, | |
metadata_file, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
spatialite, | |
version_note, | |
secret, | |
environment_variables=None, | |
port=8001, | |
apt_get_extras=None, | |
): | |
cmd = ["datasette", "serve", "--host", "0.0.0.0"] | |
environment_variables = environment_variables or {} | |
environment_variables["DATASETTE_SECRET"] = secret | |
apt_get_extras = apt_get_extras or [] | |
for filename in files: | |
cmd.extend(["-i", filename]) | |
cmd.extend(["--cors", "--inspect-file", "inspect-data.json"]) | |
if metadata_file: | |
cmd.extend(["--metadata", f"{metadata_file}"]) | |
if template_dir: | |
cmd.extend(["--template-dir", "templates/"]) | |
if plugins_dir: | |
cmd.extend(["--plugins-dir", "plugins/"]) | |
if version_note: | |
cmd.extend(["--version-note", f"{version_note}"]) | |
if static: | |
for mount_point, _ in static: | |
cmd.extend(["--static", f"{mount_point}:{mount_point}"]) | |
if extra_options: | |
for opt in extra_options.split(): | |
cmd.append(f"{opt}") | |
cmd = [shlex.quote(part) for part in cmd] | |
# port attribute is a (fixed) env variable and should not be quoted | |
cmd.extend(["--port", "$PORT"]) | |
cmd = " ".join(cmd) | |
if branch: | |
install = [f"https://github.com/simonw/datasette/archive/{branch}.zip"] + list( | |
install | |
) | |
else: | |
install = ["datasette"] + list(install) | |
apt_get_extras_ = [] | |
apt_get_extras_.extend(apt_get_extras) | |
apt_get_extras = apt_get_extras_ | |
if spatialite: | |
apt_get_extras.extend(["python3-dev", "gcc", "libsqlite3-mod-spatialite"]) | |
environment_variables["SQLITE_EXTENSIONS"] = ( | |
"/usr/lib/x86_64-linux-gnu/mod_spatialite.so" | |
) | |
return """ | |
FROM python:3.11.0-slim-bullseye | |
COPY . /app | |
WORKDIR /app | |
{apt_get_extras} | |
{environment_variables} | |
RUN pip install -U {install_from} | |
RUN datasette inspect {files} --inspect-file inspect-data.json | |
ENV PORT {port} | |
EXPOSE {port} | |
CMD {cmd}""".format( | |
apt_get_extras=( | |
APT_GET_DOCKERFILE_EXTRAS.format(" ".join(apt_get_extras)) | |
if apt_get_extras | |
else "" | |
), | |
environment_variables="\n".join( | |
[ | |
"ENV {} '{}'".format(key, value) | |
for key, value in environment_variables.items() | |
] | |
), | |
install_from=" ".join(install), | |
files=" ".join(files), | |
port=port, | |
cmd=cmd, | |
).strip() | |
@contextmanager | |
def temporary_docker_directory( | |
files, | |
name, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
spatialite, | |
version_note, | |
secret, | |
extra_metadata=None, | |
environment_variables=None, | |
port=8001, | |
apt_get_extras=None, | |
): | |
extra_metadata = extra_metadata or {} | |
tmp = tempfile.TemporaryDirectory() | |
# We create a datasette folder in there to get a nicer now deploy name | |
datasette_dir = os.path.join(tmp.name, name) | |
os.mkdir(datasette_dir) | |
saved_cwd = os.getcwd() | |
file_paths = [os.path.join(saved_cwd, file_path) for file_path in files] | |
file_names = [os.path.split(f)[-1] for f in files] | |
if metadata: | |
metadata_content = parse_metadata(metadata.read()) | |
else: | |
metadata_content = {} | |
# Merge in the non-null values in extra_metadata | |
mergedeep.merge( | |
metadata_content, | |
{key: value for key, value in extra_metadata.items() if value is not None}, | |
) | |
try: | |
dockerfile = make_dockerfile( | |
file_names, | |
metadata_content and "metadata.json", | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
spatialite, | |
version_note, | |
secret, | |
environment_variables, | |
port=port, | |
apt_get_extras=apt_get_extras, | |
) | |
os.chdir(datasette_dir) | |
if metadata_content: | |
with open("metadata.json", "w") as fp: | |
fp.write(json.dumps(metadata_content, indent=2)) | |
with open("Dockerfile", "w") as fp: | |
fp.write(dockerfile) | |
for path, filename in zip(file_paths, file_names): | |
link_or_copy(path, os.path.join(datasette_dir, filename)) | |
if template_dir: | |
link_or_copy_directory( | |
os.path.join(saved_cwd, template_dir), | |
os.path.join(datasette_dir, "templates"), | |
) | |
if plugins_dir: | |
link_or_copy_directory( | |
os.path.join(saved_cwd, plugins_dir), | |
os.path.join(datasette_dir, "plugins"), | |
) | |
for mount_point, path in static: | |
link_or_copy_directory( | |
os.path.join(saved_cwd, path), os.path.join(datasette_dir, mount_point) | |
) | |
yield datasette_dir | |
finally: | |
tmp.cleanup() | |
os.chdir(saved_cwd) | |
def detect_primary_keys(conn, table): | |
"""Figure out primary keys for a table.""" | |
columns = table_column_details(conn, table) | |
pks = [column for column in columns if column.is_pk] | |
pks.sort(key=lambda column: column.is_pk) | |
return [column.name for column in pks] | |
def get_outbound_foreign_keys(conn, table): | |
infos = conn.execute(f"PRAGMA foreign_key_list([{table}])").fetchall() | |
fks = [] | |
for info in infos: | |
if info is not None: | |
id, seq, table_name, from_, to_, on_update, on_delete, match = info | |
fks.append( | |
{ | |
"column": from_, | |
"other_table": table_name, | |
"other_column": to_, | |
"id": id, | |
"seq": seq, | |
} | |
) | |
# Filter out compound foreign keys by removing any where "id" is not unique | |
id_counts = Counter(fk["id"] for fk in fks) | |
return [ | |
{ | |
"column": fk["column"], | |
"other_table": fk["other_table"], | |
"other_column": fk["other_column"], | |
} | |
for fk in fks | |
if id_counts[fk["id"]] == 1 | |
] | |
def get_all_foreign_keys(conn): | |
tables = [ | |
r[0] for r in conn.execute('select name from sqlite_master where type="table"') | |
] | |
table_to_foreign_keys = {} | |
for table in tables: | |
table_to_foreign_keys[table] = {"incoming": [], "outgoing": []} | |
for table in tables: | |
fks = get_outbound_foreign_keys(conn, table) | |
for fk in fks: | |
table_name = fk["other_table"] | |
from_ = fk["column"] | |
to_ = fk["other_column"] | |
if table_name not in table_to_foreign_keys: | |
# Weird edge case where something refers to a table that does | |
# not actually exist | |
continue | |
table_to_foreign_keys[table_name]["incoming"].append( | |
{"other_table": table, "column": to_, "other_column": from_} | |
) | |
table_to_foreign_keys[table]["outgoing"].append( | |
{"other_table": table_name, "column": from_, "other_column": to_} | |
) | |
return table_to_foreign_keys | |
def detect_spatialite(conn): | |
rows = conn.execute( | |
'select 1 from sqlite_master where tbl_name = "geometry_columns"' | |
).fetchall() | |
return len(rows) > 0 | |
def detect_fts(conn, table): | |
"""Detect if table has a corresponding FTS virtual table and return it""" | |
rows = conn.execute(detect_fts_sql(table)).fetchall() | |
if len(rows) == 0: | |
return None | |
else: | |
return rows[0][0] | |
def detect_fts_sql(table): | |
return r""" | |
select name from sqlite_master | |
where rootpage = 0 | |
and ( | |
sql like '%VIRTUAL TABLE%USING FTS%content="{table}"%' | |
or sql like '%VIRTUAL TABLE%USING FTS%content=[{table}]%' | |
or ( | |
tbl_name = "{table}" | |
and sql like '%VIRTUAL TABLE%USING FTS%' | |
) | |
) | |
""".format( | |
table=table.replace("'", "''") | |
) | |
def detect_json1(conn=None): | |
if conn is None: | |
conn = sqlite3.connect(":memory:") | |
try: | |
conn.execute("SELECT json('{}')") | |
return True | |
except Exception: | |
return False | |
def table_columns(conn, table): | |
return [column.name for column in table_column_details(conn, table)] | |
def table_column_details(conn, table): | |
if supports_table_xinfo(): | |
# table_xinfo was added in 3.26.0 | |
return [ | |
Column(*r) | |
for r in conn.execute( | |
f"PRAGMA table_xinfo({escape_sqlite(table)});" | |
).fetchall() | |
] | |
else: | |
# Treat hidden as 0 for all columns | |
return [ | |
Column(*(list(r) + [0])) | |
for r in conn.execute( | |
f"PRAGMA table_info({escape_sqlite(table)});" | |
).fetchall() | |
] | |
filter_column_re = re.compile(r"^_filter_column_\d+$") | |
def filters_should_redirect(special_args): | |
redirect_params = [] | |
# Handle _filter_column=foo&_filter_op=exact&_filter_value=... | |
filter_column = special_args.get("_filter_column") | |
filter_op = special_args.get("_filter_op") or "" | |
filter_value = special_args.get("_filter_value") or "" | |
if "__" in filter_op: | |
filter_op, filter_value = filter_op.split("__", 1) | |
if filter_column: | |
redirect_params.append((f"{filter_column}__{filter_op}", filter_value)) | |
for key in ("_filter_column", "_filter_op", "_filter_value"): | |
if key in special_args: | |
redirect_params.append((key, None)) | |
# Now handle _filter_column_1=name&_filter_op_1=contains&_filter_value_1=hello | |
column_keys = [k for k in special_args if filter_column_re.match(k)] | |
for column_key in column_keys: | |
number = column_key.split("_")[-1] | |
column = special_args[column_key] | |
op = special_args.get(f"_filter_op_{number}") or "exact" | |
value = special_args.get(f"_filter_value_{number}") or "" | |
if "__" in op: | |
op, value = op.split("__", 1) | |
if column: | |
redirect_params.append((f"{column}__{op}", value)) | |
redirect_params.extend( | |
[ | |
(f"_filter_column_{number}", None), | |
(f"_filter_op_{number}", None), | |
(f"_filter_value_{number}", None), | |
] | |
) | |
return redirect_params | |
whitespace_re = re.compile(r"\s") | |
def is_url(value): | |
"""Must start with https:// or https:// and contain JUST a URL""" | |
if not isinstance(value, str): | |
return False | |
if not value.startswith("https://") and not value.startswith("https://"): | |
return False | |
# Any whitespace at all is invalid | |
if whitespace_re.search(value): | |
return False | |
return True | |
css_class_re = re.compile(r"^[a-zA-Z]+[_a-zA-Z0-9-]*$") | |
css_invalid_chars_re = re.compile(r"[^a-zA-Z0-9_\-]") | |
def to_css_class(s): | |
""" | |
Given a string (e.g. a table name) returns a valid unique CSS class. | |
For simple cases, just returns the string again. If the string is not a | |
valid CSS class (we disallow - and _ prefixes even though they are valid | |
as they may be confused with browser prefixes) we strip invalid characters | |
and add a 6 char md5 sum suffix, to make sure two tables with identical | |
names after stripping characters don't end up with the same CSS class. | |
""" | |
if css_class_re.match(s): | |
return s | |
md5_suffix = md5_not_usedforsecurity(s)[:6] | |
# Strip leading _, - | |
s = s.lstrip("_").lstrip("-") | |
# Replace any whitespace with hyphens | |
s = "-".join(s.split()) | |
# Remove any remaining invalid characters | |
s = css_invalid_chars_re.sub("", s) | |
# Attach the md5 suffix | |
bits = [b for b in (s, md5_suffix) if b] | |
return "-".join(bits) | |
def link_or_copy(src, dst): | |
# Intended for use in populating a temp directory. We link if possible, | |
# but fall back to copying if the temp directory is on a different device | |
# https://github.com/simonw/datasette/issues/141 | |
try: | |
os.link(src, dst) | |
except OSError: | |
shutil.copyfile(src, dst) | |
def link_or_copy_directory(src, dst): | |
try: | |
copytree(src, dst, copy_function=os.link, dirs_exist_ok=True) | |
except OSError: | |
copytree(src, dst, dirs_exist_ok=True) | |
def module_from_path(path, name): | |
# Adapted from https://sayspy.blogspot.com/2011/07/how-to-import-module-from-just-file.html | |
mod = types.ModuleType(name) | |
mod.__file__ = path | |
with open(path, "r") as file: | |
code = compile(file.read(), path, "exec", dont_inherit=True) | |
exec(code, mod.__dict__) | |
return mod | |
def path_with_format( | |
*, request=None, path=None, format=None, extra_qs=None, replace_format=None | |
): | |
qs = extra_qs or {} | |
path = request.path if request else path | |
if replace_format and path.endswith(f".{replace_format}"): | |
path = path[: -(1 + len(replace_format))] | |
if "." in path: | |
qs["_format"] = format | |
else: | |
path = f"{path}.{format}" | |
if qs: | |
extra = urllib.parse.urlencode(sorted(qs.items())) | |
if request and request.query_string: | |
path = f"{path}?{request.query_string}&{extra}" | |
else: | |
path = f"{path}?{extra}" | |
elif request and request.query_string: | |
path = f"{path}?{request.query_string}" | |
return path | |
class CustomRow(OrderedDict): | |
# Loose imitation of sqlite3.Row which offers | |
# both index-based AND key-based lookups | |
def __init__(self, columns, values=None): | |
self.columns = columns | |
if values: | |
self.update(values) | |
def __getitem__(self, key): | |
if isinstance(key, int): | |
return super().__getitem__(self.columns[key]) | |
else: | |
return super().__getitem__(key) | |
def __iter__(self): | |
for column in self.columns: | |
yield self[column] | |
def value_as_boolean(value): | |
if value.lower() not in ("on", "off", "true", "false", "1", "0"): | |
raise ValueAsBooleanError | |
return value.lower() in ("on", "true", "1") | |
class ValueAsBooleanError(ValueError): | |
pass | |
class WriteLimitExceeded(Exception): | |
pass | |
class LimitedWriter: | |
def __init__(self, writer, limit_mb): | |
self.writer = writer | |
self.limit_bytes = limit_mb * 1024 * 1024 | |
self.bytes_count = 0 | |
async def write(self, bytes): | |
self.bytes_count += len(bytes) | |
if self.limit_bytes and (self.bytes_count > self.limit_bytes): | |
raise WriteLimitExceeded(f"CSV contains more than {self.limit_bytes} bytes") | |
await self.writer.write(bytes) | |
class EscapeHtmlWriter: | |
def __init__(self, writer): | |
self.writer = writer | |
async def write(self, content): | |
await self.writer.write(markupsafe.escape(content)) | |
_infinities = {float("inf"), float("-inf")} | |
def remove_infinites(row): | |
to_check = row | |
if isinstance(row, dict): | |
to_check = row.values() | |
if not any((c in _infinities) if isinstance(c, float) else 0 for c in to_check): | |
return row | |
if isinstance(row, dict): | |
return { | |
k: (None if (isinstance(v, float) and v in _infinities) else v) | |
for k, v in row.items() | |
} | |
else: | |
return [None if (isinstance(c, float) and c in _infinities) else c for c in row] | |
class StaticMount(click.ParamType): | |
name = "mount:directory" | |
def convert(self, value, param, ctx): | |
if ":" not in value: | |
self.fail( | |
f'"{value}" should be of format mountpoint:directory', | |
param, | |
ctx, | |
) | |
path, dirpath = value.split(":", 1) | |
dirpath = os.path.abspath(dirpath) | |
if not os.path.exists(dirpath) or not os.path.isdir(dirpath): | |
self.fail(f"{value} is not a valid directory path", param, ctx) | |
return path, dirpath | |
# The --load-extension parameter can optionally include a specific entrypoint. | |
# This is done by appending ":entrypoint_name" after supplying the path to the extension | |
class LoadExtension(click.ParamType): | |
name = "path:entrypoint?" | |
def convert(self, value, param, ctx): | |
if ":" not in value: | |
return value | |
path, entrypoint = value.split(":", 1) | |
return path, entrypoint | |
def format_bytes(bytes): | |
current = float(bytes) | |
for unit in ("bytes", "KB", "MB", "GB", "TB"): | |
if current < 1024: | |
break | |
current = current / 1024 | |
if unit == "bytes": | |
return f"{int(current)} {unit}" | |
else: | |
return f"{current:.1f} {unit}" | |
_escape_fts_re = re.compile(r'\s+|(".*?")') | |
def escape_fts(query): | |
# If query has unbalanced ", add one at end | |
if query.count('"') % 2: | |
query += '"' | |
bits = _escape_fts_re.split(query) | |
bits = [b for b in bits if b and b != '""'] | |
return " ".join( | |
'"{}"'.format(bit) if not bit.startswith('"') else bit for bit in bits | |
) | |
class MultiParams: | |
def __init__(self, data): | |
# data is a dictionary of key => [list, of, values] or a list of [["key", "value"]] pairs | |
if isinstance(data, dict): | |
for key in data: | |
assert isinstance( | |
data[key], (list, tuple) | |
), "dictionary data should be a dictionary of key => [list]" | |
self._data = data | |
elif isinstance(data, list) or isinstance(data, tuple): | |
new_data = {} | |
for item in data: | |
assert ( | |
isinstance(item, (list, tuple)) and len(item) == 2 | |
), "list data should be a list of [key, value] pairs" | |
key, value = item | |
new_data.setdefault(key, []).append(value) | |
self._data = new_data | |
def __repr__(self): | |
return f"<MultiParams: {self._data}>" | |
def __contains__(self, key): | |
return key in self._data | |
def __getitem__(self, key): | |
return self._data[key][0] | |
def keys(self): | |
return self._data.keys() | |
def __iter__(self): | |
yield from self._data.keys() | |
def __len__(self): | |
return len(self._data) | |
def get(self, name, default=None): | |
"""Return first value in the list, if available""" | |
try: | |
return self._data.get(name)[0] | |
except (KeyError, TypeError): | |
return default | |
def getlist(self, name): | |
"""Return full list""" | |
return self._data.get(name) or [] | |
class ConnectionProblem(Exception): | |
pass | |
class SpatialiteConnectionProblem(ConnectionProblem): | |
pass | |
def check_connection(conn): | |
tables = [ | |
r[0] | |
for r in conn.execute( | |
"select name from sqlite_master where type='table'" | |
).fetchall() | |
] | |
for table in tables: | |
try: | |
conn.execute( | |
f"PRAGMA table_info({escape_sqlite(table)});", | |
) | |
except sqlite3.OperationalError as e: | |
if e.args[0] == "no such module: VirtualSpatialIndex": | |
raise SpatialiteConnectionProblem(e) | |
else: | |
raise ConnectionProblem(e) | |
class BadMetadataError(Exception): | |
pass | |
@documented | |
def parse_metadata(content: str) -> dict: | |
"Detects if content is JSON or YAML and parses it appropriately." | |
# content can be JSON or YAML | |
try: | |
return json.loads(content) | |
except json.JSONDecodeError: | |
try: | |
return yaml.safe_load(content) | |
except yaml.YAMLError: | |
raise BadMetadataError("Metadata is not valid JSON or YAML") | |
def _gather_arguments(fn, kwargs): | |
parameters = inspect.signature(fn).parameters.keys() | |
call_with = [] | |
for parameter in parameters: | |
if parameter not in kwargs: | |
raise TypeError( | |
"{} requires parameters {}, missing: {}".format( | |
fn, tuple(parameters), set(parameters) - set(kwargs.keys()) | |
) | |
) | |
call_with.append(kwargs[parameter]) | |
return call_with | |
def call_with_supported_arguments(fn, **kwargs): | |
call_with = _gather_arguments(fn, kwargs) | |
return fn(*call_with) | |
async def async_call_with_supported_arguments(fn, **kwargs): | |
call_with = _gather_arguments(fn, kwargs) | |
return await fn(*call_with) | |
def actor_matches_allow(actor, allow): | |
if allow is True: | |
return True | |
if allow is False: | |
return False | |
if actor is None and allow and allow.get("unauthenticated") is True: | |
return True | |
if allow is None: | |
return True | |
actor = actor or {} | |
for key, values in allow.items(): | |
if values == "*" and key in actor: | |
return True | |
if not isinstance(values, list): | |
values = [values] | |
actor_values = actor.get(key) | |
if actor_values is None: | |
continue | |
if not isinstance(actor_values, list): | |
actor_values = [actor_values] | |
actor_values = set(actor_values) | |
if actor_values.intersection(values): | |
return True | |
return False | |
def resolve_env_secrets(config, environ): | |
"""Create copy that recursively replaces {"$env": "NAME"} with values from environ""" | |
if isinstance(config, dict): | |
if list(config.keys()) == ["$env"]: | |
return environ.get(list(config.values())[0]) | |
elif list(config.keys()) == ["$file"]: | |
with open(list(config.values())[0]) as fp: | |
return fp.read() | |
else: | |
return { | |
key: resolve_env_secrets(value, environ) | |
for key, value in config.items() | |
} | |
elif isinstance(config, list): | |
return [resolve_env_secrets(value, environ) for value in config] | |
else: | |
return config | |
def display_actor(actor): | |
for key in ("display", "name", "username", "login", "id"): | |
if actor.get(key): | |
return actor[key] | |
return str(actor) | |
class SpatialiteNotFound(Exception): | |
pass | |
# Can replace with sqlite-utils when I add that dependency | |
def find_spatialite(): | |
for path in SPATIALITE_PATHS: | |
if os.path.exists(path): | |
return path | |
raise SpatialiteNotFound | |
async def initial_path_for_datasette(datasette): | |
"""Return suggested path for opening this Datasette, based on number of DBs and tables""" | |
databases = dict([p for p in datasette.databases.items() if p[0] != "_internal"]) | |
if len(databases) == 1: | |
db_name = next(iter(databases.keys())) | |
path = datasette.urls.database(db_name) | |
# Does this DB only have one table? | |
db = next(iter(databases.values())) | |
tables = await db.table_names() | |
if len(tables) == 1: | |
path = datasette.urls.table(db_name, tables[0]) | |
else: | |
path = datasette.urls.instance() | |
return path | |
class PrefixedUrlString(str): | |
def __add__(self, other): | |
return type(self)(super().__add__(other)) | |
def __str__(self): | |
return super().__str__() | |
def __getattribute__(self, name): | |
if not name.startswith("__") and name in dir(str): | |
def method(self, *args, **kwargs): | |
value = getattr(super(), name)(*args, **kwargs) | |
if isinstance(value, str): | |
return type(self)(value) | |
elif isinstance(value, list): | |
return [type(self)(i) for i in value] | |
elif isinstance(value, tuple): | |
return tuple(type(self)(i) for i in value) | |
else: | |
return value | |
return method.__get__(self) | |
else: | |
return super().__getattribute__(name) | |
class StartupError(Exception): | |
pass | |
_single_line_comment_re = re.compile(r"--.*") | |
_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL) | |
_single_quote_re = re.compile(r"'(?:''|[^'])*'") | |
_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"') | |
_named_param_re = re.compile(r":(\w+)") | |
@documented | |
def named_parameters(sql: str) -> List[str]: | |
""" | |
Given a SQL statement, return a list of named parameters that are used in the statement | |
e.g. for ``select * from foo where id=:id`` this would return ``["id"]`` | |
""" | |
sql = _single_line_comment_re.sub("", sql) | |
sql = _multi_line_comment_re.sub("", sql) | |
sql = _single_quote_re.sub("", sql) | |
sql = _double_quote_re.sub("", sql) | |
# Extract parameters from what is left | |
return _named_param_re.findall(sql) | |
async def derive_named_parameters(db: "Database", sql: str) -> List[str]: | |
""" | |
This undocumented but stable method exists for backwards compatibility | |
with plugins that were using it before it switched to named_parameters() | |
""" | |
return named_parameters(sql) | |
def add_cors_headers(headers): | |
headers["Access-Control-Allow-Origin"] = "*" | |
headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type" | |
headers["Access-Control-Expose-Headers"] = "Link" | |
headers["Access-Control-Allow-Methods"] = "GET, POST, HEAD, OPTIONS" | |
headers["Access-Control-Max-Age"] = "3600" | |
_TILDE_ENCODING_SAFE = frozenset( | |
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
b"abcdefghijklmnopqrstuvwxyz" | |
b"0123456789_-" | |
# This is the same as Python percent-encoding but I removed | |
# '.' and '~' | |
) | |
_space = ord(" ") | |
class TildeEncoder(dict): | |
# Keeps a cache internally, via __missing__ | |
def __missing__(self, b): | |
# Handle a cache miss, store encoded string in cache and return. | |
if b in _TILDE_ENCODING_SAFE: | |
res = chr(b) | |
elif b == _space: | |
res = "+" | |
else: | |
res = "~{:02X}".format(b) | |
self[b] = res | |
return res | |
_tilde_encoder = TildeEncoder().__getitem__ | |
@documented | |
def tilde_encode(s: str) -> str: | |
"Returns tilde-encoded string - for example ``/foo/bar`` -> ``~2Ffoo~2Fbar``" | |
return "".join(_tilde_encoder(char) for char in s.encode("utf-8")) | |
@documented | |
def tilde_decode(s: str) -> str: | |
"Decodes a tilde-encoded string, so ``~2Ffoo~2Fbar`` -> ``/foo/bar``" | |
# Avoid accidentally decoding a %2f style sequence | |
temp = secrets.token_hex(16) | |
s = s.replace("%", temp) | |
decoded = urllib.parse.unquote_plus(s.replace("~", "%")) | |
return decoded.replace(temp, "%") | |
def resolve_routes(routes, path): | |
for regex, view in routes: | |
match = regex.match(path) | |
if match is not None: | |
return match, view | |
return None, None | |
def truncate_url(url, length): | |
if (not length) or (len(url) <= length): | |
return url | |
bits = url.rsplit(".", 1) | |
if len(bits) == 2 and 1 <= len(bits[1]) <= 4 and "/" not in bits[1]: | |
rest, ext = bits | |
return rest[: length - 1 - len(ext)] + "…." + ext | |
return url[: length - 1] + "…" | |
async def row_sql_params_pks(db, table, pk_values): | |
pks = await db.primary_keys(table) | |
use_rowid = not pks | |
select = "*" | |
if use_rowid: | |
select = "rowid, *" | |
pks = ["rowid"] | |
wheres = [f'"{pk}"=:p{i}' for i, pk in enumerate(pks)] | |
sql = f"select {select} from {escape_sqlite(table)} where {' AND '.join(wheres)}" | |
params = {} | |
for i, pk_value in enumerate(pk_values): | |
params[f"p{i}"] = pk_value | |
return sql, params, pks | |
def _handle_pair(key: str, value: str) -> dict: | |
""" | |
Turn a key-value pair into a nested dictionary. | |
foo, bar => {'foo': 'bar'} | |
foo.bar, baz => {'foo': {'bar': 'baz'}} | |
foo.bar, [1, 2, 3] => {'foo': {'bar': [1, 2, 3]}} | |
foo.bar, "baz" => {'foo': {'bar': 'baz'}} | |
foo.bar, '{"baz": "qux"}' => {'foo': {'bar': "{'baz': 'qux'}"}} | |
""" | |
try: | |
value = json.loads(value) | |
except json.JSONDecodeError: | |
# If it doesn't parse as JSON, treat it as a string | |
pass | |
keys = key.split(".") | |
result = current_dict = {} | |
for k in keys[:-1]: | |
current_dict[k] = {} | |
current_dict = current_dict[k] | |
current_dict[keys[-1]] = value | |
return result | |
def _combine(base: dict, update: dict) -> dict: | |
""" | |
Recursively merge two dictionaries. | |
""" | |
for key, value in update.items(): | |
if isinstance(value, dict) and key in base and isinstance(base[key], dict): | |
base[key] = _combine(base[key], value) | |
else: | |
base[key] = value | |
return base | |
def pairs_to_nested_config(pairs: typing.List[typing.Tuple[str, typing.Any]]) -> dict: | |
""" | |
Parse a list of key-value pairs into a nested dictionary. | |
""" | |
result = {} | |
for key, value in pairs: | |
parsed_pair = _handle_pair(key, value) | |
result = _combine(result, parsed_pair) | |
return result | |
def make_slot_function(name, datasette, request, **kwargs): | |
from datasette.plugins import pm | |
method = getattr(pm.hook, name, None) | |
assert method is not None, "No hook found for {}".format(name) | |
async def inner(): | |
html_bits = [] | |
for hook in method(datasette=datasette, request=request, **kwargs): | |
html = await await_me_maybe(hook) | |
if html is not None: | |
html_bits.append(html) | |
return markupsafe.Markup("".join(html_bits)) | |
return inner | |
def prune_empty_dicts(d: dict): | |
""" | |
Recursively prune all empty dictionaries from a given dictionary. | |
""" | |
for key, value in list(d.items()): | |
if isinstance(value, dict): | |
prune_empty_dicts(value) | |
if value == {}: | |
d.pop(key, None) | |
def move_plugins_and_allow(source: dict, destination: dict) -> Tuple[dict, dict]: | |
""" | |
Move 'plugins' and 'allow' keys from source to destination dictionary. Creates | |
hierarchy in destination if needed. After moving, recursively remove any keys | |
in the source that are left empty. | |
""" | |
source = copy.deepcopy(source) | |
destination = copy.deepcopy(destination) | |
def recursive_move(src, dest, path=None): | |
if path is None: | |
path = [] | |
for key, value in list(src.items()): | |
new_path = path + [key] | |
if key in ("plugins", "allow"): | |
# Navigate and create the hierarchy in destination if needed | |
d = dest | |
for step in path: | |
d = d.setdefault(step, {}) | |
# Move the plugins | |
d[key] = value | |
# Remove the plugins from source | |
src.pop(key, None) | |
elif isinstance(value, dict): | |
recursive_move(value, dest, new_path) | |
# After moving, check if the current dictionary is empty and remove it if so | |
if not value: | |
src.pop(key, None) | |
recursive_move(source, destination) | |
prune_empty_dicts(source) | |
return source, destination | |
_table_config_keys = ( | |
"hidden", | |
"sort", | |
"sort_desc", | |
"size", | |
"sortable_columns", | |
"label_column", | |
"facets", | |
"fts_table", | |
"fts_pk", | |
"searchmode", | |
) | |
def move_table_config(metadata: dict, config: dict): | |
""" | |
Move all known table configuration keys from metadata to config. | |
""" | |
if "databases" not in metadata: | |
return metadata, config | |
metadata = copy.deepcopy(metadata) | |
config = copy.deepcopy(config) | |
for database_name, database in metadata["databases"].items(): | |
if "tables" not in database: | |
continue | |
for table_name, table in database["tables"].items(): | |
for key in _table_config_keys: | |
if key in table: | |
config.setdefault("databases", {}).setdefault( | |
database_name, {} | |
).setdefault("tables", {}).setdefault(table_name, {})[ | |
key | |
] = table.pop( | |
key | |
) | |
prune_empty_dicts(metadata) | |
return metadata, config | |
def redact_keys(original: dict, key_patterns: Iterable) -> dict: | |
""" | |
Recursively redact sensitive keys in a dictionary based on given patterns | |
:param original: The original dictionary | |
:param key_patterns: A list of substring patterns to redact | |
:return: A copy of the original dictionary with sensitive values redacted | |
""" | |
def redact(data): | |
if isinstance(data, dict): | |
return { | |
k: ( | |
redact(v) | |
if not any(pattern in k for pattern in key_patterns) | |
else "***" | |
) | |
for k, v in data.items() | |
} | |
elif isinstance(data, list): | |
return [redact(item) for item in data] | |
else: | |
return data | |
return redact(original) | |
def md5_not_usedforsecurity(s): | |
try: | |
return hashlib.md5(s.encode("utf8"), usedforsecurity=False).hexdigest() | |
except TypeError: | |
# For Python 3.8 which does not support usedforsecurity=False | |
return hashlib.md5(s.encode("utf8")).hexdigest() | |
_etag_cache = {} | |
async def calculate_etag(filepath, chunk_size=4096): | |
if filepath in _etag_cache: | |
return _etag_cache[filepath] | |
hasher = hashlib.md5() | |
async with aiofiles.open(filepath, "rb") as f: | |
while True: | |
chunk = await f.read(chunk_size) | |
if not chunk: | |
break | |
hasher.update(chunk) | |
etag = f'"{hasher.hexdigest()}"' | |
_etag_cache[filepath] = etag | |
return etag | |
def deep_dict_update(dict1, dict2): | |
for key, value in dict2.items(): | |
if isinstance(value, dict): | |
dict1[key] = deep_dict_update(dict1.get(key, type(value)()), value) | |
else: | |
dict1[key] = value | |
return dict1 | |
</document_content> | |
</document> | |
<document index="26"> | |
<source>datasette/utils/asgi.py</source> | |
<document_content> | |
import hashlib | |
import json | |
from datasette.utils import MultiParams, calculate_etag | |
from mimetypes import guess_type | |
from urllib.parse import parse_qs, urlunparse, parse_qsl | |
from pathlib import Path | |
from http.cookies import SimpleCookie, Morsel | |
import aiofiles | |
import aiofiles.os | |
# Workaround for adding samesite support to pre 3.8 python | |
Morsel._reserved["samesite"] = "SameSite" | |
# Thanks, Starlette: | |
# https://github.com/encode/starlette/blob/519f575/starlette/responses.py#L17 | |
class Base400(Exception): | |
status = 400 | |
class NotFound(Base400): | |
status = 404 | |
class DatabaseNotFound(NotFound): | |
def __init__(self, database_name): | |
self.database_name = database_name | |
super().__init__("Database not found") | |
class TableNotFound(NotFound): | |
def __init__(self, database_name, table): | |
super().__init__("Table not found") | |
self.database_name = database_name | |
self.table = table | |
class RowNotFound(NotFound): | |
def __init__(self, database_name, table, pk_values): | |
super().__init__("Row not found") | |
self.database_name = database_name | |
self.table_name = table | |
self.pk_values = pk_values | |
class Forbidden(Base400): | |
status = 403 | |
class BadRequest(Base400): | |
status = 400 | |
SAMESITE_VALUES = ("strict", "lax", "none") | |
class Request: | |
def __init__(self, scope, receive): | |
self.scope = scope | |
self.receive = receive | |
def __repr__(self): | |
return '<asgi.Request method="{}" url="{}">'.format(self.method, self.url) | |
@property | |
def method(self): | |
return self.scope["method"] | |
@property | |
def url(self): | |
return urlunparse( | |
(self.scheme, self.host, self.path, None, self.query_string, None) | |
) | |
@property | |
def url_vars(self): | |
return (self.scope.get("url_route") or {}).get("kwargs") or {} | |
@property | |
def scheme(self): | |
return self.scope.get("scheme") or "http" | |
@property | |
def headers(self): | |
return { | |
k.decode("latin-1").lower(): v.decode("latin-1") | |
for k, v in self.scope.get("headers") or [] | |
} | |
@property | |
def host(self): | |
return self.headers.get("host") or "localhost" | |
@property | |
def cookies(self): | |
cookies = SimpleCookie() | |
cookies.load(self.headers.get("cookie", "")) | |
return {key: value.value for key, value in cookies.items()} | |
@property | |
def path(self): | |
if self.scope.get("raw_path") is not None: | |
return self.scope["raw_path"].decode("latin-1").partition("?")[0] | |
else: | |
path = self.scope["path"] | |
if isinstance(path, str): | |
return path | |
else: | |
return path.decode("utf-8") | |
@property | |
def query_string(self): | |
return (self.scope.get("query_string") or b"").decode("latin-1") | |
@property | |
def full_path(self): | |
qs = self.query_string | |
return "{}{}".format(self.path, ("?" + qs) if qs else "") | |
@property | |
def args(self): | |
return MultiParams(parse_qs(qs=self.query_string, keep_blank_values=True)) | |
@property | |
def actor(self): | |
return self.scope.get("actor", None) | |
async def post_body(self): | |
body = b"" | |
more_body = True | |
while more_body: | |
message = await self.receive() | |
assert message["type"] == "http.request", message | |
body += message.get("body", b"") | |
more_body = message.get("more_body", False) | |
return body | |
async def post_vars(self): | |
body = await self.post_body() | |
return dict(parse_qsl(body.decode("utf-8"), keep_blank_values=True)) | |
@classmethod | |
def fake(cls, path_with_query_string, method="GET", scheme="http", url_vars=None): | |
"""Useful for constructing Request objects for tests""" | |
path, _, query_string = path_with_query_string.partition("?") | |
scope = { | |
"http_version": "1.1", | |
"method": method, | |
"path": path, | |
"raw_path": path_with_query_string.encode("latin-1"), | |
"query_string": query_string.encode("latin-1"), | |
"scheme": scheme, | |
"type": "http", | |
} | |
if url_vars: | |
scope["url_route"] = {"kwargs": url_vars} | |
return cls(scope, None) | |
class AsgiLifespan: | |
def __init__(self, app, on_startup=None, on_shutdown=None): | |
self.app = app | |
on_startup = on_startup or [] | |
on_shutdown = on_shutdown or [] | |
if not isinstance(on_startup or [], list): | |
on_startup = [on_startup] | |
if not isinstance(on_shutdown or [], list): | |
on_shutdown = [on_shutdown] | |
self.on_startup = on_startup | |
self.on_shutdown = on_shutdown | |
async def __call__(self, scope, receive, send): | |
if scope["type"] == "lifespan": | |
while True: | |
message = await receive() | |
if message["type"] == "lifespan.startup": | |
for fn in self.on_startup: | |
await fn() | |
await send({"type": "lifespan.startup.complete"}) | |
elif message["type"] == "lifespan.shutdown": | |
for fn in self.on_shutdown: | |
await fn() | |
await send({"type": "lifespan.shutdown.complete"}) | |
return | |
else: | |
await self.app(scope, receive, send) | |
class AsgiStream: | |
def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): | |
self.stream_fn = stream_fn | |
self.status = status | |
self.headers = headers or {} | |
self.content_type = content_type | |
async def asgi_send(self, send): | |
# Remove any existing content-type header | |
headers = {k: v for k, v in self.headers.items() if k.lower() != "content-type"} | |
headers["content-type"] = self.content_type | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": self.status, | |
"headers": [ | |
[key.encode("utf-8"), value.encode("utf-8")] | |
for key, value in headers.items() | |
], | |
} | |
) | |
w = AsgiWriter(send) | |
await self.stream_fn(w) | |
await send({"type": "http.response.body", "body": b""}) | |
class AsgiWriter: | |
def __init__(self, send): | |
self.send = send | |
async def write(self, chunk): | |
await self.send( | |
{ | |
"type": "http.response.body", | |
"body": chunk.encode("utf-8"), | |
"more_body": True, | |
} | |
) | |
async def asgi_send_json(send, info, status=200, headers=None): | |
headers = headers or {} | |
await asgi_send( | |
send, | |
json.dumps(info), | |
status=status, | |
headers=headers, | |
content_type="application/json; charset=utf-8", | |
) | |
async def asgi_send_html(send, html, status=200, headers=None): | |
headers = headers or {} | |
await asgi_send( | |
send, | |
html, | |
status=status, | |
headers=headers, | |
content_type="text/html; charset=utf-8", | |
) | |
async def asgi_send_redirect(send, location, status=302): | |
await asgi_send( | |
send, | |
"", | |
status=status, | |
headers={"Location": location}, | |
content_type="text/html; charset=utf-8", | |
) | |
async def asgi_send(send, content, status, headers=None, content_type="text/plain"): | |
await asgi_start(send, status, headers, content_type) | |
await send({"type": "http.response.body", "body": content.encode("utf-8")}) | |
async def asgi_start(send, status, headers=None, content_type="text/plain"): | |
headers = headers or {} | |
# Remove any existing content-type header | |
headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} | |
headers["content-type"] = content_type | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": status, | |
"headers": [ | |
[key.encode("latin1"), value.encode("latin1")] | |
for key, value in headers.items() | |
], | |
} | |
) | |
async def asgi_send_file( | |
send, filepath, filename=None, content_type=None, chunk_size=4096, headers=None | |
): | |
headers = headers or {} | |
if filename: | |
headers["content-disposition"] = f'attachment; filename="{filename}"' | |
first = True | |
headers["content-length"] = str((await aiofiles.os.stat(str(filepath))).st_size) | |
async with aiofiles.open(str(filepath), mode="rb") as fp: | |
if first: | |
await asgi_start( | |
send, | |
200, | |
headers, | |
content_type or guess_type(str(filepath))[0] or "text/plain", | |
) | |
first = False | |
more_body = True | |
while more_body: | |
chunk = await fp.read(chunk_size) | |
more_body = len(chunk) == chunk_size | |
await send( | |
{"type": "http.response.body", "body": chunk, "more_body": more_body} | |
) | |
def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None): | |
root_path = Path(root_path) | |
static_headers = {} | |
if headers: | |
static_headers = headers.copy() | |
async def inner_static(request, send): | |
path = request.scope["url_route"]["kwargs"]["path"] | |
headers = static_headers.copy() | |
try: | |
full_path = (root_path / path).resolve().absolute() | |
except FileNotFoundError: | |
await asgi_send_html(send, "404: Directory not found", 404) | |
return | |
if full_path.is_dir(): | |
await asgi_send_html(send, "403: Directory listing is not allowed", 403) | |
return | |
# Ensure full_path is within root_path to avoid weird "../" tricks | |
try: | |
full_path.relative_to(root_path.resolve()) | |
except ValueError: | |
await asgi_send_html(send, "404: Path not inside root path", 404) | |
return | |
try: | |
# Calculate ETag for filepath | |
etag = await calculate_etag(full_path, chunk_size=chunk_size) | |
headers["ETag"] = etag | |
if_none_match = request.headers.get("if-none-match") | |
if if_none_match and if_none_match == etag: | |
return await asgi_send(send, "", 304) | |
await asgi_send_file( | |
send, full_path, chunk_size=chunk_size, headers=headers | |
) | |
except FileNotFoundError: | |
await asgi_send_html(send, "404: File not found", 404) | |
return | |
return inner_static | |
class Response: | |
def __init__(self, body=None, status=200, headers=None, content_type="text/plain"): | |
self.body = body | |
self.status = status | |
self.headers = headers or {} | |
self._set_cookie_headers = [] | |
self.content_type = content_type | |
async def asgi_send(self, send): | |
headers = {} | |
headers.update(self.headers) | |
headers["content-type"] = self.content_type | |
raw_headers = [ | |
[key.encode("utf-8"), value.encode("utf-8")] | |
for key, value in headers.items() | |
] | |
for set_cookie in self._set_cookie_headers: | |
raw_headers.append([b"set-cookie", set_cookie.encode("utf-8")]) | |
await send( | |
{ | |
"type": "http.response.start", | |
"status": self.status, | |
"headers": raw_headers, | |
} | |
) | |
body = self.body | |
if not isinstance(body, bytes): | |
body = body.encode("utf-8") | |
await send({"type": "http.response.body", "body": body}) | |
def set_cookie( | |
self, | |
key, | |
value="", | |
max_age=None, | |
expires=None, | |
path="/", | |
domain=None, | |
secure=False, | |
httponly=False, | |
samesite="lax", | |
): | |
assert samesite in SAMESITE_VALUES, "samesite should be one of {}".format( | |
SAMESITE_VALUES | |
) | |
cookie = SimpleCookie() | |
cookie[key] = value | |
for prop_name, prop_value in ( | |
("max_age", max_age), | |
("expires", expires), | |
("path", path), | |
("domain", domain), | |
("samesite", samesite), | |
): | |
if prop_value is not None: | |
cookie[key][prop_name.replace("_", "-")] = prop_value | |
for prop_name, prop_value in (("secure", secure), ("httponly", httponly)): | |
if prop_value: | |
cookie[key][prop_name] = True | |
self._set_cookie_headers.append(cookie.output(header="").strip()) | |
@classmethod | |
def html(cls, body, status=200, headers=None): | |
return cls( | |
body, | |
status=status, | |
headers=headers, | |
content_type="text/html; charset=utf-8", | |
) | |
@classmethod | |
def text(cls, body, status=200, headers=None): | |
return cls( | |
str(body), | |
status=status, | |
headers=headers, | |
content_type="text/plain; charset=utf-8", | |
) | |
@classmethod | |
def json(cls, body, status=200, headers=None, default=None): | |
return cls( | |
json.dumps(body, default=default), | |
status=status, | |
headers=headers, | |
content_type="application/json; charset=utf-8", | |
) | |
@classmethod | |
def redirect(cls, path, status=302, headers=None): | |
headers = headers or {} | |
headers["Location"] = path | |
return cls("", status=status, headers=headers) | |
class AsgiFileDownload: | |
def __init__( | |
self, | |
filepath, | |
filename=None, | |
content_type="application/octet-stream", | |
headers=None, | |
): | |
self.headers = headers or {} | |
self.filepath = filepath | |
self.filename = filename | |
self.content_type = content_type | |
async def asgi_send(self, send): | |
return await asgi_send_file( | |
send, | |
self.filepath, | |
filename=self.filename, | |
content_type=self.content_type, | |
headers=self.headers, | |
) | |
class AsgiRunOnFirstRequest: | |
def __init__(self, asgi, on_startup): | |
assert isinstance(on_startup, list) | |
self.asgi = asgi | |
self.on_startup = on_startup | |
self._started = False | |
async def __call__(self, scope, receive, send): | |
if not self._started: | |
self._started = True | |
for hook in self.on_startup: | |
await hook() | |
return await self.asgi(scope, receive, send) | |
</document_content> | |
</document> | |
<document index="27"> | |
<source>datasette/utils/baseconv.py</source> | |
<document_content> | |
""" | |
Convert numbers from base 10 integers to base X strings and back again. | |
Sample usage: | |
>>> base20 = BaseConverter('0123456789abcdefghij') | |
>>> base20.from_decimal(1234) | |
'31e' | |
>>> base20.to_decimal('31e') | |
1234 | |
Originally shared here: https://www.djangosnippets.org/snippets/1431/ | |
""" | |
class BaseConverter(object): | |
decimal_digits = "0123456789" | |
def __init__(self, digits): | |
self.digits = digits | |
def encode(self, i): | |
return self.convert(i, self.decimal_digits, self.digits) | |
def decode(self, s): | |
return int(self.convert(s, self.digits, self.decimal_digits)) | |
def convert(number, fromdigits, todigits): | |
# Based on https://code.activestate.com/recipes/111286/ | |
if str(number)[0] == "-": | |
number = str(number)[1:] | |
neg = 1 | |
else: | |
neg = 0 | |
# make an integer out of the number | |
x = 0 | |
for digit in str(number): | |
x = x * len(fromdigits) + fromdigits.index(digit) | |
# create the result in base 'len(todigits)' | |
if x == 0: | |
res = todigits[0] | |
else: | |
res = "" | |
while x > 0: | |
digit = x % len(todigits) | |
res = todigits[digit] + res | |
x = int(x / len(todigits)) | |
if neg: | |
res = "-" + res | |
return res | |
convert = staticmethod(convert) | |
bin = BaseConverter("01") | |
hexconv = BaseConverter("0123456789ABCDEF") | |
base62 = BaseConverter("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefghijklmnopqrstuvwxyz") | |
</document_content> | |
</document> | |
<document index="28"> | |
<source>datasette/utils/check_callable.py</source> | |
<document_content> | |
import asyncio | |
import types | |
from typing import NamedTuple, Any | |
class CallableStatus(NamedTuple): | |
is_callable: bool | |
is_async_callable: bool | |
def check_callable(obj: Any) -> CallableStatus: | |
if not callable(obj): | |
return CallableStatus(False, False) | |
if isinstance(obj, type): | |
# It's a class | |
return CallableStatus(True, False) | |
if isinstance(obj, types.FunctionType): | |
return CallableStatus(True, asyncio.iscoroutinefunction(obj)) | |
if hasattr(obj, "__call__"): | |
return CallableStatus(True, asyncio.iscoroutinefunction(obj.__call__)) | |
assert False, "obj {} is somehow callable with no __call__ method".format(repr(obj)) | |
</document_content> | |
</document> | |
<document index="29"> | |
<source>datasette/utils/internal_db.py</source> | |
<document_content> | |
import textwrap | |
from datasette.utils import table_column_details | |
async def init_internal_db(db): | |
create_tables_sql = textwrap.dedent( | |
""" | |
CREATE TABLE IF NOT EXISTS catalog_databases ( | |
database_name TEXT PRIMARY KEY, | |
path TEXT, | |
is_memory INTEGER, | |
schema_version INTEGER | |
); | |
CREATE TABLE IF NOT EXISTS catalog_tables ( | |
database_name TEXT, | |
table_name TEXT, | |
rootpage INTEGER, | |
sql TEXT, | |
PRIMARY KEY (database_name, table_name), | |
FOREIGN KEY (database_name) REFERENCES databases(database_name) | |
); | |
CREATE TABLE IF NOT EXISTS catalog_columns ( | |
database_name TEXT, | |
table_name TEXT, | |
cid INTEGER, | |
name TEXT, | |
type TEXT, | |
"notnull" INTEGER, | |
default_value TEXT, -- renamed from dflt_value | |
is_pk INTEGER, -- renamed from pk | |
hidden INTEGER, | |
PRIMARY KEY (database_name, table_name, name), | |
FOREIGN KEY (database_name) REFERENCES databases(database_name), | |
FOREIGN KEY (database_name, table_name) REFERENCES tables(database_name, table_name) | |
); | |
CREATE TABLE IF NOT EXISTS catalog_indexes ( | |
database_name TEXT, | |
table_name TEXT, | |
seq INTEGER, | |
name TEXT, | |
"unique" INTEGER, | |
origin TEXT, | |
partial INTEGER, | |
PRIMARY KEY (database_name, table_name, name), | |
FOREIGN KEY (database_name) REFERENCES databases(database_name), | |
FOREIGN KEY (database_name, table_name) REFERENCES tables(database_name, table_name) | |
); | |
CREATE TABLE IF NOT EXISTS catalog_foreign_keys ( | |
database_name TEXT, | |
table_name TEXT, | |
id INTEGER, | |
seq INTEGER, | |
"table" TEXT, | |
"from" TEXT, | |
"to" TEXT, | |
on_update TEXT, | |
on_delete TEXT, | |
match TEXT, | |
PRIMARY KEY (database_name, table_name, id, seq), | |
FOREIGN KEY (database_name) REFERENCES databases(database_name), | |
FOREIGN KEY (database_name, table_name) REFERENCES tables(database_name, table_name) | |
); | |
""" | |
).strip() | |
await db.execute_write_script(create_tables_sql) | |
await initialize_metadata_tables(db) | |
async def initialize_metadata_tables(db): | |
await db.execute_write_script( | |
textwrap.dedent( | |
""" | |
CREATE TABLE IF NOT EXISTS metadata_instance ( | |
key text, | |
value text, | |
unique(key) | |
); | |
CREATE TABLE IF NOT EXISTS metadata_databases ( | |
database_name text, | |
key text, | |
value text, | |
unique(database_name, key) | |
); | |
CREATE TABLE IF NOT EXISTS metadata_resources ( | |
database_name text, | |
resource_name text, | |
key text, | |
value text, | |
unique(database_name, resource_name, key) | |
); | |
CREATE TABLE IF NOT EXISTS metadata_columns ( | |
database_name text, | |
resource_name text, | |
column_name text, | |
key text, | |
value text, | |
unique(database_name, resource_name, column_name, key) | |
); | |
""" | |
) | |
) | |
async def populate_schema_tables(internal_db, db): | |
database_name = db.name | |
def delete_everything(conn): | |
conn.execute( | |
"DELETE FROM catalog_tables WHERE database_name = ?", [database_name] | |
) | |
conn.execute( | |
"DELETE FROM catalog_columns WHERE database_name = ?", [database_name] | |
) | |
conn.execute( | |
"DELETE FROM catalog_foreign_keys WHERE database_name = ?", | |
[database_name], | |
) | |
conn.execute( | |
"DELETE FROM catalog_indexes WHERE database_name = ?", [database_name] | |
) | |
await internal_db.execute_write_fn(delete_everything) | |
tables = (await db.execute("select * from sqlite_master WHERE type = 'table'")).rows | |
def collect_info(conn): | |
tables_to_insert = [] | |
columns_to_insert = [] | |
foreign_keys_to_insert = [] | |
indexes_to_insert = [] | |
for table in tables: | |
table_name = table["name"] | |
tables_to_insert.append( | |
(database_name, table_name, table["rootpage"], table["sql"]) | |
) | |
columns = table_column_details(conn, table_name) | |
columns_to_insert.extend( | |
{ | |
**{"database_name": database_name, "table_name": table_name}, | |
**column._asdict(), | |
} | |
for column in columns | |
) | |
foreign_keys = conn.execute( | |
f"PRAGMA foreign_key_list([{table_name}])" | |
).fetchall() | |
foreign_keys_to_insert.extend( | |
{ | |
**{"database_name": database_name, "table_name": table_name}, | |
**dict(foreign_key), | |
} | |
for foreign_key in foreign_keys | |
) | |
indexes = conn.execute(f"PRAGMA index_list([{table_name}])").fetchall() | |
indexes_to_insert.extend( | |
{ | |
**{"database_name": database_name, "table_name": table_name}, | |
**dict(index), | |
} | |
for index in indexes | |
) | |
return ( | |
tables_to_insert, | |
columns_to_insert, | |
foreign_keys_to_insert, | |
indexes_to_insert, | |
) | |
( | |
tables_to_insert, | |
columns_to_insert, | |
foreign_keys_to_insert, | |
indexes_to_insert, | |
) = await db.execute_fn(collect_info) | |
await internal_db.execute_write_many( | |
""" | |
INSERT INTO catalog_tables (database_name, table_name, rootpage, sql) | |
values (?, ?, ?, ?) | |
""", | |
tables_to_insert, | |
) | |
await internal_db.execute_write_many( | |
""" | |
INSERT INTO catalog_columns ( | |
database_name, table_name, cid, name, type, "notnull", default_value, is_pk, hidden | |
) VALUES ( | |
:database_name, :table_name, :cid, :name, :type, :notnull, :default_value, :is_pk, :hidden | |
) | |
""", | |
columns_to_insert, | |
) | |
await internal_db.execute_write_many( | |
""" | |
INSERT INTO catalog_foreign_keys ( | |
database_name, table_name, "id", seq, "table", "from", "to", on_update, on_delete, match | |
) VALUES ( | |
:database_name, :table_name, :id, :seq, :table, :from, :to, :on_update, :on_delete, :match | |
) | |
""", | |
foreign_keys_to_insert, | |
) | |
await internal_db.execute_write_many( | |
""" | |
INSERT INTO catalog_indexes ( | |
database_name, table_name, seq, name, "unique", origin, partial | |
) VALUES ( | |
:database_name, :table_name, :seq, :name, :unique, :origin, :partial | |
) | |
""", | |
indexes_to_insert, | |
) | |
</document_content> | |
</document> | |
<document index="30"> | |
<source>datasette/utils/shutil_backport.py</source> | |
<document_content> | |
""" | |
Backported from Python 3.8. | |
This code is licensed under the Python License: | |
https://github.com/python/cpython/blob/v3.8.3/LICENSE | |
""" | |
import os | |
from shutil import copy, copy2, copystat, Error | |
def _copytree( | |
entries, | |
src, | |
dst, | |
symlinks, | |
ignore, | |
copy_function, | |
ignore_dangling_symlinks, | |
dirs_exist_ok=False, | |
): | |
if ignore is not None: | |
ignored_names = ignore(src, set(os.listdir(src))) | |
else: | |
ignored_names = set() | |
os.makedirs(dst, exist_ok=dirs_exist_ok) | |
errors = [] | |
use_srcentry = copy_function is copy2 or copy_function is copy | |
for srcentry in entries: | |
if srcentry.name in ignored_names: | |
continue | |
srcname = os.path.join(src, srcentry.name) | |
dstname = os.path.join(dst, srcentry.name) | |
srcobj = srcentry if use_srcentry else srcname | |
try: | |
if srcentry.is_symlink(): | |
linkto = os.readlink(srcname) | |
if symlinks: | |
os.symlink(linkto, dstname) | |
copystat(srcobj, dstname, follow_symlinks=not symlinks) | |
else: | |
if not os.path.exists(linkto) and ignore_dangling_symlinks: | |
continue | |
if srcentry.is_dir(): | |
copytree( | |
srcobj, | |
dstname, | |
symlinks, | |
ignore, | |
copy_function, | |
dirs_exist_ok=dirs_exist_ok, | |
) | |
else: | |
copy_function(srcobj, dstname) | |
elif srcentry.is_dir(): | |
copytree( | |
srcobj, | |
dstname, | |
symlinks, | |
ignore, | |
copy_function, | |
dirs_exist_ok=dirs_exist_ok, | |
) | |
else: | |
copy_function(srcentry, dstname) | |
except Error as err: | |
errors.extend(err.args[0]) | |
except OSError as why: | |
errors.append((srcname, dstname, str(why))) | |
try: | |
copystat(src, dst) | |
except OSError as why: | |
# Copying file access times may fail on Windows | |
if getattr(why, "winerror", None) is None: | |
errors.append((src, dst, str(why))) | |
if errors: | |
raise Error(errors) | |
return dst | |
def copytree( | |
src, | |
dst, | |
symlinks=False, | |
ignore=None, | |
copy_function=copy2, | |
ignore_dangling_symlinks=False, | |
dirs_exist_ok=False, | |
): | |
with os.scandir(src) as entries: | |
return _copytree( | |
entries=entries, | |
src=src, | |
dst=dst, | |
symlinks=symlinks, | |
ignore=ignore, | |
copy_function=copy_function, | |
ignore_dangling_symlinks=ignore_dangling_symlinks, | |
dirs_exist_ok=dirs_exist_ok, | |
) | |
</document_content> | |
</document> | |
<document index="31"> | |
<source>datasette/utils/sqlite.py</source> | |
<document_content> | |
using_pysqlite3 = False | |
try: | |
import pysqlite3 as sqlite3 | |
using_pysqlite3 = True | |
except ImportError: | |
import sqlite3 | |
if hasattr(sqlite3, "enable_callback_tracebacks"): | |
sqlite3.enable_callback_tracebacks(True) | |
_cached_sqlite_version = None | |
def sqlite_version(): | |
global _cached_sqlite_version | |
if _cached_sqlite_version is None: | |
_cached_sqlite_version = _sqlite_version() | |
return _cached_sqlite_version | |
def _sqlite_version(): | |
return tuple( | |
map( | |
int, | |
sqlite3.connect(":memory:") | |
.execute("select sqlite_version()") | |
.fetchone()[0] | |
.split("."), | |
) | |
) | |
def supports_table_xinfo(): | |
return sqlite_version() >= (3, 26, 0) | |
def supports_generated_columns(): | |
return sqlite_version() >= (3, 31, 0) | |
</document_content> | |
</document> | |
<document index="32"> | |
<source>datasette/utils/testing.py</source> | |
<document_content> | |
from asgiref.sync import async_to_sync | |
from urllib.parse import urlencode | |
import json | |
# These wrapper classes pre-date the introduction of | |
# datasette.client and httpx to Datasette. They could | |
# be removed if the Datasette tests are modified to | |
# call datasette.client directly. | |
class TestResponse: | |
def __init__(self, httpx_response): | |
self.httpx_response = httpx_response | |
@property | |
def status(self): | |
return self.httpx_response.status_code | |
# Supports both for test-writing convenience | |
@property | |
def status_code(self): | |
return self.status | |
@property | |
def headers(self): | |
return self.httpx_response.headers | |
@property | |
def body(self): | |
return self.httpx_response.content | |
@property | |
def content(self): | |
return self.body | |
@property | |
def cookies(self): | |
return dict(self.httpx_response.cookies) | |
@property | |
def json(self): | |
return json.loads(self.text) | |
@property | |
def text(self): | |
return self.body.decode("utf8") | |
class TestClient: | |
max_redirects = 5 | |
def __init__(self, ds): | |
self.ds = ds | |
def actor_cookie(self, actor): | |
return self.ds.sign({"a": actor}, "actor") | |
@async_to_sync | |
async def get( | |
self, | |
path, | |
follow_redirects=False, | |
redirect_count=0, | |
method="GET", | |
params=None, | |
cookies=None, | |
if_none_match=None, | |
headers=None, | |
): | |
if params: | |
path += "?" + urlencode(params, doseq=True) | |
return await self._request( | |
path=path, | |
follow_redirects=follow_redirects, | |
redirect_count=redirect_count, | |
method=method, | |
cookies=cookies, | |
if_none_match=if_none_match, | |
headers=headers, | |
) | |
@async_to_sync | |
async def post( | |
self, | |
path, | |
post_data=None, | |
body=None, | |
follow_redirects=False, | |
redirect_count=0, | |
content_type="application/x-www-form-urlencoded", | |
cookies=None, | |
headers=None, | |
csrftoken_from=None, | |
): | |
cookies = cookies or {} | |
post_data = post_data or {} | |
assert not (post_data and body), "Provide one or other of body= or post_data=" | |
# Maybe fetch a csrftoken first | |
if csrftoken_from is not None: | |
assert body is None, "body= is not compatible with csrftoken_from=" | |
if csrftoken_from is True: | |
csrftoken_from = path | |
token_response = await self._request(csrftoken_from, cookies=cookies) | |
csrftoken = token_response.cookies["ds_csrftoken"] | |
cookies["ds_csrftoken"] = csrftoken | |
post_data["csrftoken"] = csrftoken | |
if post_data: | |
body = urlencode(post_data, doseq=True) | |
return await self._request( | |
path=path, | |
follow_redirects=follow_redirects, | |
redirect_count=redirect_count, | |
method="POST", | |
cookies=cookies, | |
headers=headers, | |
post_body=body, | |
content_type=content_type, | |
) | |
@async_to_sync | |
async def request( | |
self, | |
path, | |
follow_redirects=True, | |
redirect_count=0, | |
method="GET", | |
cookies=None, | |
headers=None, | |
post_body=None, | |
content_type=None, | |
if_none_match=None, | |
): | |
return await self._request( | |
path, | |
follow_redirects=follow_redirects, | |
redirect_count=redirect_count, | |
method=method, | |
cookies=cookies, | |
headers=headers, | |
post_body=post_body, | |
content_type=content_type, | |
if_none_match=if_none_match, | |
) | |
async def _request( | |
self, | |
path, | |
follow_redirects=True, | |
redirect_count=0, | |
method="GET", | |
cookies=None, | |
headers=None, | |
post_body=None, | |
content_type=None, | |
if_none_match=None, | |
): | |
await self.ds.invoke_startup() | |
headers = headers or {} | |
if content_type: | |
headers["content-type"] = content_type | |
if if_none_match: | |
headers["if-none-match"] = if_none_match | |
httpx_response = await self.ds.client.request( | |
method, | |
path, | |
follow_redirects=follow_redirects, | |
avoid_path_rewrites=True, | |
cookies=cookies, | |
headers=headers, | |
content=post_body, | |
) | |
response = TestResponse(httpx_response) | |
if follow_redirects and response.status in (301, 302): | |
assert ( | |
redirect_count < self.max_redirects | |
), f"Redirected {redirect_count} times, max_redirects={self.max_redirects}" | |
location = response.headers["Location"] | |
return await self._request( | |
location, follow_redirects=True, redirect_count=redirect_count + 1 | |
) | |
return response | |
</document_content> | |
</document> | |
<document index="33"> | |
<source>datasette/publish/__init__.py</source> | |
<document_content> | |
</document_content> | |
</document> | |
<document index="34"> | |
<source>datasette/publish/cloudrun.py</source> | |
<document_content> | |
from datasette import hookimpl | |
import click | |
import json | |
import os | |
import re | |
from subprocess import check_call, check_output | |
from .common import ( | |
add_common_publish_arguments_and_options, | |
fail_if_publish_binary_not_installed, | |
) | |
from ..utils import temporary_docker_directory | |
@hookimpl | |
def publish_subcommand(publish): | |
@publish.command() | |
@add_common_publish_arguments_and_options | |
@click.option( | |
"-n", | |
"--name", | |
default="datasette", | |
help="Application name to use when building", | |
) | |
@click.option( | |
"--service", default="", help="Cloud Run service to deploy (or over-write)" | |
) | |
@click.option("--spatialite", is_flag=True, help="Enable SpatialLite extension") | |
@click.option( | |
"--show-files", | |
is_flag=True, | |
help="Output the generated Dockerfile and metadata.json", | |
) | |
@click.option( | |
"--memory", | |
callback=_validate_memory, | |
help="Memory to allocate in Cloud Run, e.g. 1Gi", | |
) | |
@click.option( | |
"--cpu", | |
type=click.Choice(["1", "2", "4"]), | |
help="Number of vCPUs to allocate in Cloud Run", | |
) | |
@click.option( | |
"--timeout", | |
type=int, | |
help="Build timeout in seconds", | |
) | |
@click.option( | |
"--apt-get-install", | |
"apt_get_extras", | |
multiple=True, | |
help="Additional packages to apt-get install", | |
) | |
@click.option( | |
"--max-instances", | |
type=int, | |
help="Maximum Cloud Run instances", | |
) | |
@click.option( | |
"--min-instances", | |
type=int, | |
help="Minimum Cloud Run instances", | |
) | |
def cloudrun( | |
files, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
plugin_secret, | |
version_note, | |
secret, | |
title, | |
license, | |
license_url, | |
source, | |
source_url, | |
about, | |
about_url, | |
name, | |
service, | |
spatialite, | |
show_files, | |
memory, | |
cpu, | |
timeout, | |
apt_get_extras, | |
max_instances, | |
min_instances, | |
): | |
"Publish databases to Datasette running on Cloud Run" | |
fail_if_publish_binary_not_installed( | |
"gcloud", "Google Cloud", "https://cloud.google.com/sdk/" | |
) | |
project = check_output( | |
"gcloud config get-value project", shell=True, universal_newlines=True | |
).strip() | |
if not service: | |
# Show the user their current services, then prompt for one | |
click.echo("Please provide a service name for this deployment\n") | |
click.echo("Using an existing service name will over-write it") | |
click.echo("") | |
existing_services = get_existing_services() | |
if existing_services: | |
click.echo("Your existing services:\n") | |
for existing_service in existing_services: | |
click.echo( | |
" {name} - created {created} - {url}".format( | |
**existing_service | |
) | |
) | |
click.echo("") | |
service = click.prompt("Service name", type=str) | |
extra_metadata = { | |
"title": title, | |
"license": license, | |
"license_url": license_url, | |
"source": source, | |
"source_url": source_url, | |
"about": about, | |
"about_url": about_url, | |
} | |
if not extra_options: | |
extra_options = "" | |
if "force_https_urls" not in extra_options: | |
if extra_options: | |
extra_options += " " | |
extra_options += "--setting force_https_urls on" | |
environment_variables = {} | |
if plugin_secret: | |
extra_metadata["plugins"] = {} | |
for plugin_name, plugin_setting, setting_value in plugin_secret: | |
environment_variable = ( | |
f"{plugin_name}_{plugin_setting}".upper().replace("-", "_") | |
) | |
environment_variables[environment_variable] = setting_value | |
extra_metadata["plugins"].setdefault(plugin_name, {})[ | |
plugin_setting | |
] = {"$env": environment_variable} | |
with temporary_docker_directory( | |
files, | |
name, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
spatialite, | |
version_note, | |
secret, | |
extra_metadata, | |
environment_variables, | |
apt_get_extras=apt_get_extras, | |
): | |
if show_files: | |
if os.path.exists("metadata.json"): | |
print("=== metadata.json ===\n") | |
with open("metadata.json") as fp: | |
print(fp.read()) | |
print("\n==== Dockerfile ====\n") | |
with open("Dockerfile") as fp: | |
print(fp.read()) | |
print("\n====================\n") | |
image_id = f"gcr.io/{project}/datasette-{service}" | |
check_call( | |
"gcloud builds submit --tag {}{}".format( | |
image_id, " --timeout {}".format(timeout) if timeout else "" | |
), | |
shell=True, | |
) | |
extra_deploy_options = [] | |
for option, value in ( | |
("--memory", memory), | |
("--cpu", cpu), | |
("--max-instances", max_instances), | |
("--min-instances", min_instances), | |
): | |
if value: | |
extra_deploy_options.append("{} {}".format(option, value)) | |
check_call( | |
"gcloud run deploy --allow-unauthenticated --platform=managed --image {} {}{}".format( | |
image_id, | |
service, | |
" " + " ".join(extra_deploy_options) if extra_deploy_options else "", | |
), | |
shell=True, | |
) | |
def get_existing_services(): | |
services = json.loads( | |
check_output( | |
"gcloud run services list --platform=managed --format json", | |
shell=True, | |
universal_newlines=True, | |
) | |
) | |
return [ | |
{ | |
"name": service["metadata"]["name"], | |
"created": service["metadata"]["creationTimestamp"], | |
"url": service["status"]["address"]["url"], | |
} | |
for service in services | |
] | |
def _validate_memory(ctx, param, value): | |
if value and re.match(r"^\d+(Gi|G|Mi|M)$", value) is None: | |
raise click.BadParameter("--memory should be a number then Gi/G/Mi/M e.g 1Gi") | |
return value | |
</document_content> | |
</document> | |
<document index="35"> | |
<source>datasette/publish/common.py</source> | |
<document_content> | |
from ..utils import StaticMount | |
import click | |
import os | |
import shutil | |
import sys | |
def add_common_publish_arguments_and_options(subcommand): | |
for decorator in reversed( | |
( | |
click.argument("files", type=click.Path(exists=True), nargs=-1), | |
click.option( | |
"-m", | |
"--metadata", | |
type=click.File(mode="r"), | |
help="Path to JSON/YAML file containing metadata to publish", | |
), | |
click.option( | |
"--extra-options", help="Extra options to pass to datasette serve" | |
), | |
click.option( | |
"--branch", help="Install datasette from a GitHub branch e.g. main" | |
), | |
click.option( | |
"--template-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom templates", | |
), | |
click.option( | |
"--plugins-dir", | |
type=click.Path(exists=True, file_okay=False, dir_okay=True), | |
help="Path to directory containing custom plugins", | |
), | |
click.option( | |
"--static", | |
type=StaticMount(), | |
help="Serve static files from this directory at /MOUNT/...", | |
multiple=True, | |
), | |
click.option( | |
"--install", | |
help="Additional packages (e.g. plugins) to install", | |
multiple=True, | |
), | |
click.option( | |
"--plugin-secret", | |
nargs=3, | |
type=(str, str, str), | |
callback=validate_plugin_secret, | |
multiple=True, | |
help="Secrets to pass to plugins, e.g. --plugin-secret datasette-auth-github client_id xxx", | |
), | |
click.option( | |
"--version-note", help="Additional note to show on /-/versions" | |
), | |
click.option( | |
"--secret", | |
help="Secret used for signing secure values, such as signed cookies", | |
envvar="DATASETTE_PUBLISH_SECRET", | |
default=lambda: os.urandom(32).hex(), | |
), | |
click.option("--title", help="Title for metadata"), | |
click.option("--license", help="License label for metadata"), | |
click.option("--license_url", help="License URL for metadata"), | |
click.option("--source", help="Source label for metadata"), | |
click.option("--source_url", help="Source URL for metadata"), | |
click.option("--about", help="About label for metadata"), | |
click.option("--about_url", help="About URL for metadata"), | |
) | |
): | |
subcommand = decorator(subcommand) | |
return subcommand | |
def fail_if_publish_binary_not_installed(binary, publish_target, install_link): | |
"""Exit (with error message) if ``binary` isn't installed""" | |
if not shutil.which(binary): | |
click.secho( | |
"Publishing to {publish_target} requires {binary} to be installed and configured".format( | |
publish_target=publish_target, binary=binary | |
), | |
bg="red", | |
fg="white", | |
bold=True, | |
err=True, | |
) | |
click.echo( | |
f"Follow the instructions at {install_link}", | |
err=True, | |
) | |
sys.exit(1) | |
def validate_plugin_secret(ctx, param, value): | |
for plugin_name, plugin_setting, setting_value in value: | |
if "'" in setting_value: | |
raise click.BadParameter("--plugin-secret cannot contain single quotes") | |
return value | |
</document_content> | |
</document> | |
<document index="36"> | |
<source>datasette/publish/heroku.py</source> | |
<document_content> | |
from contextlib import contextmanager | |
from datasette import hookimpl | |
import click | |
import json | |
import os | |
import pathlib | |
import shlex | |
import shutil | |
from subprocess import call, check_output | |
import tempfile | |
from .common import ( | |
add_common_publish_arguments_and_options, | |
fail_if_publish_binary_not_installed, | |
) | |
from datasette.utils import link_or_copy, link_or_copy_directory, parse_metadata | |
@hookimpl | |
def publish_subcommand(publish): | |
@publish.command() | |
@add_common_publish_arguments_and_options | |
@click.option( | |
"-n", | |
"--name", | |
default="datasette", | |
help="Application name to use when deploying", | |
) | |
@click.option( | |
"--tar", | |
help="--tar option to pass to Heroku, e.g. --tar=/usr/local/bin/gtar", | |
) | |
@click.option( | |
"--generate-dir", | |
type=click.Path(dir_okay=True, file_okay=False), | |
help="Output generated application files and stop without deploying", | |
) | |
def heroku( | |
files, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
plugin_secret, | |
version_note, | |
secret, | |
title, | |
license, | |
license_url, | |
source, | |
source_url, | |
about, | |
about_url, | |
name, | |
tar, | |
generate_dir, | |
): | |
"Publish databases to Datasette running on Heroku" | |
fail_if_publish_binary_not_installed( | |
"heroku", "Heroku", "https://cli.heroku.com" | |
) | |
# Check for heroku-builds plugin | |
plugins = [ | |
line.split()[0] for line in check_output(["heroku", "plugins"]).splitlines() | |
] | |
if b"heroku-builds" not in plugins: | |
click.echo( | |
"Publishing to Heroku requires the heroku-builds plugin to be installed." | |
) | |
click.confirm( | |
"Install it? (this will run `heroku plugins:install heroku-builds`)", | |
abort=True, | |
) | |
call(["heroku", "plugins:install", "heroku-builds"]) | |
extra_metadata = { | |
"title": title, | |
"license": license, | |
"license_url": license_url, | |
"source": source, | |
"source_url": source_url, | |
"about": about, | |
"about_url": about_url, | |
} | |
environment_variables = {} | |
if plugin_secret: | |
extra_metadata["plugins"] = {} | |
for plugin_name, plugin_setting, setting_value in plugin_secret: | |
environment_variable = ( | |
f"{plugin_name}_{plugin_setting}".upper().replace("-", "_") | |
) | |
environment_variables[environment_variable] = setting_value | |
extra_metadata["plugins"].setdefault(plugin_name, {})[ | |
plugin_setting | |
] = {"$env": environment_variable} | |
with temporary_heroku_directory( | |
files, | |
name, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
version_note, | |
secret, | |
extra_metadata, | |
): | |
if generate_dir: | |
# Recursively copy files from current working directory to it | |
if pathlib.Path(generate_dir).exists(): | |
raise click.ClickException("Directory already exists") | |
shutil.copytree(".", generate_dir) | |
click.echo( | |
f"Generated files written to {generate_dir}, stopping without deploying", | |
err=True, | |
) | |
return | |
app_name = None | |
if name: | |
# Check to see if this app already exists | |
list_output = check_output(["heroku", "apps:list", "--json"]).decode( | |
"utf8" | |
) | |
apps = json.loads(list_output) | |
for app in apps: | |
if app["name"] == name: | |
app_name = name | |
break | |
if not app_name: | |
# Create a new app | |
cmd = ["heroku", "apps:create"] | |
if name: | |
cmd.append(name) | |
cmd.append("--json") | |
create_output = check_output(cmd).decode("utf8") | |
app_name = json.loads(create_output)["name"] | |
for key, value in environment_variables.items(): | |
call(["heroku", "config:set", "-a", app_name, f"{key}={value}"]) | |
tar_option = [] | |
if tar: | |
tar_option = ["--tar", tar] | |
call( | |
["heroku", "builds:create", "-a", app_name, "--include-vcs-ignore"] | |
+ tar_option | |
) | |
@contextmanager | |
def temporary_heroku_directory( | |
files, | |
name, | |
metadata, | |
extra_options, | |
branch, | |
template_dir, | |
plugins_dir, | |
static, | |
install, | |
version_note, | |
secret, | |
extra_metadata=None, | |
): | |
extra_metadata = extra_metadata or {} | |
tmp = tempfile.TemporaryDirectory() | |
saved_cwd = os.getcwd() | |
file_paths = [os.path.join(saved_cwd, file_path) for file_path in files] | |
file_names = [os.path.split(f)[-1] for f in files] | |
if metadata: | |
metadata_content = parse_metadata(metadata.read()) | |
else: | |
metadata_content = {} | |
for key, value in extra_metadata.items(): | |
if value: | |
metadata_content[key] = value | |
try: | |
os.chdir(tmp.name) | |
if metadata_content: | |
with open("metadata.json", "w") as fp: | |
fp.write(json.dumps(metadata_content, indent=2)) | |
with open("runtime.txt", "w") as fp: | |
fp.write("python-3.11.0") | |
if branch: | |
install = [ | |
f"https://github.com/simonw/datasette/archive/{branch}.zip" | |
] + list(install) | |
else: | |
install = ["datasette"] + list(install) | |
with open("requirements.txt", "w") as fp: | |
fp.write("\n".join(install)) | |
os.mkdir("bin") | |
with open("bin/post_compile", "w") as fp: | |
fp.write("datasette inspect --inspect-file inspect-data.json") | |
extras = [] | |
if template_dir: | |
link_or_copy_directory( | |
os.path.join(saved_cwd, template_dir), | |
os.path.join(tmp.name, "templates"), | |
) | |
extras.extend(["--template-dir", "templates/"]) | |
if plugins_dir: | |
link_or_copy_directory( | |
os.path.join(saved_cwd, plugins_dir), os.path.join(tmp.name, "plugins") | |
) | |
extras.extend(["--plugins-dir", "plugins/"]) | |
if version_note: | |
extras.extend(["--version-note", version_note]) | |
if metadata_content: | |
extras.extend(["--metadata", "metadata.json"]) | |
if extra_options: | |
extras.extend(extra_options.split()) | |
for mount_point, path in static: | |
link_or_copy_directory( | |
os.path.join(saved_cwd, path), os.path.join(tmp.name, mount_point) | |
) | |
extras.extend(["--static", f"{mount_point}:{mount_point}"]) | |
quoted_files = " ".join( | |
["-i {}".format(shlex.quote(file_name)) for file_name in file_names] | |
) | |
procfile_cmd = "web: datasette serve --host 0.0.0.0 {quoted_files} --cors --port $PORT --inspect-file inspect-data.json {extras}".format( | |
quoted_files=quoted_files, extras=" ".join(extras) | |
) | |
with open("Procfile", "w") as fp: | |
fp.write(procfile_cmd) | |
for path, filename in zip(file_paths, file_names): | |
link_or_copy(path, os.path.join(tmp.name, filename)) | |
yield | |
finally: | |
tmp.cleanup() | |
os.chdir(saved_cwd) | |
</document_content> | |
</document> | |
<document index="37"> | |
<source>datasette/views/__init__.py</source> | |
<document_content> | |
class Context: | |
"Base class for all documented contexts" | |
pass | |
</document_content> | |
</document> | |
<document index="38"> | |
<source>datasette/views/base.py</source> | |
<document_content> | |
import asyncio | |
import csv | |
import hashlib | |
import sys | |
import textwrap | |
import time | |
import urllib | |
from markupsafe import escape | |
from datasette.database import QueryInterrupted | |
from datasette.utils.asgi import Request | |
from datasette.utils import ( | |
add_cors_headers, | |
await_me_maybe, | |
EscapeHtmlWriter, | |
InvalidSql, | |
LimitedWriter, | |
call_with_supported_arguments, | |
path_from_row_pks, | |
path_with_added_args, | |
path_with_removed_args, | |
path_with_format, | |
sqlite3, | |
) | |
from datasette.utils.asgi import ( | |
AsgiStream, | |
NotFound, | |
Response, | |
BadRequest, | |
) | |
class DatasetteError(Exception): | |
def __init__( | |
self, | |
message, | |
title=None, | |
error_dict=None, | |
status=500, | |
template=None, | |
message_is_html=False, | |
): | |
self.message = message | |
self.title = title | |
self.error_dict = error_dict or {} | |
self.status = status | |
self.message_is_html = message_is_html | |
class View: | |
async def head(self, request, datasette): | |
if not hasattr(self, "get"): | |
return await self.method_not_allowed(request) | |
response = await self.get(request, datasette) | |
response.body = "" | |
return response | |
async def method_not_allowed(self, request): | |
if ( | |
request.path.endswith(".json") | |
or request.headers.get("content-type") == "application/json" | |
): | |
response = Response.json( | |
{"ok": False, "error": "Method not allowed"}, status=405 | |
) | |
else: | |
response = Response.text("Method not allowed", status=405) | |
return response | |
async def options(self, request, datasette): | |
response = Response.text("ok") | |
response.headers["allow"] = ", ".join( | |
method.upper() | |
for method in ("head", "get", "post", "put", "patch", "delete") | |
if hasattr(self, method) | |
) | |
return response | |
async def __call__(self, request, datasette): | |
try: | |
handler = getattr(self, request.method.lower()) | |
except AttributeError: | |
return await self.method_not_allowed(request) | |
return await handler(request, datasette) | |
class BaseView: | |
ds = None | |
has_json_alternate = True | |
def __init__(self, datasette): | |
self.ds = datasette | |
async def head(self, *args, **kwargs): | |
response = await self.get(*args, **kwargs) | |
response.body = b"" | |
return response | |
async def method_not_allowed(self, request): | |
if ( | |
request.path.endswith(".json") | |
or request.headers.get("content-type") == "application/json" | |
): | |
response = Response.json( | |
{"ok": False, "error": "Method not allowed"}, status=405 | |
) | |
else: | |
response = Response.text("Method not allowed", status=405) | |
return response | |
async def options(self, request, *args, **kwargs): | |
return Response.text("ok") | |
async def get(self, request, *args, **kwargs): | |
return await self.method_not_allowed(request) | |
async def post(self, request, *args, **kwargs): | |
return await self.method_not_allowed(request) | |
async def put(self, request, *args, **kwargs): | |
return await self.method_not_allowed(request) | |
async def patch(self, request, *args, **kwargs): | |
return await self.method_not_allowed(request) | |
async def delete(self, request, *args, **kwargs): | |
return await self.method_not_allowed(request) | |
async def dispatch_request(self, request): | |
if self.ds: | |
await self.ds.refresh_schemas() | |
handler = getattr(self, request.method.lower(), None) | |
response = await handler(request) | |
if self.ds.cors: | |
add_cors_headers(response.headers) | |
return response | |
async def render(self, templates, request, context=None): | |
context = context or {} | |
environment = self.ds.get_jinja_environment(request) | |
template = environment.select_template(templates) | |
template_context = { | |
**context, | |
**{ | |
"select_templates": [ | |
f"{'*' if template_name == template.name else ''}{template_name}" | |
for template_name in templates | |
], | |
}, | |
} | |
headers = {} | |
if self.has_json_alternate: | |
alternate_url_json = self.ds.absolute_url( | |
request, | |
self.ds.urls.path(path_with_format(request=request, format="json")), | |
) | |
template_context["alternate_url_json"] = alternate_url_json | |
headers.update( | |
{ | |
"Link": '{}; rel="alternate"; type="application/json+datasette"'.format( | |
alternate_url_json | |
) | |
} | |
) | |
return Response.html( | |
await self.ds.render_template( | |
template, | |
template_context, | |
request=request, | |
view_name=self.name, | |
), | |
headers=headers, | |
) | |
@classmethod | |
def as_view(cls, *class_args, **class_kwargs): | |
async def view(request, send): | |
self = view.view_class(*class_args, **class_kwargs) | |
return await self.dispatch_request(request) | |
view.view_class = cls | |
view.__doc__ = cls.__doc__ | |
view.__module__ = cls.__module__ | |
view.__name__ = cls.__name__ | |
return view | |
class DataView(BaseView): | |
name = "" | |
def redirect(self, request, path, forward_querystring=True, remove_args=None): | |
if request.query_string and "?" not in path and forward_querystring: | |
path = f"{path}?{request.query_string}" | |
if remove_args: | |
path = path_with_removed_args(request, remove_args, path=path) | |
r = Response.redirect(path) | |
r.headers["Link"] = f"<{path}>; rel=preload" | |
if self.ds.cors: | |
add_cors_headers(r.headers) | |
return r | |
async def data(self, request): | |
raise NotImplementedError | |
async def as_csv(self, request, database): | |
return await stream_csv(self.ds, self.data, request, database) | |
async def get(self, request): | |
db = await self.ds.resolve_database(request) | |
database = db.name | |
database_route = db.route | |
_format = request.url_vars["format"] | |
data_kwargs = {} | |
if _format == "csv": | |
return await self.as_csv(request, database_route) | |
if _format is None: | |
# HTML views default to expanding all foreign key labels | |
data_kwargs["default_labels"] = True | |
extra_template_data = {} | |
start = time.perf_counter() | |
status_code = None | |
templates = [] | |
try: | |
response_or_template_contexts = await self.data(request, **data_kwargs) | |
if isinstance(response_or_template_contexts, Response): | |
return response_or_template_contexts | |
# If it has four items, it includes an HTTP status code | |
if len(response_or_template_contexts) == 4: | |
( | |
data, | |
extra_template_data, | |
templates, | |
status_code, | |
) = response_or_template_contexts | |
else: | |
data, extra_template_data, templates = response_or_template_contexts | |
except QueryInterrupted as ex: | |
raise DatasetteError( | |
textwrap.dedent( | |
""" | |
<p>SQL query took too long. The time limit is controlled by the | |
<a href="https://docs.datasette.io/en/stable/settings.html#sql-time-limit-ms">sql_time_limit_ms</a> | |
configuration option.</p> | |
<textarea style="width: 90%">{}</textarea> | |
<script> | |
let ta = document.querySelector("textarea"); | |
ta.style.height = ta.scrollHeight + "px"; | |
</script> | |
""".format( | |
escape(ex.sql) | |
) | |
).strip(), | |
title="SQL Interrupted", | |
status=400, | |
message_is_html=True, | |
) | |
except (sqlite3.OperationalError, InvalidSql) as e: | |
raise DatasetteError(str(e), title="Invalid SQL", status=400) | |
except sqlite3.OperationalError as e: | |
raise DatasetteError(str(e)) | |
except DatasetteError: | |
raise | |
end = time.perf_counter() | |
data["query_ms"] = (end - start) * 1000 | |
# Special case for .jsono extension - redirect to _shape=objects | |
if _format == "jsono": | |
return self.redirect( | |
request, | |
path_with_added_args( | |
request, | |
{"_shape": "objects"}, | |
path=request.path.rsplit(".jsono", 1)[0] + ".json", | |
), | |
forward_querystring=False, | |
) | |
if _format in self.ds.renderers.keys(): | |
# Dispatch request to the correct output format renderer | |
# (CSV is not handled here due to streaming) | |
result = call_with_supported_arguments( | |
self.ds.renderers[_format][0], | |
datasette=self.ds, | |
columns=data.get("columns") or [], | |
rows=data.get("rows") or [], | |
sql=data.get("query", {}).get("sql", None), | |
query_name=data.get("query_name"), | |
database=database, | |
table=data.get("table"), | |
request=request, | |
view_name=self.name, | |
truncated=False, # TODO: support this | |
error=data.get("error"), | |
# These will be deprecated in Datasette 1.0: | |
args=request.args, | |
data=data, | |
) | |
if asyncio.iscoroutine(result): | |
result = await result | |
if result is None: | |
raise NotFound("No data") | |
if isinstance(result, dict): | |
r = Response( | |
body=result.get("body"), | |
status=result.get("status_code", status_code or 200), | |
content_type=result.get("content_type", "text/plain"), | |
headers=result.get("headers"), | |
) | |
elif isinstance(result, Response): | |
r = result | |
if status_code is not None: | |
# Over-ride the status code | |
r.status = status_code | |
else: | |
assert False, f"{result} should be dict or Response" | |
else: | |
extras = {} | |
if callable(extra_template_data): | |
extras = extra_template_data() | |
if asyncio.iscoroutine(extras): | |
extras = await extras | |
else: | |
extras = extra_template_data | |
url_labels_extra = {} | |
if data.get("expandable_columns"): | |
url_labels_extra = {"_labels": "on"} | |
renderers = {} | |
for key, (_, can_render) in self.ds.renderers.items(): | |
it_can_render = call_with_supported_arguments( | |
can_render, | |
datasette=self.ds, | |
columns=data.get("columns") or [], | |
rows=data.get("rows") or [], | |
sql=data.get("query", {}).get("sql", None), | |
query_name=data.get("query_name"), | |
database=database, | |
table=data.get("table"), | |
request=request, | |
view_name=self.name, | |
) | |
it_can_render = await await_me_maybe(it_can_render) | |
if it_can_render: | |
renderers[key] = self.ds.urls.path( | |
path_with_format( | |
request=request, format=key, extra_qs={**url_labels_extra} | |
) | |
) | |
url_csv_args = {"_size": "max", **url_labels_extra} | |
url_csv = self.ds.urls.path( | |
path_with_format(request=request, format="csv", extra_qs=url_csv_args) | |
) | |
url_csv_path = url_csv.split("?")[0] | |
context = { | |
**data, | |
**extras, | |
**{ | |
"renderers": renderers, | |
"url_csv": url_csv, | |
"url_csv_path": url_csv_path, | |
"url_csv_hidden_args": [ | |
(key, value) | |
for key, value in urllib.parse.parse_qsl(request.query_string) | |
if key not in ("_labels", "_facet", "_size") | |
] | |
+ [("_size", "max")], | |
"settings": self.ds.settings_dict(), | |
}, | |
} | |
if "metadata" not in context: | |
context["metadata"] = await self.ds.get_instance_metadata() | |
r = await self.render(templates, request=request, context=context) | |
if status_code is not None: | |
r.status = status_code | |
ttl = request.args.get("_ttl", None) | |
if ttl is None or not ttl.isdigit(): | |
ttl = self.ds.setting("default_cache_ttl") | |
return self.set_response_headers(r, ttl) | |
def set_response_headers(self, response, ttl): | |
# Set far-future cache expiry | |
if self.ds.cache_headers and response.status == 200: | |
ttl = int(ttl) | |
if ttl == 0: | |
ttl_header = "no-cache" | |
else: | |
ttl_header = f"max-age={ttl}" | |
response.headers["Cache-Control"] = ttl_header | |
response.headers["Referrer-Policy"] = "no-referrer" | |
if self.ds.cors: | |
add_cors_headers(response.headers) | |
return response | |
def _error(messages, status=400): | |
return Response.json({"ok": False, "errors": messages}, status=status) | |
async def stream_csv(datasette, fetch_data, request, database): | |
kwargs = {} | |
stream = request.args.get("_stream") | |
# Do not calculate facets or counts: | |
extra_parameters = [ | |
"{}=1".format(key) | |
for key in ("_nofacet", "_nocount") | |
if not request.args.get(key) | |
] | |
if extra_parameters: | |
# Replace request object with a new one with modified scope | |
if not request.query_string: | |
new_query_string = "&".join(extra_parameters) | |
else: | |
new_query_string = request.query_string + "&" + "&".join(extra_parameters) | |
new_scope = dict(request.scope, query_string=new_query_string.encode("latin-1")) | |
receive = request.receive | |
request = Request(new_scope, receive) | |
if stream: | |
# Some quick soundness checks | |
if not datasette.setting("allow_csv_stream"): | |
raise BadRequest("CSV streaming is disabled") | |
if request.args.get("_next"): | |
raise BadRequest("_next not allowed for CSV streaming") | |
kwargs["_size"] = "max" | |
# Fetch the first page | |
try: | |
response_or_template_contexts = await fetch_data(request) | |
if isinstance(response_or_template_contexts, Response): | |
return response_or_template_contexts | |
elif len(response_or_template_contexts) == 4: | |
data, _, _, _ = response_or_template_contexts | |
else: | |
data, _, _ = response_or_template_contexts | |
except (sqlite3.OperationalError, InvalidSql) as e: | |
raise DatasetteError(str(e), title="Invalid SQL", status=400) | |
except sqlite3.OperationalError as e: | |
raise DatasetteError(str(e)) | |
except DatasetteError: | |
raise | |
# Convert rows and columns to CSV | |
headings = data["columns"] | |
# if there are expanded_columns we need to add additional headings | |
expanded_columns = set(data.get("expanded_columns") or []) | |
if expanded_columns: | |
headings = [] | |
for column in data["columns"]: | |
headings.append(column) | |
if column in expanded_columns: | |
headings.append(f"{column}_label") | |
content_type = "text/plain; charset=utf-8" | |
preamble = "" | |
postamble = "" | |
trace = request.args.get("_trace") | |
if trace: | |
content_type = "text/html; charset=utf-8" | |
preamble = ( | |
"<html><head><title>CSV debug</title></head>" | |
'<body><textarea style="width: 90%; height: 70vh">' | |
) | |
postamble = "</textarea></body></html>" | |
async def stream_fn(r): | |
nonlocal data, trace | |
limited_writer = LimitedWriter(r, datasette.setting("max_csv_mb")) | |
if trace: | |
await limited_writer.write(preamble) | |
writer = csv.writer(EscapeHtmlWriter(limited_writer)) | |
else: | |
writer = csv.writer(limited_writer) | |
first = True | |
next = None | |
while first or (next and stream): | |
try: | |
kwargs = {} | |
if next: | |
kwargs["_next"] = next | |
if not first: | |
data, _, _ = await fetch_data(request, **kwargs) | |
if first: | |
if request.args.get("_header") != "off": | |
await writer.writerow(headings) | |
first = False | |
next = data.get("next") | |
for row in data["rows"]: | |
if any(isinstance(r, bytes) for r in row): | |
new_row = [] | |
for column, cell in zip(headings, row): | |
if isinstance(cell, bytes): | |
# If this is a table page, use .urls.row_blob() | |
if data.get("table"): | |
pks = data.get("primary_keys") or [] | |
cell = datasette.absolute_url( | |
request, | |
datasette.urls.row_blob( | |
database, | |
data["table"], | |
path_from_row_pks(row, pks, not pks), | |
column, | |
), | |
) | |
else: | |
# Otherwise generate URL for this query | |
url = datasette.absolute_url( | |
request, | |
path_with_format( | |
request=request, | |
format="blob", | |
extra_qs={ | |
"_blob_column": column, | |
"_blob_hash": hashlib.sha256( | |
cell | |
).hexdigest(), | |
}, | |
replace_format="csv", | |
), | |
) | |
cell = url.replace("&_nocount=1", "").replace( | |
"&_nofacet=1", "" | |
) | |
new_row.append(cell) | |
row = new_row | |
if not expanded_columns: | |
# Simple path | |
await writer.writerow(row) | |
else: | |
# Look for {"value": "label": } dicts and expand | |
new_row = [] | |
for heading, cell in zip(data["columns"], row): | |
if heading in expanded_columns: | |
if cell is None: | |
new_row.extend(("", "")) | |
else: | |
if not isinstance(cell, dict): | |
new_row.extend((cell, "")) | |
else: | |
new_row.append(cell["value"]) | |
new_row.append(cell["label"]) | |
else: | |
new_row.append(cell) | |
await writer.writerow(new_row) | |
except Exception as ex: | |
sys.stderr.write("Caught this error: {}\n".format(ex)) | |
sys.stderr.flush() | |
await r.write(str(ex)) | |
return | |
await limited_writer.write(postamble) | |
headers = {} | |
if datasette.cors: | |
add_cors_headers(headers) | |
if request.args.get("_dl", None): | |
if not trace: | |
content_type = "text/csv; charset=utf-8" | |
disposition = 'attachment; filename="{}.csv"'.format( | |
request.url_vars.get("table", database) | |
) | |
headers["content-disposition"] = disposition | |
return AsgiStream(stream_fn, headers=headers, content_type=content_type) | |
</document_content> | |
</document> | |
<document index="39"> | |
<source>datasette/views/database.py</source> | |
<document_content> | |
from dataclasses import dataclass, field | |
from urllib.parse import parse_qsl, urlencode | |
import asyncio | |
import hashlib | |
import itertools | |
import json | |
import markupsafe | |
import os | |
import re | |
import sqlite_utils | |
import textwrap | |
from typing import List | |
from datasette.events import AlterTableEvent, CreateTableEvent, InsertRowsEvent | |
from datasette.database import QueryInterrupted | |
from datasette.utils import ( | |
add_cors_headers, | |
await_me_maybe, | |
call_with_supported_arguments, | |
named_parameters as derive_named_parameters, | |
format_bytes, | |
make_slot_function, | |
tilde_decode, | |
to_css_class, | |
validate_sql_select, | |
is_url, | |
path_with_added_args, | |
path_with_format, | |
path_with_removed_args, | |
sqlite3, | |
truncate_url, | |
InvalidSql, | |
) | |
from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden | |
from datasette.plugins import pm | |
from .base import BaseView, DatasetteError, View, _error, stream_csv | |
class DatabaseView(View): | |
async def get(self, request, datasette): | |
format_ = request.url_vars.get("format") or "html" | |
await datasette.refresh_schemas() | |
db = await datasette.resolve_database(request) | |
database = db.name | |
visible, private = await datasette.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if not visible: | |
raise Forbidden("You do not have permission to view this database") | |
sql = (request.args.get("sql") or "").strip() | |
if sql: | |
redirect_url = "/" + request.url_vars.get("database") + "/-/query" | |
if request.url_vars.get("format"): | |
redirect_url += "." + request.url_vars.get("format") | |
redirect_url += "?" + request.query_string | |
return Response.redirect(redirect_url) | |
return await QueryView()(request, datasette) | |
if format_ not in ("html", "json"): | |
raise NotFound("Invalid format: {}".format(format_)) | |
metadata = await datasette.get_database_metadata(database) | |
sql_views = [] | |
for view_name in await db.view_names(): | |
view_visible, view_private = await datasette.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-table", (database, view_name)), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if view_visible: | |
sql_views.append( | |
{ | |
"name": view_name, | |
"private": view_private, | |
} | |
) | |
tables = await get_tables(datasette, request, db) | |
canned_queries = [] | |
for query in ( | |
await datasette.get_canned_queries(database, request.actor) | |
).values(): | |
query_visible, query_private = await datasette.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-query", (database, query["name"])), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if query_visible: | |
canned_queries.append(dict(query, private=query_private)) | |
async def database_actions(): | |
links = [] | |
for hook in pm.hook.database_actions( | |
datasette=datasette, | |
database=database, | |
actor=request.actor, | |
request=request, | |
): | |
extra_links = await await_me_maybe(hook) | |
if extra_links: | |
links.extend(extra_links) | |
return links | |
attached_databases = [d.name for d in await db.attached_databases()] | |
allow_execute_sql = await datasette.permission_allowed( | |
request.actor, "execute-sql", database | |
) | |
json_data = { | |
"database": database, | |
"private": private, | |
"path": datasette.urls.database(database), | |
"size": db.size, | |
"tables": tables, | |
"hidden_count": len([t for t in tables if t["hidden"]]), | |
"views": sql_views, | |
"queries": canned_queries, | |
"allow_execute_sql": allow_execute_sql, | |
"table_columns": ( | |
await _table_columns(datasette, database) if allow_execute_sql else {} | |
), | |
"metadata": await datasette.get_database_metadata(database), | |
} | |
if format_ == "json": | |
response = Response.json(json_data) | |
if datasette.cors: | |
add_cors_headers(response.headers) | |
return response | |
assert format_ == "html" | |
alternate_url_json = datasette.absolute_url( | |
request, | |
datasette.urls.path(path_with_format(request=request, format="json")), | |
) | |
templates = (f"database-{to_css_class(database)}.html", "database.html") | |
environment = datasette.get_jinja_environment(request) | |
template = environment.select_template(templates) | |
context = { | |
**json_data, | |
"database_color": db.color, | |
"database_actions": database_actions, | |
"show_hidden": request.args.get("_show_hidden"), | |
"editable": True, | |
"metadata": metadata, | |
"count_limit": db.count_limit, | |
"allow_download": datasette.setting("allow_download") | |
and not db.is_mutable | |
and not db.is_memory, | |
"attached_databases": attached_databases, | |
"alternate_url_json": alternate_url_json, | |
"select_templates": [ | |
f"{'*' if template_name == template.name else ''}{template_name}" | |
for template_name in templates | |
], | |
"top_database": make_slot_function( | |
"top_database", datasette, request, database=database | |
), | |
} | |
return Response.html( | |
await datasette.render_template( | |
templates, | |
context, | |
request=request, | |
view_name="database", | |
), | |
headers={ | |
"Link": '{}; rel="alternate"; type="application/json+datasette"'.format( | |
alternate_url_json | |
) | |
}, | |
) | |
@dataclass | |
class QueryContext: | |
database: str = field(metadata={"help": "The name of the database being queried"}) | |
database_color: str = field(metadata={"help": "The color of the database"}) | |
query: dict = field( | |
metadata={"help": "The SQL query object containing the `sql` string"} | |
) | |
canned_query: str = field( | |
metadata={"help": "The name of the canned query if this is a canned query"} | |
) | |
private: bool = field( | |
metadata={"help": "Boolean indicating if this is a private database"} | |
) | |
# urls: dict = field( | |
# metadata={"help": "Object containing URL helpers like `database()`"} | |
# ) | |
canned_query_write: bool = field( | |
metadata={ | |
"help": "Boolean indicating if this is a canned query that allows writes" | |
} | |
) | |
metadata: dict = field( | |
metadata={"help": "Metadata about the database or the canned query"} | |
) | |
db_is_immutable: bool = field( | |
metadata={"help": "Boolean indicating if this database is immutable"} | |
) | |
error: str = field(metadata={"help": "Any query error message"}) | |
hide_sql: bool = field( | |
metadata={"help": "Boolean indicating if the SQL should be hidden"} | |
) | |
show_hide_link: str = field( | |
metadata={"help": "The URL to toggle showing/hiding the SQL"} | |
) | |
show_hide_text: str = field( | |
metadata={"help": "The text for the show/hide SQL link"} | |
) | |
editable: bool = field( | |
metadata={"help": "Boolean indicating if the SQL can be edited"} | |
) | |
allow_execute_sql: bool = field( | |
metadata={"help": "Boolean indicating if custom SQL can be executed"} | |
) | |
tables: list = field(metadata={"help": "List of table objects in the database"}) | |
named_parameter_values: dict = field( | |
metadata={"help": "Dictionary of parameter names/values"} | |
) | |
edit_sql_url: str = field( | |
metadata={"help": "URL to edit the SQL for a canned query"} | |
) | |
display_rows: list = field(metadata={"help": "List of result rows to display"}) | |
columns: list = field(metadata={"help": "List of column names"}) | |
renderers: dict = field(metadata={"help": "Dictionary of renderer name to URL"}) | |
url_csv: str = field(metadata={"help": "URL for CSV export"}) | |
show_hide_hidden: str = field( | |
metadata={"help": "Hidden input field for the _show_sql parameter"} | |
) | |
table_columns: dict = field( | |
metadata={"help": "Dictionary of table name to list of column names"} | |
) | |
alternate_url_json: str = field( | |
metadata={"help": "URL for alternate JSON version of this page"} | |
) | |
# TODO: refactor this to somewhere else, probably ds.render_template() | |
select_templates: list = field( | |
metadata={ | |
"help": "List of templates that were considered for rendering this page" | |
} | |
) | |
top_query: callable = field( | |
metadata={"help": "Callable to render the top_query slot"} | |
) | |
top_canned_query: callable = field( | |
metadata={"help": "Callable to render the top_canned_query slot"} | |
) | |
query_actions: callable = field( | |
metadata={ | |
"help": "Callable returning a list of links for the query action menu" | |
} | |
) | |
async def get_tables(datasette, request, db): | |
tables = [] | |
database = db.name | |
table_counts = await db.table_counts(100) | |
hidden_table_names = set(await db.hidden_table_names()) | |
all_foreign_keys = await db.get_all_foreign_keys() | |
for table in table_counts: | |
table_visible, table_private = await datasette.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-table", (database, table)), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if not table_visible: | |
continue | |
table_columns = await db.table_columns(table) | |
tables.append( | |
{ | |
"name": table, | |
"columns": table_columns, | |
"primary_keys": await db.primary_keys(table), | |
"count": table_counts[table], | |
"hidden": table in hidden_table_names, | |
"fts_table": await db.fts_table(table), | |
"foreign_keys": all_foreign_keys[table], | |
"private": table_private, | |
} | |
) | |
tables.sort(key=lambda t: (t["hidden"], t["name"])) | |
return tables | |
async def database_download(request, datasette): | |
database = tilde_decode(request.url_vars["database"]) | |
await datasette.ensure_permissions( | |
request.actor, | |
[ | |
("view-database-download", database), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
try: | |
db = datasette.get_database(route=database) | |
except KeyError: | |
raise DatasetteError("Invalid database", status=404) | |
if db.is_memory: | |
raise DatasetteError("Cannot download in-memory databases", status=404) | |
if not datasette.setting("allow_download") or db.is_mutable: | |
raise Forbidden("Database download is forbidden") | |
if not db.path: | |
raise DatasetteError("Cannot download database", status=404) | |
filepath = db.path | |
headers = {} | |
if datasette.cors: | |
add_cors_headers(headers) | |
if db.hash: | |
etag = '"{}"'.format(db.hash) | |
headers["Etag"] = etag | |
# Has user seen this already? | |
if_none_match = request.headers.get("if-none-match") | |
if if_none_match and if_none_match == etag: | |
return Response("", status=304) | |
headers["Transfer-Encoding"] = "chunked" | |
return AsgiFileDownload( | |
filepath, | |
filename=os.path.basename(filepath), | |
content_type="application/octet-stream", | |
headers=headers, | |
) | |
class QueryView(View): | |
async def post(self, request, datasette): | |
from datasette.app import TableNotFound | |
db = await datasette.resolve_database(request) | |
# We must be a canned query | |
table_found = False | |
try: | |
await datasette.resolve_table(request) | |
table_found = True | |
except TableNotFound as table_not_found: | |
canned_query = await datasette.get_canned_query( | |
table_not_found.database_name, table_not_found.table, request.actor | |
) | |
if canned_query is None: | |
raise | |
if table_found: | |
# That should not have happened | |
raise DatasetteError("Unexpected table found on POST", status=404) | |
# If database is immutable, return an error | |
if not db.is_mutable: | |
raise Forbidden("Database is immutable") | |
# Process the POST | |
body = await request.post_body() | |
body = body.decode("utf-8").strip() | |
if body.startswith("{") and body.endswith("}"): | |
params = json.loads(body) | |
# But we want key=value strings | |
for key, value in params.items(): | |
params[key] = str(value) | |
else: | |
params = dict(parse_qsl(body, keep_blank_values=True)) | |
# Don't ever send csrftoken as a SQL parameter | |
params.pop("csrftoken", None) | |
# Should we return JSON? | |
should_return_json = ( | |
request.headers.get("accept") == "application/json" | |
or request.args.get("_json") | |
or params.get("_json") | |
) | |
params_for_query = MagicParameters( | |
canned_query["sql"], params, request, datasette | |
) | |
await params_for_query.execute_params() | |
ok = None | |
redirect_url = None | |
try: | |
cursor = await db.execute_write(canned_query["sql"], params_for_query) | |
# success message can come from on_success_message or on_success_message_sql | |
message = None | |
message_type = datasette.INFO | |
on_success_message_sql = canned_query.get("on_success_message_sql") | |
if on_success_message_sql: | |
try: | |
message_result = ( | |
await db.execute(on_success_message_sql, params_for_query) | |
).first() | |
if message_result: | |
message = message_result[0] | |
except Exception as ex: | |
message = "Error running on_success_message_sql: {}".format(ex) | |
message_type = datasette.ERROR | |
if not message: | |
message = canned_query.get( | |
"on_success_message" | |
) or "Query executed, {} row{} affected".format( | |
cursor.rowcount, "" if cursor.rowcount == 1 else "s" | |
) | |
redirect_url = canned_query.get("on_success_redirect") | |
ok = True | |
except Exception as ex: | |
message = canned_query.get("on_error_message") or str(ex) | |
message_type = datasette.ERROR | |
redirect_url = canned_query.get("on_error_redirect") | |
ok = False | |
if should_return_json: | |
return Response.json( | |
{ | |
"ok": ok, | |
"message": message, | |
"redirect": redirect_url, | |
} | |
) | |
else: | |
datasette.add_message(request, message, message_type) | |
return Response.redirect(redirect_url or request.path) | |
async def get(self, request, datasette): | |
from datasette.app import TableNotFound | |
await datasette.refresh_schemas() | |
db = await datasette.resolve_database(request) | |
database = db.name | |
# Are we a canned query? | |
canned_query = None | |
canned_query_write = False | |
if "table" in request.url_vars: | |
try: | |
await datasette.resolve_table(request) | |
except TableNotFound as table_not_found: | |
# Was this actually a canned query? | |
canned_query = await datasette.get_canned_query( | |
table_not_found.database_name, table_not_found.table, request.actor | |
) | |
if canned_query is None: | |
raise | |
canned_query_write = bool(canned_query.get("write")) | |
private = False | |
if canned_query: | |
# Respect canned query permissions | |
visible, private = await datasette.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-query", (database, canned_query["name"])), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if not visible: | |
raise Forbidden("You do not have permission to view this query") | |
else: | |
await datasette.ensure_permissions( | |
request.actor, [("execute-sql", database)] | |
) | |
# Flattened because of ?sql=&name1=value1&name2=value2 feature | |
params = {key: request.args.get(key) for key in request.args} | |
sql = None | |
if canned_query: | |
sql = canned_query["sql"] | |
elif "sql" in params: | |
sql = params.pop("sql") | |
# Extract any :named parameters | |
named_parameters = [] | |
if canned_query and canned_query.get("params"): | |
named_parameters = canned_query["params"] | |
if not named_parameters: | |
named_parameters = derive_named_parameters(sql) | |
named_parameter_values = { | |
named_parameter: params.get(named_parameter) or "" | |
for named_parameter in named_parameters | |
if not named_parameter.startswith("_") | |
} | |
# Set to blank string if missing from params | |
for named_parameter in named_parameters: | |
if named_parameter not in params and not named_parameter.startswith("_"): | |
params[named_parameter] = "" | |
extra_args = {} | |
if params.get("_timelimit"): | |
extra_args["custom_time_limit"] = int(params["_timelimit"]) | |
format_ = request.url_vars.get("format") or "html" | |
query_error = None | |
results = None | |
rows = [] | |
columns = [] | |
params_for_query = params | |
if not canned_query_write: | |
try: | |
if not canned_query: | |
# For regular queries we only allow SELECT, plus other rules | |
validate_sql_select(sql) | |
else: | |
# Canned queries can run magic parameters | |
params_for_query = MagicParameters(sql, params, request, datasette) | |
await params_for_query.execute_params() | |
results = await datasette.execute( | |
database, sql, params_for_query, truncate=True, **extra_args | |
) | |
columns = results.columns | |
rows = results.rows | |
except QueryInterrupted as ex: | |
raise DatasetteError( | |
textwrap.dedent( | |
""" | |
<p>SQL query took too long. The time limit is controlled by the | |
<a href="https://docs.datasette.io/en/stable/settings.html#sql-time-limit-ms">sql_time_limit_ms</a> | |
configuration option.</p> | |
<textarea style="width: 90%">{}</textarea> | |
<script> | |
let ta = document.querySelector("textarea"); | |
ta.style.height = ta.scrollHeight + "px"; | |
</script> | |
""".format( | |
markupsafe.escape(ex.sql) | |
) | |
).strip(), | |
title="SQL Interrupted", | |
status=400, | |
message_is_html=True, | |
) | |
except sqlite3.DatabaseError as ex: | |
query_error = str(ex) | |
results = None | |
rows = [] | |
columns = [] | |
except (sqlite3.OperationalError, InvalidSql) as ex: | |
raise DatasetteError(str(ex), title="Invalid SQL", status=400) | |
except sqlite3.OperationalError as ex: | |
raise DatasetteError(str(ex)) | |
except DatasetteError: | |
raise | |
# Handle formats from plugins | |
if format_ == "csv": | |
async def fetch_data_for_csv(request, _next=None): | |
results = await db.execute(sql, params, truncate=True) | |
data = {"rows": results.rows, "columns": results.columns} | |
return data, None, None | |
return await stream_csv(datasette, fetch_data_for_csv, request, db.name) | |
elif format_ in datasette.renderers.keys(): | |
# Dispatch request to the correct output format renderer | |
# (CSV is not handled here due to streaming) | |
result = call_with_supported_arguments( | |
datasette.renderers[format_][0], | |
datasette=datasette, | |
columns=columns, | |
rows=rows, | |
sql=sql, | |
query_name=canned_query["name"] if canned_query else None, | |
database=database, | |
table=None, | |
request=request, | |
view_name="table", | |
truncated=results.truncated if results else False, | |
error=query_error, | |
# These will be deprecated in Datasette 1.0: | |
args=request.args, | |
data={"ok": True, "rows": rows, "columns": columns}, | |
) | |
if asyncio.iscoroutine(result): | |
result = await result | |
if result is None: | |
raise NotFound("No data") | |
if isinstance(result, dict): | |
r = Response( | |
body=result.get("body"), | |
status=result.get("status_code") or 200, | |
content_type=result.get("content_type", "text/plain"), | |
headers=result.get("headers"), | |
) | |
elif isinstance(result, Response): | |
r = result | |
# if status_code is not None: | |
# # Over-ride the status code | |
# r.status = status_code | |
else: | |
assert False, f"{result} should be dict or Response" | |
elif format_ == "html": | |
headers = {} | |
templates = [f"query-{to_css_class(database)}.html", "query.html"] | |
if canned_query: | |
templates.insert( | |
0, | |
f"query-{to_css_class(database)}-{to_css_class(canned_query['name'])}.html", | |
) | |
environment = datasette.get_jinja_environment(request) | |
template = environment.select_template(templates) | |
alternate_url_json = datasette.absolute_url( | |
request, | |
datasette.urls.path(path_with_format(request=request, format="json")), | |
) | |
data = {} | |
headers.update( | |
{ | |
"Link": '{}; rel="alternate"; type="application/json+datasette"'.format( | |
alternate_url_json | |
) | |
} | |
) | |
metadata = await datasette.get_database_metadata(database) | |
renderers = {} | |
for key, (_, can_render) in datasette.renderers.items(): | |
it_can_render = call_with_supported_arguments( | |
can_render, | |
datasette=datasette, | |
columns=data.get("columns") or [], | |
rows=data.get("rows") or [], | |
sql=data.get("query", {}).get("sql", None), | |
query_name=data.get("query_name"), | |
database=database, | |
table=data.get("table"), | |
request=request, | |
view_name="database", | |
) | |
it_can_render = await await_me_maybe(it_can_render) | |
if it_can_render: | |
renderers[key] = datasette.urls.path( | |
path_with_format(request=request, format=key) | |
) | |
allow_execute_sql = await datasette.permission_allowed( | |
request.actor, "execute-sql", database | |
) | |
show_hide_hidden = "" | |
if canned_query and canned_query.get("hide_sql"): | |
if bool(params.get("_show_sql")): | |
show_hide_link = path_with_removed_args(request, {"_show_sql"}) | |
show_hide_text = "hide" | |
show_hide_hidden = ( | |
'<input type="hidden" name="_show_sql" value="1">' | |
) | |
else: | |
show_hide_link = path_with_added_args(request, {"_show_sql": 1}) | |
show_hide_text = "show" | |
else: | |
if bool(params.get("_hide_sql")): | |
show_hide_link = path_with_removed_args(request, {"_hide_sql"}) | |
show_hide_text = "show" | |
show_hide_hidden = ( | |
'<input type="hidden" name="_hide_sql" value="1">' | |
) | |
else: | |
show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) | |
show_hide_text = "hide" | |
hide_sql = show_hide_text == "show" | |
# Show 'Edit SQL' button only if: | |
# - User is allowed to execute SQL | |
# - SQL is an approved SELECT statement | |
# - No magic parameters, so no :_ in the SQL string | |
edit_sql_url = None | |
is_validated_sql = False | |
try: | |
validate_sql_select(sql) | |
is_validated_sql = True | |
except InvalidSql: | |
pass | |
if allow_execute_sql and is_validated_sql and ":_" not in sql: | |
edit_sql_url = ( | |
datasette.urls.database(database) | |
+ "/-/query" | |
+ "?" | |
+ urlencode( | |
{ | |
**{ | |
"sql": sql, | |
}, | |
**named_parameter_values, | |
} | |
) | |
) | |
async def query_actions(): | |
query_actions = [] | |
for hook in pm.hook.query_actions( | |
datasette=datasette, | |
actor=request.actor, | |
database=database, | |
query_name=canned_query["name"] if canned_query else None, | |
request=request, | |
sql=sql, | |
params=params, | |
): | |
extra_links = await await_me_maybe(hook) | |
if extra_links: | |
query_actions.extend(extra_links) | |
return query_actions | |
r = Response.html( | |
await datasette.render_template( | |
template, | |
QueryContext( | |
database=database, | |
database_color=db.color, | |
query={ | |
"sql": sql, | |
"params": params, | |
}, | |
canned_query=canned_query["name"] if canned_query else None, | |
private=private, | |
canned_query_write=canned_query_write, | |
db_is_immutable=not db.is_mutable, | |
error=query_error, | |
hide_sql=hide_sql, | |
show_hide_link=datasette.urls.path(show_hide_link), | |
show_hide_text=show_hide_text, | |
editable=not canned_query, | |
allow_execute_sql=allow_execute_sql, | |
tables=await get_tables(datasette, request, db), | |
named_parameter_values=named_parameter_values, | |
edit_sql_url=edit_sql_url, | |
display_rows=await display_rows( | |
datasette, database, request, rows, columns | |
), | |
table_columns=( | |
await _table_columns(datasette, database) | |
if allow_execute_sql | |
else {} | |
), | |
columns=columns, | |
renderers=renderers, | |
url_csv=datasette.urls.path( | |
path_with_format( | |
request=request, format="csv", extra_qs={"_size": "max"} | |
) | |
), | |
show_hide_hidden=markupsafe.Markup(show_hide_hidden), | |
metadata=canned_query or metadata, | |
alternate_url_json=alternate_url_json, | |
select_templates=[ | |
f"{'*' if template_name == template.name else ''}{template_name}" | |
for template_name in templates | |
], | |
top_query=make_slot_function( | |
"top_query", datasette, request, database=database, sql=sql | |
), | |
top_canned_query=make_slot_function( | |
"top_canned_query", | |
datasette, | |
request, | |
database=database, | |
query_name=canned_query["name"] if canned_query else None, | |
), | |
query_actions=query_actions, | |
), | |
request=request, | |
view_name="database", | |
), | |
headers=headers, | |
) | |
else: | |
assert False, "Invalid format: {}".format(format_) | |
if datasette.cors: | |
add_cors_headers(r.headers) | |
return r | |
class MagicParameters(dict): | |
def __init__(self, sql, data, request, datasette): | |
super().__init__(data) | |
self._sql = sql | |
self._request = request | |
self._magics = dict( | |
itertools.chain.from_iterable( | |
pm.hook.register_magic_parameters(datasette=datasette) | |
) | |
) | |
self._prepared = {} | |
async def execute_params(self): | |
for key in derive_named_parameters(self._sql): | |
if key.startswith("_") and key.count("_") >= 2: | |
prefix, suffix = key[1:].split("_", 1) | |
if prefix in self._magics: | |
result = await await_me_maybe( | |
self._magics[prefix](suffix, self._request) | |
) | |
self._prepared[key] = result | |
def __len__(self): | |
# Workaround for 'Incorrect number of bindings' error | |
# https://github.com/simonw/datasette/issues/967#issuecomment-692951144 | |
return super().__len__() or 1 | |
def __getitem__(self, key): | |
if key.startswith("_") and key.count("_") >= 2: | |
if key in self._prepared: | |
return self._prepared[key] | |
# Try the other route | |
prefix, suffix = key[1:].split("_", 1) | |
if prefix in self._magics: | |
try: | |
return self._magics[prefix](suffix, self._request) | |
except KeyError: | |
return super().__getitem__(key) | |
else: | |
return super().__getitem__(key) | |
class TableCreateView(BaseView): | |
name = "table-create" | |
_valid_keys = { | |
"table", | |
"rows", | |
"row", | |
"columns", | |
"pk", | |
"pks", | |
"ignore", | |
"replace", | |
"alter", | |
} | |
_supported_column_types = { | |
"text", | |
"integer", | |
"float", | |
"blob", | |
} | |
# Any string that does not contain a newline or start with sqlite_ | |
_table_name_re = re.compile(r"^(?!sqlite_)[^\n]+$") | |
def __init__(self, datasette): | |
self.ds = datasette | |
async def post(self, request): | |
db = await self.ds.resolve_database(request) | |
database_name = db.name | |
# Must have create-table permission | |
if not await self.ds.permission_allowed( | |
request.actor, "create-table", resource=database_name | |
): | |
return _error(["Permission denied"], 403) | |
body = await request.post_body() | |
try: | |
data = json.loads(body) | |
except json.JSONDecodeError as e: | |
return _error(["Invalid JSON: {}".format(e)]) | |
if not isinstance(data, dict): | |
return _error(["JSON must be an object"]) | |
invalid_keys = set(data.keys()) - self._valid_keys | |
if invalid_keys: | |
return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) | |
# ignore and replace are mutually exclusive | |
if data.get("ignore") and data.get("replace"): | |
return _error(["ignore and replace are mutually exclusive"]) | |
# ignore and replace only allowed with row or rows | |
if "ignore" in data or "replace" in data: | |
if not data.get("row") and not data.get("rows"): | |
return _error(["ignore and replace require row or rows"]) | |
# ignore and replace require pk or pks | |
if "ignore" in data or "replace" in data: | |
if not data.get("pk") and not data.get("pks"): | |
return _error(["ignore and replace require pk or pks"]) | |
ignore = data.get("ignore") | |
replace = data.get("replace") | |
if replace: | |
# Must have update-row permission | |
if not await self.ds.permission_allowed( | |
request.actor, "update-row", resource=database_name | |
): | |
return _error(["Permission denied: need update-row"], 403) | |
table_name = data.get("table") | |
if not table_name: | |
return _error(["Table is required"]) | |
if not self._table_name_re.match(table_name): | |
return _error(["Invalid table name"]) | |
table_exists = await db.table_exists(data["table"]) | |
columns = data.get("columns") | |
rows = data.get("rows") | |
row = data.get("row") | |
if not columns and not rows and not row: | |
return _error(["columns, rows or row is required"]) | |
if rows and row: | |
return _error(["Cannot specify both rows and row"]) | |
if rows or row: | |
# Must have insert-row permission | |
if not await self.ds.permission_allowed( | |
request.actor, "insert-row", resource=database_name | |
): | |
return _error(["Permission denied: need insert-row"], 403) | |
alter = False | |
if rows or row: | |
if not table_exists: | |
# if table is being created for the first time, alter=True | |
alter = True | |
else: | |
# alter=True only if they request it AND they have permission | |
if data.get("alter"): | |
if not await self.ds.permission_allowed( | |
request.actor, "alter-table", resource=database_name | |
): | |
return _error(["Permission denied: need alter-table"], 403) | |
alter = True | |
if columns: | |
if rows or row: | |
return _error(["Cannot specify columns with rows or row"]) | |
if not isinstance(columns, list): | |
return _error(["columns must be a list"]) | |
for column in columns: | |
if not isinstance(column, dict): | |
return _error(["columns must be a list of objects"]) | |
if not column.get("name") or not isinstance(column.get("name"), str): | |
return _error(["Column name is required"]) | |
if not column.get("type"): | |
column["type"] = "text" | |
if column["type"] not in self._supported_column_types: | |
return _error( | |
["Unsupported column type: {}".format(column["type"])] | |
) | |
# No duplicate column names | |
dupes = {c["name"] for c in columns if columns.count(c) > 1} | |
if dupes: | |
return _error(["Duplicate column name: {}".format(", ".join(dupes))]) | |
if row: | |
rows = [row] | |
if rows: | |
if not isinstance(rows, list): | |
return _error(["rows must be a list"]) | |
for row in rows: | |
if not isinstance(row, dict): | |
return _error(["rows must be a list of objects"]) | |
pk = data.get("pk") | |
pks = data.get("pks") | |
if pk and pks: | |
return _error(["Cannot specify both pk and pks"]) | |
if pk: | |
if not isinstance(pk, str): | |
return _error(["pk must be a string"]) | |
if pks: | |
if not isinstance(pks, list): | |
return _error(["pks must be a list"]) | |
for pk in pks: | |
if not isinstance(pk, str): | |
return _error(["pks must be a list of strings"]) | |
# If table exists already, read pks from that instead | |
if table_exists: | |
actual_pks = await db.primary_keys(table_name) | |
# if pk passed and table already exists check it does not change | |
bad_pks = False | |
if len(actual_pks) == 1 and data.get("pk") and data["pk"] != actual_pks[0]: | |
bad_pks = True | |
elif ( | |
len(actual_pks) > 1 | |
and data.get("pks") | |
and set(data["pks"]) != set(actual_pks) | |
): | |
bad_pks = True | |
if bad_pks: | |
return _error(["pk cannot be changed for existing table"]) | |
pks = actual_pks | |
initial_schema = None | |
if table_exists: | |
initial_schema = await db.execute_fn( | |
lambda conn: sqlite_utils.Database(conn)[table_name].schema | |
) | |
def create_table(conn): | |
table = sqlite_utils.Database(conn)[table_name] | |
if rows: | |
table.insert_all( | |
rows, pk=pks or pk, ignore=ignore, replace=replace, alter=alter | |
) | |
else: | |
table.create( | |
{c["name"]: c["type"] for c in columns}, | |
pk=pks or pk, | |
) | |
return table.schema | |
try: | |
schema = await db.execute_write_fn(create_table) | |
except Exception as e: | |
return _error([str(e)]) | |
if initial_schema is not None and initial_schema != schema: | |
await self.ds.track_event( | |
AlterTableEvent( | |
request.actor, | |
database=database_name, | |
table=table_name, | |
before_schema=initial_schema, | |
after_schema=schema, | |
) | |
) | |
table_url = self.ds.absolute_url( | |
request, self.ds.urls.table(db.name, table_name) | |
) | |
table_api_url = self.ds.absolute_url( | |
request, self.ds.urls.table(db.name, table_name, format="json") | |
) | |
details = { | |
"ok": True, | |
"database": db.name, | |
"table": table_name, | |
"table_url": table_url, | |
"table_api_url": table_api_url, | |
"schema": schema, | |
} | |
if rows: | |
details["row_count"] = len(rows) | |
if not table_exists: | |
# Only log creation if we created a table | |
await self.ds.track_event( | |
CreateTableEvent( | |
request.actor, database=db.name, table=table_name, schema=schema | |
) | |
) | |
if rows: | |
await self.ds.track_event( | |
InsertRowsEvent( | |
request.actor, | |
database=db.name, | |
table=table_name, | |
num_rows=len(rows), | |
ignore=ignore, | |
replace=replace, | |
) | |
) | |
return Response.json(details, status=201) | |
async def _table_columns(datasette, database_name): | |
internal_db = datasette.get_internal_database() | |
result = await internal_db.execute( | |
"select table_name, name from catalog_columns where database_name = ?", | |
[database_name], | |
) | |
table_columns = {} | |
for row in result.rows: | |
table_columns.setdefault(row["table_name"], []).append(row["name"]) | |
# Add views | |
db = datasette.get_database(database_name) | |
for view_name in await db.view_names(): | |
table_columns[view_name] = [] | |
return table_columns | |
async def display_rows(datasette, database, request, rows, columns): | |
display_rows = [] | |
truncate_cells = datasette.setting("truncate_cells_html") | |
for row in rows: | |
display_row = [] | |
for column, value in zip(columns, row): | |
display_value = value | |
# Let the plugins have a go | |
# pylint: disable=no-member | |
plugin_display_value = None | |
for candidate in pm.hook.render_cell( | |
row=row, | |
value=value, | |
column=column, | |
table=None, | |
database=database, | |
datasette=datasette, | |
request=request, | |
): | |
candidate = await await_me_maybe(candidate) | |
if candidate is not None: | |
plugin_display_value = candidate | |
break | |
if plugin_display_value is not None: | |
display_value = plugin_display_value | |
else: | |
if value in ("", None): | |
display_value = markupsafe.Markup(" ") | |
elif is_url(str(display_value).strip()): | |
display_value = markupsafe.Markup( | |
'<a href="{url}">{truncated_url}</a>'.format( | |
url=markupsafe.escape(value.strip()), | |
truncated_url=markupsafe.escape( | |
truncate_url(value.strip(), truncate_cells) | |
), | |
) | |
) | |
elif isinstance(display_value, bytes): | |
blob_url = path_with_format( | |
request=request, | |
format="blob", | |
extra_qs={ | |
"_blob_column": column, | |
"_blob_hash": hashlib.sha256(display_value).hexdigest(), | |
}, | |
) | |
formatted = format_bytes(len(value)) | |
display_value = markupsafe.Markup( | |
'<a class="blob-download" href="{}"{}><Binary: {:,} byte{}></a>'.format( | |
blob_url, | |
( | |
' title="{}"'.format(formatted) | |
if "bytes" not in formatted | |
else "" | |
), | |
len(value), | |
"" if len(value) == 1 else "s", | |
) | |
) | |
else: | |
display_value = str(value) | |
if truncate_cells and len(display_value) > truncate_cells: | |
display_value = display_value[:truncate_cells] + "\u2026" | |
display_row.append(display_value) | |
display_rows.append(display_row) | |
return display_rows | |
</document_content> | |
</document> | |
<document index="40"> | |
<source>datasette/views/index.py</source> | |
<document_content> | |
import json | |
from datasette.plugins import pm | |
from datasette.utils import ( | |
add_cors_headers, | |
await_me_maybe, | |
make_slot_function, | |
CustomJSONEncoder, | |
) | |
from datasette.utils.asgi import Response | |
from datasette.version import __version__ | |
from .base import BaseView | |
# Truncate table list on homepage at: | |
TRUNCATE_AT = 5 | |
# Only attempt counts if database less than this size in bytes: | |
COUNT_DB_SIZE_LIMIT = 100 * 1024 * 1024 | |
class IndexView(BaseView): | |
name = "index" | |
async def get(self, request): | |
as_format = request.url_vars["format"] | |
await self.ds.ensure_permissions(request.actor, ["view-instance"]) | |
databases = [] | |
for name, db in self.ds.databases.items(): | |
database_visible, database_private = await self.ds.check_visibility( | |
request.actor, | |
"view-database", | |
name, | |
) | |
if not database_visible: | |
continue | |
table_names = await db.table_names() | |
hidden_table_names = set(await db.hidden_table_names()) | |
views = [] | |
for view_name in await db.view_names(): | |
view_visible, view_private = await self.ds.check_visibility( | |
request.actor, | |
"view-table", | |
(name, view_name), | |
) | |
if view_visible: | |
views.append({"name": view_name, "private": view_private}) | |
# Perform counts only for immutable or DBS with <= COUNT_TABLE_LIMIT tables | |
table_counts = {} | |
if not db.is_mutable or db.size < COUNT_DB_SIZE_LIMIT: | |
table_counts = await db.table_counts(10) | |
# If any of these are None it means at least one timed out - ignore them all | |
if any(v is None for v in table_counts.values()): | |
table_counts = {} | |
tables = {} | |
for table in table_names: | |
visible, private = await self.ds.check_visibility( | |
request.actor, | |
"view-table", | |
(name, table), | |
) | |
if not visible: | |
continue | |
table_columns = await db.table_columns(table) | |
tables[table] = { | |
"name": table, | |
"columns": table_columns, | |
"primary_keys": await db.primary_keys(table), | |
"count": table_counts.get(table), | |
"hidden": table in hidden_table_names, | |
"fts_table": await db.fts_table(table), | |
"num_relationships_for_sorting": 0, | |
"private": private, | |
} | |
if request.args.get("_sort") == "relationships" or not table_counts: | |
# We will be sorting by number of relationships, so populate that field | |
all_foreign_keys = await db.get_all_foreign_keys() | |
for table, foreign_keys in all_foreign_keys.items(): | |
if table in tables.keys(): | |
count = len(foreign_keys["incoming"] + foreign_keys["outgoing"]) | |
tables[table]["num_relationships_for_sorting"] = count | |
hidden_tables = [t for t in tables.values() if t["hidden"]] | |
visible_tables = [t for t in tables.values() if not t["hidden"]] | |
tables_and_views_truncated = list( | |
sorted( | |
(t for t in tables.values() if t not in hidden_tables), | |
key=lambda t: ( | |
t["num_relationships_for_sorting"], | |
t["count"] or 0, | |
t["name"], | |
), | |
reverse=True, | |
)[:TRUNCATE_AT] | |
) | |
# Only add views if this is less than TRUNCATE_AT | |
if len(tables_and_views_truncated) < TRUNCATE_AT: | |
num_views_to_add = TRUNCATE_AT - len(tables_and_views_truncated) | |
for view in views[:num_views_to_add]: | |
tables_and_views_truncated.append(view) | |
databases.append( | |
{ | |
"name": name, | |
"hash": db.hash, | |
"color": db.color, | |
"path": self.ds.urls.database(name), | |
"tables_and_views_truncated": tables_and_views_truncated, | |
"tables_and_views_more": (len(visible_tables) + len(views)) | |
> TRUNCATE_AT, | |
"tables_count": len(visible_tables), | |
"table_rows_sum": sum((t["count"] or 0) for t in visible_tables), | |
"show_table_row_counts": bool(table_counts), | |
"hidden_table_rows_sum": sum( | |
t["count"] for t in hidden_tables if t["count"] is not None | |
), | |
"hidden_tables_count": len(hidden_tables), | |
"views_count": len(views), | |
"private": database_private, | |
} | |
) | |
if as_format: | |
headers = {} | |
if self.ds.cors: | |
add_cors_headers(headers) | |
return Response( | |
json.dumps( | |
{ | |
"databases": {db["name"]: db for db in databases}, | |
"metadata": await self.ds.get_instance_metadata(), | |
}, | |
cls=CustomJSONEncoder, | |
), | |
content_type="application/json; charset=utf-8", | |
headers=headers, | |
) | |
else: | |
homepage_actions = [] | |
for hook in pm.hook.homepage_actions( | |
datasette=self.ds, | |
actor=request.actor, | |
request=request, | |
): | |
extra_links = await await_me_maybe(hook) | |
if extra_links: | |
homepage_actions.extend(extra_links) | |
alternative_homepage = request.path == "/-/" | |
return await self.render( | |
["default:index.html" if alternative_homepage else "index.html"], | |
request=request, | |
context={ | |
"databases": databases, | |
"metadata": await self.ds.get_instance_metadata(), | |
"datasette_version": __version__, | |
"private": not await self.ds.permission_allowed( | |
None, "view-instance" | |
), | |
"top_homepage": make_slot_function( | |
"top_homepage", self.ds, request | |
), | |
"homepage_actions": homepage_actions, | |
"noindex": request.path == "/-/", | |
}, | |
) | |
</document_content> | |
</document> | |
<document index="41"> | |
<source>datasette/views/row.py</source> | |
<document_content> | |
from datasette.utils.asgi import NotFound, Forbidden, Response | |
from datasette.database import QueryInterrupted | |
from datasette.events import UpdateRowEvent, DeleteRowEvent | |
from .base import DataView, BaseView, _error | |
from datasette.utils import ( | |
await_me_maybe, | |
make_slot_function, | |
to_css_class, | |
escape_sqlite, | |
) | |
from datasette.plugins import pm | |
import json | |
import sqlite_utils | |
from .table import display_columns_and_rows | |
class RowView(DataView): | |
name = "row" | |
async def data(self, request, default_labels=False): | |
resolved = await self.ds.resolve_row(request) | |
db = resolved.db | |
database = db.name | |
table = resolved.table | |
pk_values = resolved.pk_values | |
# Ensure user has permission to view this row | |
visible, private = await self.ds.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-table", (database, table)), | |
("view-database", database), | |
"view-instance", | |
], | |
) | |
if not visible: | |
raise Forbidden("You do not have permission to view this table") | |
results = await resolved.db.execute( | |
resolved.sql, resolved.params, truncate=True | |
) | |
columns = [r[0] for r in results.description] | |
rows = list(results.rows) | |
if not rows: | |
raise NotFound(f"Record not found: {pk_values}") | |
async def template_data(): | |
display_columns, display_rows = await display_columns_and_rows( | |
self.ds, | |
database, | |
table, | |
results.description, | |
rows, | |
link_column=False, | |
truncate_cells=0, | |
request=request, | |
) | |
for column in display_columns: | |
column["sortable"] = False | |
row_actions = [] | |
for hook in pm.hook.row_actions( | |
datasette=self.ds, | |
actor=request.actor, | |
request=request, | |
database=database, | |
table=table, | |
row=rows[0], | |
): | |
extra_links = await await_me_maybe(hook) | |
if extra_links: | |
row_actions.extend(extra_links) | |
return { | |
"private": private, | |
"foreign_key_tables": await self.foreign_key_tables( | |
database, table, pk_values | |
), | |
"database_color": db.color, | |
"display_columns": display_columns, | |
"display_rows": display_rows, | |
"custom_table_templates": [ | |
f"_table-{to_css_class(database)}-{to_css_class(table)}.html", | |
f"_table-row-{to_css_class(database)}-{to_css_class(table)}.html", | |
"_table.html", | |
], | |
"row_actions": row_actions, | |
"top_row": make_slot_function( | |
"top_row", | |
self.ds, | |
request, | |
database=resolved.db.name, | |
table=resolved.table, | |
row=rows[0], | |
), | |
"metadata": {}, | |
} | |
data = { | |
"database": database, | |
"table": table, | |
"rows": rows, | |
"columns": columns, | |
"primary_keys": resolved.pks, | |
"primary_key_values": pk_values, | |
} | |
if "foreign_key_tables" in (request.args.get("_extras") or "").split(","): | |
data["foreign_key_tables"] = await self.foreign_key_tables( | |
database, table, pk_values | |
) | |
return ( | |
data, | |
template_data, | |
( | |
f"row-{to_css_class(database)}-{to_css_class(table)}.html", | |
"row.html", | |
), | |
) | |
async def foreign_key_tables(self, database, table, pk_values): | |
if len(pk_values) != 1: | |
return [] | |
db = self.ds.databases[database] | |
all_foreign_keys = await db.get_all_foreign_keys() | |
foreign_keys = all_foreign_keys[table]["incoming"] | |
if len(foreign_keys) == 0: | |
return [] | |
sql = "select " + ", ".join( | |
[ | |
"(select count(*) from {table} where {column}=:id)".format( | |
table=escape_sqlite(fk["other_table"]), | |
column=escape_sqlite(fk["other_column"]), | |
) | |
for fk in foreign_keys | |
] | |
) | |
try: | |
rows = list(await db.execute(sql, {"id": pk_values[0]})) | |
except QueryInterrupted: | |
# Almost certainly hit the timeout | |
return [] | |
foreign_table_counts = dict( | |
zip( | |
[(fk["other_table"], fk["other_column"]) for fk in foreign_keys], | |
list(rows[0]), | |
) | |
) | |
foreign_key_tables = [] | |
for fk in foreign_keys: | |
count = ( | |
foreign_table_counts.get((fk["other_table"], fk["other_column"])) or 0 | |
) | |
key = fk["other_column"] | |
if key.startswith("_"): | |
key += "__exact" | |
link = "{}?{}={}".format( | |
self.ds.urls.table(database, fk["other_table"]), | |
key, | |
",".join(pk_values), | |
) | |
foreign_key_tables.append({**fk, **{"count": count, "link": link}}) | |
return foreign_key_tables | |
class RowError(Exception): | |
def __init__(self, error): | |
self.error = error | |
async def _resolve_row_and_check_permission(datasette, request, permission): | |
from datasette.app import DatabaseNotFound, TableNotFound, RowNotFound | |
try: | |
resolved = await datasette.resolve_row(request) | |
except DatabaseNotFound as e: | |
return False, _error(["Database not found: {}".format(e.database_name)], 404) | |
except TableNotFound as e: | |
return False, _error(["Table not found: {}".format(e.table)], 404) | |
except RowNotFound as e: | |
return False, _error(["Record not found: {}".format(e.pk_values)], 404) | |
# Ensure user has permission to delete this row | |
if not await datasette.permission_allowed( | |
request.actor, permission, resource=(resolved.db.name, resolved.table) | |
): | |
return False, _error(["Permission denied"], 403) | |
return True, resolved | |
class RowDeleteView(BaseView): | |
name = "row-delete" | |
def __init__(self, datasette): | |
self.ds = datasette | |
async def post(self, request): | |
ok, resolved = await _resolve_row_and_check_permission( | |
self.ds, request, "delete-row" | |
) | |
if not ok: | |
return resolved | |
# Delete table | |
def delete_row(conn): | |
sqlite_utils.Database(conn)[resolved.table].delete(resolved.pk_values) | |
try: | |
await resolved.db.execute_write_fn(delete_row) | |
except Exception as e: | |
return _error([str(e)], 500) | |
await self.ds.track_event( | |
DeleteRowEvent( | |
actor=request.actor, | |
database=resolved.db.name, | |
table=resolved.table, | |
pks=resolved.pk_values, | |
) | |
) | |
return Response.json({"ok": True}, status=200) | |
class RowUpdateView(BaseView): | |
name = "row-update" | |
def __init__(self, datasette): | |
self.ds = datasette | |
async def post(self, request): | |
ok, resolved = await _resolve_row_and_check_permission( | |
self.ds, request, "update-row" | |
) | |
if not ok: | |
return resolved | |
body = await request.post_body() | |
try: | |
data = json.loads(body) | |
except json.JSONDecodeError as e: | |
return _error(["Invalid JSON: {}".format(e)]) | |
if not isinstance(data, dict): | |
return _error(["JSON must be a dictionary"]) | |
if not "update" in data or not isinstance(data["update"], dict): | |
return _error(["JSON must contain an update dictionary"]) | |
invalid_keys = set(data.keys()) - {"update", "return", "alter"} | |
if invalid_keys: | |
return _error(["Invalid keys: {}".format(", ".join(invalid_keys))]) | |
update = data["update"] | |
alter = data.get("alter") | |
if alter and not await self.ds.permission_allowed( | |
request.actor, "alter-table", resource=(resolved.db.name, resolved.table) | |
): | |
return _error(["Permission denied for alter-table"], 403) | |
def update_row(conn): | |
sqlite_utils.Database(conn)[resolved.table].update( | |
resolved.pk_values, update, alter=alter | |
) | |
try: | |
await resolved.db.execute_write_fn(update_row) | |
except Exception as e: | |
return _error([str(e)], 400) | |
result = {"ok": True} | |
if data.get("return"): | |
results = await resolved.db.execute( | |
resolved.sql, resolved.params, truncate=True | |
) | |
result["row"] = results.dicts()[0] | |
await self.ds.track_event( | |
UpdateRowEvent( | |
actor=request.actor, | |
database=resolved.db.name, | |
table=resolved.table, | |
pks=resolved.pk_values, | |
) | |
) | |
return Response.json(result, status=200) | |
</document_content> | |
</document> | |
<document index="42"> | |
<source>datasette/views/special.py</source> | |
<document_content> | |
import json | |
from datasette.events import LogoutEvent, LoginEvent, CreateTokenEvent | |
from datasette.utils.asgi import Response, Forbidden | |
from datasette.utils import ( | |
actor_matches_allow, | |
add_cors_headers, | |
tilde_encode, | |
tilde_decode, | |
) | |
from .base import BaseView, View | |
import secrets | |
import urllib | |
class JsonDataView(BaseView): | |
name = "json_data" | |
def __init__( | |
self, | |
datasette, | |
filename, | |
data_callback, | |
needs_request=False, | |
permission="view-instance", | |
): | |
self.ds = datasette | |
self.filename = filename | |
self.data_callback = data_callback | |
self.needs_request = needs_request | |
self.permission = permission | |
async def get(self, request): | |
as_format = request.url_vars["format"] | |
if self.permission: | |
await self.ds.ensure_permissions(request.actor, [self.permission]) | |
if self.needs_request: | |
data = self.data_callback(request) | |
else: | |
data = self.data_callback() | |
if as_format: | |
headers = {} | |
if self.ds.cors: | |
add_cors_headers(headers) | |
return Response( | |
json.dumps(data, default=repr), | |
content_type="application/json; charset=utf-8", | |
headers=headers, | |
) | |
else: | |
return await self.render( | |
["show_json.html"], | |
request=request, | |
context={ | |
"filename": self.filename, | |
"data_json": json.dumps(data, indent=4, default=repr), | |
}, | |
) | |
class PatternPortfolioView(View): | |
async def get(self, request, datasette): | |
await datasette.ensure_permissions(request.actor, ["view-instance"]) | |
return Response.html( | |
await datasette.render_template( | |
"patterns.html", | |
request=request, | |
view_name="patterns", | |
) | |
) | |
class AuthTokenView(BaseView): | |
name = "auth_token" | |
has_json_alternate = False | |
async def get(self, request): | |
# If already signed in as root, redirect | |
if request.actor and request.actor.get("id") == "root": | |
return Response.redirect(self.ds.urls.instance()) | |
token = request.args.get("token") or "" | |
if not self.ds._root_token: | |
raise Forbidden("Root token has already been used") | |
if secrets.compare_digest(token, self.ds._root_token): | |
self.ds._root_token = None | |
response = Response.redirect(self.ds.urls.instance()) | |
root_actor = {"id": "root"} | |
self.ds.set_actor_cookie(response, root_actor) | |
await self.ds.track_event(LoginEvent(actor=root_actor)) | |
return response | |
else: | |
raise Forbidden("Invalid token") | |
class LogoutView(BaseView): | |
name = "logout" | |
has_json_alternate = False | |
async def get(self, request): | |
if not request.actor: | |
return Response.redirect(self.ds.urls.instance()) | |
return await self.render( | |
["logout.html"], | |
request, | |
{"actor": request.actor}, | |
) | |
async def post(self, request): | |
response = Response.redirect(self.ds.urls.instance()) | |
self.ds.delete_actor_cookie(response) | |
self.ds.add_message(request, "You are now logged out", self.ds.WARNING) | |
await self.ds.track_event(LogoutEvent(actor=request.actor)) | |
return response | |
class PermissionsDebugView(BaseView): | |
name = "permissions_debug" | |
has_json_alternate = False | |
async def get(self, request): | |
await self.ds.ensure_permissions(request.actor, ["view-instance"]) | |
if not await self.ds.permission_allowed(request.actor, "permissions-debug"): | |
raise Forbidden("Permission denied") | |
filter_ = request.args.get("filter") or "all" | |
permission_checks = list(reversed(self.ds._permission_checks)) | |
if filter_ == "exclude-yours": | |
permission_checks = [ | |
check | |
for check in permission_checks | |
if (check["actor"] or {}).get("id") != request.actor["id"] | |
] | |
elif filter_ == "only-yours": | |
permission_checks = [ | |
check | |
for check in permission_checks | |
if (check["actor"] or {}).get("id") == request.actor["id"] | |
] | |
return await self.render( | |
["permissions_debug.html"], | |
request, | |
# list() avoids error if check is performed during template render: | |
{ | |
"permission_checks": permission_checks, | |
"filter": filter_, | |
"permissions": [ | |
{ | |
"name": p.name, | |
"abbr": p.abbr, | |
"description": p.description, | |
"takes_database": p.takes_database, | |
"takes_resource": p.takes_resource, | |
"default": p.default, | |
} | |
for p in self.ds.permissions.values() | |
], | |
}, | |
) | |
async def post(self, request): | |
await self.ds.ensure_permissions(request.actor, ["view-instance"]) | |
if not await self.ds.permission_allowed(request.actor, "permissions-debug"): | |
raise Forbidden("Permission denied") | |
vars = await request.post_vars() | |
actor = json.loads(vars["actor"]) | |
permission = vars["permission"] | |
resource_1 = vars["resource_1"] | |
resource_2 = vars["resource_2"] | |
resource = [] | |
if resource_1: | |
resource.append(resource_1) | |
if resource_2: | |
resource.append(resource_2) | |
resource = tuple(resource) | |
if len(resource) == 1: | |
resource = resource[0] | |
result = await self.ds.permission_allowed( | |
actor, permission, resource, default="USE_DEFAULT" | |
) | |
return Response.json( | |
{ | |
"actor": actor, | |
"permission": permission, | |
"resource": resource, | |
"result": result, | |
"default": self.ds.permissions[permission].default, | |
} | |
) | |
class AllowDebugView(BaseView): | |
name = "allow_debug" | |
has_json_alternate = False | |
async def get(self, request): | |
errors = [] | |
actor_input = request.args.get("actor") or '{"id": "root"}' | |
try: | |
actor = json.loads(actor_input) | |
actor_input = json.dumps(actor, indent=4) | |
except json.decoder.JSONDecodeError as ex: | |
errors.append(f"Actor JSON error: {ex}") | |
allow_input = request.args.get("allow") or '{"id": "*"}' | |
try: | |
allow = json.loads(allow_input) | |
allow_input = json.dumps(allow, indent=4) | |
except json.decoder.JSONDecodeError as ex: | |
errors.append(f"Allow JSON error: {ex}") | |
result = None | |
if not errors: | |
result = str(actor_matches_allow(actor, allow)) | |
return await self.render( | |
["allow_debug.html"], | |
request, | |
{ | |
"result": result, | |
"error": "\n\n".join(errors) if errors else "", | |
"actor_input": actor_input, | |
"allow_input": allow_input, | |
}, | |
) | |
class MessagesDebugView(BaseView): | |
name = "messages_debug" | |
has_json_alternate = False | |
async def get(self, request): | |
await self.ds.ensure_permissions(request.actor, ["view-instance"]) | |
return await self.render(["messages_debug.html"], request) | |
async def post(self, request): | |
await self.ds.ensure_permissions(request.actor, ["view-instance"]) | |
post = await request.post_vars() | |
message = post.get("message", "") | |
message_type = post.get("message_type") or "INFO" | |
assert message_type in ("INFO", "WARNING", "ERROR", "all") | |
datasette = self.ds | |
if message_type == "all": | |
datasette.add_message(request, message, datasette.INFO) | |
datasette.add_message(request, message, datasette.WARNING) | |
datasette.add_message(request, message, datasette.ERROR) | |
else: | |
datasette.add_message(request, message, getattr(datasette, message_type)) | |
return Response.redirect(self.ds.urls.instance()) | |
class CreateTokenView(BaseView): | |
name = "create_token" | |
has_json_alternate = False | |
def check_permission(self, request): | |
if not self.ds.setting("allow_signed_tokens"): | |
raise Forbidden("Signed tokens are not enabled for this Datasette instance") | |
if not request.actor: | |
raise Forbidden("You must be logged in to create a token") | |
if not request.actor.get("id"): | |
raise Forbidden( | |
"You must be logged in as an actor with an ID to create a token" | |
) | |
if request.actor.get("token"): | |
raise Forbidden( | |
"Token authentication cannot be used to create additional tokens" | |
) | |
async def shared(self, request): | |
self.check_permission(request) | |
# Build list of databases and tables the user has permission to view | |
database_with_tables = [] | |
for database in self.ds.databases.values(): | |
if database.name == "_memory": | |
continue | |
if not await self.ds.permission_allowed( | |
request.actor, "view-database", database.name | |
): | |
continue | |
hidden_tables = await database.hidden_table_names() | |
tables = [] | |
for table in await database.table_names(): | |
if table in hidden_tables: | |
continue | |
if not await self.ds.permission_allowed( | |
request.actor, | |
"view-table", | |
resource=(database.name, table), | |
): | |
continue | |
tables.append({"name": table, "encoded": tilde_encode(table)}) | |
database_with_tables.append( | |
{ | |
"name": database.name, | |
"encoded": tilde_encode(database.name), | |
"tables": tables, | |
} | |
) | |
return { | |
"actor": request.actor, | |
"all_permissions": self.ds.permissions.keys(), | |
"database_permissions": [ | |
key | |
for key, value in self.ds.permissions.items() | |
if value.takes_database | |
], | |
"resource_permissions": [ | |
key | |
for key, value in self.ds.permissions.items() | |
if value.takes_resource | |
], | |
"database_with_tables": database_with_tables, | |
} | |
async def get(self, request): | |
self.check_permission(request) | |
return await self.render( | |
["create_token.html"], request, await self.shared(request) | |
) | |
async def post(self, request): | |
self.check_permission(request) | |
post = await request.post_vars() | |
errors = [] | |
expires_after = None | |
if post.get("expire_type"): | |
duration_string = post.get("expire_duration") | |
if ( | |
not duration_string | |
or not duration_string.isdigit() | |
or not int(duration_string) > 0 | |
): | |
errors.append("Invalid expire duration") | |
else: | |
unit = post["expire_type"] | |
if unit == "minutes": | |
expires_after = int(duration_string) * 60 | |
elif unit == "hours": | |
expires_after = int(duration_string) * 60 * 60 | |
elif unit == "days": | |
expires_after = int(duration_string) * 60 * 60 * 24 | |
else: | |
errors.append("Invalid expire duration unit") | |
# Are there any restrictions? | |
restrict_all = [] | |
restrict_database = {} | |
restrict_resource = {} | |
for key in post: | |
if key.startswith("all:") and key.count(":") == 1: | |
restrict_all.append(key.split(":")[1]) | |
elif key.startswith("database:") and key.count(":") == 2: | |
bits = key.split(":") | |
database = tilde_decode(bits[1]) | |
action = bits[2] | |
restrict_database.setdefault(database, []).append(action) | |
elif key.startswith("resource:") and key.count(":") == 3: | |
bits = key.split(":") | |
database = tilde_decode(bits[1]) | |
resource = tilde_decode(bits[2]) | |
action = bits[3] | |
restrict_resource.setdefault(database, {}).setdefault( | |
resource, [] | |
).append(action) | |
token = self.ds.create_token( | |
request.actor["id"], | |
expires_after=expires_after, | |
restrict_all=restrict_all, | |
restrict_database=restrict_database, | |
restrict_resource=restrict_resource, | |
) | |
token_bits = self.ds.unsign(token[len("dstok_") :], namespace="token") | |
await self.ds.track_event( | |
CreateTokenEvent( | |
actor=request.actor, | |
expires_after=expires_after, | |
restrict_all=restrict_all, | |
restrict_database=restrict_database, | |
restrict_resource=restrict_resource, | |
) | |
) | |
context = await self.shared(request) | |
context.update({"errors": errors, "token": token, "token_bits": token_bits}) | |
return await self.render(["create_token.html"], request, context) | |
class ApiExplorerView(BaseView): | |
name = "api_explorer" | |
has_json_alternate = False | |
async def example_links(self, request): | |
databases = [] | |
for name, db in self.ds.databases.items(): | |
if name == "_internal": | |
continue | |
database_visible, _ = await self.ds.check_visibility( | |
request.actor, permissions=[("view-database", name), "view-instance"] | |
) | |
if not database_visible: | |
continue | |
tables = [] | |
table_names = await db.table_names() | |
for table in table_names: | |
visible, _ = await self.ds.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-table", (name, table)), | |
("view-database", name), | |
"view-instance", | |
], | |
) | |
if not visible: | |
continue | |
table_links = [] | |
tables.append({"name": table, "links": table_links}) | |
table_links.append( | |
{ | |
"label": "Get rows for {}".format(table), | |
"method": "GET", | |
"path": self.ds.urls.table(name, table, format="json"), | |
} | |
) | |
# If not mutable don't show any write APIs | |
if not db.is_mutable: | |
continue | |
if await self.ds.permission_allowed( | |
request.actor, "insert-row", (name, table) | |
): | |
pks = await db.primary_keys(table) | |
table_links.extend( | |
[ | |
{ | |
"path": self.ds.urls.table(name, table) + "/-/insert", | |
"method": "POST", | |
"label": "Insert rows into {}".format(table), | |
"json": { | |
"rows": [ | |
{ | |
column: None | |
for column in await db.table_columns(table) | |
if column not in pks | |
} | |
] | |
}, | |
}, | |
{ | |
"path": self.ds.urls.table(name, table) + "/-/upsert", | |
"method": "POST", | |
"label": "Upsert rows into {}".format(table), | |
"json": { | |
"rows": [ | |
{ | |
column: None | |
for column in await db.table_columns(table) | |
if column not in pks | |
} | |
] | |
}, | |
}, | |
] | |
) | |
if await self.ds.permission_allowed( | |
request.actor, "drop-table", (name, table) | |
): | |
table_links.append( | |
{ | |
"path": self.ds.urls.table(name, table) + "/-/drop", | |
"label": "Drop table {}".format(table), | |
"json": {"confirm": False}, | |
"method": "POST", | |
} | |
) | |
database_links = [] | |
if ( | |
await self.ds.permission_allowed(request.actor, "create-table", name) | |
and db.is_mutable | |
): | |
database_links.append( | |
{ | |
"path": self.ds.urls.database(name) + "/-/create", | |
"label": "Create table in {}".format(name), | |
"json": { | |
"table": "new_table", | |
"columns": [ | |
{"name": "id", "type": "integer"}, | |
{"name": "name", "type": "text"}, | |
], | |
"pk": "id", | |
}, | |
"method": "POST", | |
} | |
) | |
if database_links or tables: | |
databases.append( | |
{ | |
"name": name, | |
"links": database_links, | |
"tables": tables, | |
} | |
) | |
# Sort so that mutable databases are first | |
databases.sort(key=lambda d: not self.ds.databases[d["name"]].is_mutable) | |
return databases | |
async def get(self, request): | |
visible, private = await self.ds.check_visibility( | |
request.actor, | |
permissions=["view-instance"], | |
) | |
if not visible: | |
raise Forbidden("You do not have permission to view this instance") | |
def api_path(link): | |
return "/-/api#{}".format( | |
urllib.parse.urlencode( | |
{ | |
key: json.dumps(value, indent=2) if key == "json" else value | |
for key, value in link.items() | |
if key in ("path", "method", "json") | |
} | |
) | |
) | |
return await self.render( | |
["api_explorer.html"], | |
request, | |
{ | |
"example_links": await self.example_links(request), | |
"api_path": api_path, | |
"private": private, | |
}, | |
) | |
</document_content> | |
</document> | |
<document index="43"> | |
<source>datasette/views/table.py</source> | |
<document_content> | |
import asyncio | |
import itertools | |
import json | |
import urllib | |
from asyncinject import Registry | |
import markupsafe | |
from datasette.plugins import pm | |
from datasette.database import QueryInterrupted | |
from datasette.events import ( | |
AlterTableEvent, | |
DropTableEvent, | |
InsertRowsEvent, | |
UpsertRowsEvent, | |
) | |
from datasette import tracer | |
from datasette.utils import ( | |
add_cors_headers, | |
await_me_maybe, | |
call_with_supported_arguments, | |
CustomRow, | |
append_querystring, | |
compound_keys_after_sql, | |
format_bytes, | |
make_slot_function, | |
tilde_encode, | |
escape_sqlite, | |
filters_should_redirect, | |
is_url, | |
path_from_row_pks, | |
path_with_added_args, | |
path_with_format, | |
path_with_removed_args, | |
path_with_replaced_args, | |
to_css_class, | |
truncate_url, | |
urlsafe_components, | |
value_as_boolean, | |
InvalidSql, | |
sqlite3, | |
) | |
from datasette.utils.asgi import BadRequest, Forbidden, NotFound, Response | |
from datasette.filters import Filters | |
import sqlite_utils | |
from .base import BaseView, DatasetteError, _error, stream_csv | |
from .database import QueryView | |
LINK_WITH_LABEL = ( | |
'<a href="{base_url}{database}/{table}/{link_id}">{label}</a> <em>{id}</em>' | |
) | |
LINK_WITH_VALUE = '<a href="{base_url}{database}/{table}/{link_id}">{id}</a>' | |
class Row: | |
def __init__(self, cells): | |
self.cells = cells | |
def __iter__(self): | |
return iter(self.cells) | |
def __getitem__(self, key): | |
for cell in self.cells: | |
if cell["column"] == key: | |
return cell["raw"] | |
raise KeyError | |
def display(self, key): | |
for cell in self.cells: | |
if cell["column"] == key: | |
return cell["value"] | |
return None | |
def __str__(self): | |
d = { | |
key: self[key] | |
for key in [ | |
c["column"] for c in self.cells if not c.get("is_special_link_column") | |
] | |
} | |
return json.dumps(d, default=repr, indent=2) | |
async def run_sequential(*args): | |
# This used to be swappable for asyncio.gather() to run things in | |
# parallel, but this lead to hard-to-debug locking issues with | |
# in-memory databases: https://github.com/simonw/datasette/issues/2189 | |
results = [] | |
for fn in args: | |
results.append(await fn) | |
return results | |
def _redirect(datasette, request, path, forward_querystring=True, remove_args=None): | |
if request.query_string and "?" not in path and forward_querystring: | |
path = f"{path}?{request.query_string}" | |
if remove_args: | |
path = path_with_removed_args(request, remove_args, path=path) | |
r = Response.redirect(path) | |
r.headers["Link"] = f"<{path}>; rel=preload" | |
if datasette.cors: | |
add_cors_headers(r.headers) | |
return r | |
async def _redirect_if_needed(datasette, request, resolved): | |
# Handle ?_filter_column | |
redirect_params = filters_should_redirect(request.args) | |
if redirect_params: | |
return _redirect( | |
datasette, | |
request, | |
datasette.urls.path(path_with_added_args(request, redirect_params)), | |
forward_querystring=False, | |
) | |
# If ?_sort_by_desc=on (from checkbox) redirect to _sort_desc=(_sort) | |
if "_sort_by_desc" in request.args: | |
return _redirect( | |
datasette, | |
request, | |
datasette.urls.path( | |
path_with_added_args( | |
request, | |
{ | |
"_sort_desc": request.args.get("_sort"), | |
"_sort_by_desc": None, | |
"_sort": None, | |
}, | |
) | |
), | |
forward_querystring=False, | |
) | |
async def display_columns_and_rows( | |
datasette, | |
database_name, | |
table_name, | |
description, | |
rows, | |
link_column=False, | |
truncate_cells=0, | |
sortable_columns=None, | |
request=None, | |
): | |
"""Returns columns, rows for specified table - including fancy foreign key treatment""" | |
sortable_columns = sortable_columns or set() | |
db = datasette.databases[database_name] | |
column_descriptions = dict( | |
await datasette.get_internal_database().execute( | |
""" | |
SELECT | |
column_name, | |
value | |
FROM metadata_columns | |
WHERE database_name = ? | |
AND resource_name = ? | |
AND key = 'description' | |
""", | |
[database_name, table_name], | |
) | |
) | |
column_details = { | |
col.name: col for col in await db.table_column_details(table_name) | |
} | |
table_config = await datasette.table_config(database_name, table_name) | |
pks = await db.primary_keys(table_name) | |
pks_for_display = pks | |
if not pks_for_display: | |
pks_for_display = ["rowid"] | |
columns = [] | |
for r in description: | |
if r[0] == "rowid" and "rowid" not in column_details: | |
type_ = "integer" | |
notnull = 0 | |
else: | |
type_ = column_details[r[0]].type | |
notnull = column_details[r[0]].notnull | |
columns.append( | |
{ | |
"name": r[0], | |
"sortable": r[0] in sortable_columns, | |
"is_pk": r[0] in pks_for_display, | |
"type": type_, | |
"notnull": notnull, | |
"description": column_descriptions.get(r[0]), | |
} | |
) | |
column_to_foreign_key_table = { | |
fk["column"]: fk["other_table"] | |
for fk in await db.foreign_keys_for_table(table_name) | |
} | |
cell_rows = [] | |
base_url = datasette.setting("base_url") | |
for row in rows: | |
cells = [] | |
# Unless we are a view, the first column is a link - either to the rowid | |
# or to the simple or compound primary key | |
if link_column: | |
is_special_link_column = len(pks) != 1 | |
pk_path = path_from_row_pks(row, pks, not pks, False) | |
cells.append( | |
{ | |
"column": pks[0] if len(pks) == 1 else "Link", | |
"value_type": "pk", | |
"is_special_link_column": is_special_link_column, | |
"raw": pk_path, | |
"value": markupsafe.Markup( | |
'<a href="{table_path}/{flat_pks_quoted}">{flat_pks}</a>'.format( | |
table_path=datasette.urls.table(database_name, table_name), | |
flat_pks=str(markupsafe.escape(pk_path)), | |
flat_pks_quoted=path_from_row_pks(row, pks, not pks), | |
) | |
), | |
} | |
) | |
for value, column_dict in zip(row, columns): | |
column = column_dict["name"] | |
if link_column and len(pks) == 1 and column == pks[0]: | |
# If there's a simple primary key, don't repeat the value as it's | |
# already shown in the link column. | |
continue | |
# First let the plugins have a go | |
# pylint: disable=no-member | |
plugin_display_value = None | |
for candidate in pm.hook.render_cell( | |
row=row, | |
value=value, | |
column=column, | |
table=table_name, | |
database=database_name, | |
datasette=datasette, | |
request=request, | |
): | |
candidate = await await_me_maybe(candidate) | |
if candidate is not None: | |
plugin_display_value = candidate | |
break | |
if plugin_display_value: | |
display_value = plugin_display_value | |
elif isinstance(value, bytes): | |
formatted = format_bytes(len(value)) | |
display_value = markupsafe.Markup( | |
'<a class="blob-download" href="{}"{}><Binary: {:,} byte{}></a>'.format( | |
datasette.urls.row_blob( | |
database_name, | |
table_name, | |
path_from_row_pks(row, pks, not pks), | |
column, | |
), | |
( | |
' title="{}"'.format(formatted) | |
if "bytes" not in formatted | |
else "" | |
), | |
len(value), | |
"" if len(value) == 1 else "s", | |
) | |
) | |
elif isinstance(value, dict): | |
# It's an expanded foreign key - display link to other row | |
label = value["label"] | |
value = value["value"] | |
# The table we link to depends on the column | |
other_table = column_to_foreign_key_table[column] | |
link_template = LINK_WITH_LABEL if (label != value) else LINK_WITH_VALUE | |
display_value = markupsafe.Markup( | |
link_template.format( | |
database=database_name, | |
base_url=base_url, | |
table=tilde_encode(other_table), | |
link_id=tilde_encode(str(value)), | |
id=str(markupsafe.escape(value)), | |
label=str(markupsafe.escape(label)) or "-", | |
) | |
) | |
elif value in ("", None): | |
display_value = markupsafe.Markup(" ") | |
elif is_url(str(value).strip()): | |
display_value = markupsafe.Markup( | |
'<a href="{url}">{truncated_url}</a>'.format( | |
url=markupsafe.escape(value.strip()), | |
truncated_url=markupsafe.escape( | |
truncate_url(value.strip(), truncate_cells) | |
), | |
) | |
) | |
else: | |
display_value = str(value) | |
if truncate_cells and len(display_value) > truncate_cells: | |
display_value = display_value[:truncate_cells] + "\u2026" | |
cells.append( | |
{ | |
"column": column, | |
"value": display_value, | |
"raw": value, | |
"value_type": ( | |
"none" if value is None else str(type(value).__name__) | |
), | |
} | |
) | |
cell_rows.append(Row(cells)) | |
if link_column: | |
# Add the link column header. | |
# If it's a simple primary key, we have to remove and re-add that column name at | |
# the beginning of the header row. | |
first_column = None | |
if len(pks) == 1: | |
columns = [col for col in columns if col["name"] != pks[0]] | |
first_column = { | |
"name": pks[0], | |
"sortable": len(pks) == 1, | |
"is_pk": True, | |
"type": column_details[pks[0]].type, | |
"notnull": column_details[pks[0]].notnull, | |
} | |
else: | |
first_column = { | |
"name": "Link", | |
"sortable": False, | |
"is_pk": False, | |
"type": "", | |
"notnull": 0, | |
} | |
columns = [first_column] + columns | |
return columns, cell_rows | |
class TableInsertView(BaseView): | |
name = "table-insert" | |
def __init__(self, datasette): | |
self.ds = datasette | |
async def _validate_data(self, request, db, table_name, pks, upsert): | |
errors = [] | |
pks_list = [] | |
if isinstance(pks, str): | |
pks_list = [pks] | |
else: | |
pks_list = list(pks) | |
if not pks_list: | |
pks_list = ["rowid"] | |
def _errors(errors): | |
return None, errors, {} | |
if not request.headers.get("content-type").startswith("application/json"): | |
# TODO: handle form-encoded data | |
return _errors(["Invalid content-type, must be application/json"]) | |
body = await request.post_body() | |
try: | |
data = json.loads(body) | |
except json.JSONDecodeError as e: | |
return _errors(["Invalid JSON: {}".format(e)]) | |
if not isinstance(data, dict): | |
return _errors(["JSON must be a dictionary"]) | |
keys = data.keys() | |
# keys must contain "row" or "rows" | |
if "row" not in keys and "rows" not in keys: | |
return _errors(['JSON must have one or other of "row" or "rows"']) | |
rows = [] | |
if "row" in keys: | |
if "rows" in keys: | |
return _errors(['Cannot use "row" and "rows" at the same time']) | |
row = data["row"] | |
if not isinstance(row, dict): | |
return _errors(['"row" must be a dictionary']) | |
rows = [row] | |
data["return"] = True | |
else: | |
rows = data["rows"] | |
if not isinstance(rows, list): | |
return _errors(['"rows" must be a list']) | |
for row in rows: | |
if not isinstance(row, dict): | |
return _errors(['"rows" must be a list of dictionaries']) | |
# Does this exceed max_insert_rows? | |
max_insert_rows = self.ds.setting("max_insert_rows") | |
if len(rows) > max_insert_rows: | |
return _errors( | |
["Too many rows, maximum allowed is {}".format(max_insert_rows)] | |
) | |
# Validate other parameters | |
extras = { | |
key: value for key, value in data.items() if key not in ("row", "rows") | |
} | |
valid_extras = {"return", "ignore", "replace", "alter"} | |
invalid_extras = extras.keys() - valid_extras | |
if invalid_extras: | |
return _errors( | |
['Invalid parameter: "{}"'.format('", "'.join(sorted(invalid_extras)))] | |
) | |
if extras.get("ignore") and extras.get("replace"): | |
return _errors(['Cannot use "ignore" and "replace" at the same time']) | |
columns = set(await db.table_columns(table_name)) | |
columns.update(pks_list) | |
for i, row in enumerate(rows): | |
if upsert: | |
# It MUST have the primary key | |
missing_pks = [pk for pk in pks_list if pk not in row] | |
if missing_pks: | |
errors.append( | |
'Row {} is missing primary key column(s): "{}"'.format( | |
i, '", "'.join(missing_pks) | |
) | |
) | |
invalid_columns = set(row.keys()) - columns | |
if invalid_columns and not extras.get("alter"): | |
errors.append( | |
"Row {} has invalid columns: {}".format( | |
i, ", ".join(sorted(invalid_columns)) | |
) | |
) | |
if errors: | |
return _errors(errors) | |
return rows, errors, extras | |
async def post(self, request, upsert=False): | |
try: | |
resolved = await self.ds.resolve_table(request) | |
except NotFound as e: | |
return _error([e.args[0]], 404) | |
db = resolved.db | |
database_name = db.name | |
table_name = resolved.table | |
# Table must exist (may handle table creation in the future) | |
db = self.ds.get_database(database_name) | |
if not await db.table_exists(table_name): | |
return _error(["Table not found: {}".format(table_name)], 404) | |
if upsert: | |
# Must have insert-row AND upsert-row permissions | |
if not ( | |
await self.ds.permission_allowed( | |
request.actor, "insert-row", resource=(database_name, table_name) | |
) | |
and await self.ds.permission_allowed( | |
request.actor, "update-row", resource=(database_name, table_name) | |
) | |
): | |
return _error( | |
["Permission denied: need both insert-row and update-row"], 403 | |
) | |
else: | |
# Must have insert-row permission | |
if not await self.ds.permission_allowed( | |
request.actor, "insert-row", resource=(database_name, table_name) | |
): | |
return _error(["Permission denied"], 403) | |
if not db.is_mutable: | |
return _error(["Database is immutable"], 403) | |
pks = await db.primary_keys(table_name) | |
rows, errors, extras = await self._validate_data( | |
request, db, table_name, pks, upsert | |
) | |
if errors: | |
return _error(errors, 400) | |
num_rows = len(rows) | |
# No that we've passed pks to _validate_data it's safe to | |
# fix the rowids case: | |
if not pks: | |
pks = ["rowid"] | |
ignore = extras.get("ignore") | |
replace = extras.get("replace") | |
alter = extras.get("alter") | |
if upsert and (ignore or replace): | |
return _error(["Upsert does not support ignore or replace"], 400) | |
if replace and not await self.ds.permission_allowed( | |
request.actor, "update-row", resource=(database_name, table_name) | |
): | |
return _error(['Permission denied: need update-row to use "replace"'], 403) | |
initial_schema = None | |
if alter: | |
# Must have alter-table permission | |
if not await self.ds.permission_allowed( | |
request.actor, "alter-table", resource=(database_name, table_name) | |
): | |
return _error(["Permission denied for alter-table"], 403) | |
# Track initial schema to check if it changed later | |
initial_schema = await db.execute_fn( | |
lambda conn: sqlite_utils.Database(conn)[table_name].schema | |
) | |
should_return = bool(extras.get("return", False)) | |
row_pk_values_for_later = [] | |
if should_return and upsert: | |
row_pk_values_for_later = [tuple(row[pk] for pk in pks) for row in rows] | |
def insert_or_upsert_rows(conn): | |
table = sqlite_utils.Database(conn)[table_name] | |
kwargs = {} | |
if upsert: | |
kwargs = { | |
"pk": pks[0] if len(pks) == 1 else pks, | |
"alter": alter, | |
} | |
else: | |
# Insert | |
kwargs = {"ignore": ignore, "replace": replace, "alter": alter} | |
if should_return and not upsert: | |
rowids = [] | |
method = table.upsert if upsert else table.insert | |
for row in rows: | |
rowids.append(method(row, **kwargs).last_rowid) | |
return list( | |
table.rows_where( | |
"rowid in ({})".format(",".join("?" for _ in rowids)), | |
rowids, | |
) | |
) | |
else: | |
method_all = table.upsert_all if upsert else table.insert_all | |
method_all(rows, **kwargs) | |
try: | |
rows = await db.execute_write_fn(insert_or_upsert_rows) | |
except Exception as e: | |
return _error([str(e)]) | |
result = {"ok": True} | |
if should_return: | |
if upsert: | |
# Fetch based on initial input IDs | |
where_clause = " OR ".join( | |
["({})".format(" AND ".join("{} = ?".format(pk) for pk in pks))] | |
* len(row_pk_values_for_later) | |
) | |
args = list(itertools.chain.from_iterable(row_pk_values_for_later)) | |
fetched_rows = await db.execute( | |
"select {}* from [{}] where {}".format( | |
"rowid, " if pks == ["rowid"] else "", table_name, where_clause | |
), | |
args, | |
) | |
result["rows"] = fetched_rows.dicts() | |
else: | |
result["rows"] = rows | |
# We track the number of rows requested, but do not attempt to show which were actually | |
# inserted or upserted v.s. ignored | |
if upsert: | |
await self.ds.track_event( | |
UpsertRowsEvent( | |
actor=request.actor, | |
database=database_name, | |
table=table_name, | |
num_rows=num_rows, | |
) | |
) | |
else: | |
await self.ds.track_event( | |
InsertRowsEvent( | |
actor=request.actor, | |
database=database_name, | |
table=table_name, | |
num_rows=num_rows, | |
ignore=bool(ignore), | |
replace=bool(replace), | |
) | |
) | |
if initial_schema is not None: | |
after_schema = await db.execute_fn( | |
lambda conn: sqlite_utils.Database(conn)[table_name].schema | |
) | |
if initial_schema != after_schema: | |
await self.ds.track_event( | |
AlterTableEvent( | |
request.actor, | |
database=database_name, | |
table=table_name, | |
before_schema=initial_schema, | |
after_schema=after_schema, | |
) | |
) | |
return Response.json(result, status=200 if upsert else 201) | |
class TableUpsertView(TableInsertView): | |
name = "table-upsert" | |
async def post(self, request): | |
return await super().post(request, upsert=True) | |
class TableDropView(BaseView): | |
name = "table-drop" | |
def __init__(self, datasette): | |
self.ds = datasette | |
async def post(self, request): | |
try: | |
resolved = await self.ds.resolve_table(request) | |
except NotFound as e: | |
return _error([e.args[0]], 404) | |
db = resolved.db | |
database_name = db.name | |
table_name = resolved.table | |
# Table must exist | |
db = self.ds.get_database(database_name) | |
if not await db.table_exists(table_name): | |
return _error(["Table not found: {}".format(table_name)], 404) | |
if not await self.ds.permission_allowed( | |
request.actor, "drop-table", resource=(database_name, table_name) | |
): | |
return _error(["Permission denied"], 403) | |
if not db.is_mutable: | |
return _error(["Database is immutable"], 403) | |
confirm = False | |
try: | |
data = json.loads(await request.post_body()) | |
confirm = data.get("confirm") | |
except json.JSONDecodeError: | |
pass | |
if not confirm: | |
return Response.json( | |
{ | |
"ok": True, | |
"database": database_name, | |
"table": table_name, | |
"row_count": ( | |
await db.execute("select count(*) from [{}]".format(table_name)) | |
).single_value(), | |
"message": 'Pass "confirm": true to confirm', | |
}, | |
status=200, | |
) | |
# Drop table | |
def drop_table(conn): | |
sqlite_utils.Database(conn)[table_name].drop() | |
await db.execute_write_fn(drop_table) | |
await self.ds.track_event( | |
DropTableEvent( | |
actor=request.actor, database=database_name, table=table_name | |
) | |
) | |
return Response.json({"ok": True}, status=200) | |
def _get_extras(request): | |
extra_bits = request.args.getlist("_extra") | |
extras = set() | |
for bit in extra_bits: | |
extras.update(bit.split(",")) | |
return extras | |
async def _columns_to_select(table_columns, pks, request): | |
columns = list(table_columns) | |
if "_col" in request.args: | |
columns = list(pks) | |
_cols = request.args.getlist("_col") | |
bad_columns = [column for column in _cols if column not in table_columns] | |
if bad_columns: | |
raise DatasetteError( | |
"_col={} - invalid columns".format(", ".join(bad_columns)), | |
status=400, | |
) | |
# De-duplicate maintaining order: | |
columns.extend(dict.fromkeys(_cols)) | |
if "_nocol" in request.args: | |
# Return all columns EXCEPT these | |
bad_columns = [ | |
column | |
for column in request.args.getlist("_nocol") | |
if (column not in table_columns) or (column in pks) | |
] | |
if bad_columns: | |
raise DatasetteError( | |
"_nocol={} - invalid columns".format(", ".join(bad_columns)), | |
status=400, | |
) | |
tmp_columns = [ | |
column for column in columns if column not in request.args.getlist("_nocol") | |
] | |
columns = tmp_columns | |
return columns | |
async def _sortable_columns_for_table(datasette, database_name, table_name, use_rowid): | |
db = datasette.databases[database_name] | |
table_metadata = await datasette.table_config(database_name, table_name) | |
if "sortable_columns" in table_metadata: | |
sortable_columns = set(table_metadata["sortable_columns"]) | |
else: | |
sortable_columns = set(await db.table_columns(table_name)) | |
if use_rowid: | |
sortable_columns.add("rowid") | |
return sortable_columns | |
async def _sort_order(table_metadata, sortable_columns, request, order_by): | |
sort = request.args.get("_sort") | |
sort_desc = request.args.get("_sort_desc") | |
if not sort and not sort_desc: | |
sort = table_metadata.get("sort") | |
sort_desc = table_metadata.get("sort_desc") | |
if sort and sort_desc: | |
raise DatasetteError( | |
"Cannot use _sort and _sort_desc at the same time", status=400 | |
) | |
if sort: | |
if sort not in sortable_columns: | |
raise DatasetteError(f"Cannot sort table by {sort}", status=400) | |
order_by = escape_sqlite(sort) | |
if sort_desc: | |
if sort_desc not in sortable_columns: | |
raise DatasetteError(f"Cannot sort table by {sort_desc}", status=400) | |
order_by = f"{escape_sqlite(sort_desc)} desc" | |
return sort, sort_desc, order_by | |
async def table_view(datasette, request): | |
await datasette.refresh_schemas() | |
with tracer.trace_child_tasks(): | |
response = await table_view_traced(datasette, request) | |
# CORS | |
if datasette.cors: | |
add_cors_headers(response.headers) | |
# Cache TTL header | |
ttl = request.args.get("_ttl", None) | |
if ttl is None or not ttl.isdigit(): | |
ttl = datasette.setting("default_cache_ttl") | |
if datasette.cache_headers and response.status == 200: | |
ttl = int(ttl) | |
if ttl == 0: | |
ttl_header = "no-cache" | |
else: | |
ttl_header = f"max-age={ttl}" | |
response.headers["Cache-Control"] = ttl_header | |
# Referrer policy | |
response.headers["Referrer-Policy"] = "no-referrer" | |
return response | |
async def table_view_traced(datasette, request): | |
from datasette.app import TableNotFound | |
try: | |
resolved = await datasette.resolve_table(request) | |
except TableNotFound as not_found: | |
# Was this actually a canned query? | |
canned_query = await datasette.get_canned_query( | |
not_found.database_name, not_found.table, request.actor | |
) | |
# If this is a canned query, not a table, then dispatch to QueryView instead | |
if canned_query: | |
return await QueryView()(request, datasette) | |
else: | |
raise | |
if request.method == "POST": | |
return Response.text("Method not allowed", status=405) | |
format_ = request.url_vars.get("format") or "html" | |
extra_extras = None | |
context_for_html_hack = False | |
default_labels = False | |
if format_ == "html": | |
extra_extras = {"_html"} | |
context_for_html_hack = True | |
default_labels = True | |
view_data = await table_view_data( | |
datasette, | |
request, | |
resolved, | |
extra_extras=extra_extras, | |
context_for_html_hack=context_for_html_hack, | |
default_labels=default_labels, | |
) | |
if isinstance(view_data, Response): | |
return view_data | |
data, rows, columns, expanded_columns, sql, next_url = view_data | |
# Handle formats from plugins | |
if format_ == "csv": | |
async def fetch_data(request, _next=None): | |
( | |
data, | |
rows, | |
columns, | |
expanded_columns, | |
sql, | |
next_url, | |
) = await table_view_data( | |
datasette, | |
request, | |
resolved, | |
extra_extras=extra_extras, | |
context_for_html_hack=context_for_html_hack, | |
default_labels=default_labels, | |
_next=_next, | |
) | |
data["rows"] = rows | |
data["table"] = resolved.table | |
data["columns"] = columns | |
data["expanded_columns"] = expanded_columns | |
return data, None, None | |
return await stream_csv(datasette, fetch_data, request, resolved.db.name) | |
elif format_ in datasette.renderers.keys(): | |
# Dispatch request to the correct output format renderer | |
# (CSV is not handled here due to streaming) | |
result = call_with_supported_arguments( | |
datasette.renderers[format_][0], | |
datasette=datasette, | |
columns=columns, | |
rows=rows, | |
sql=sql, | |
query_name=None, | |
database=resolved.db.name, | |
table=resolved.table, | |
request=request, | |
view_name="table", | |
truncated=False, | |
error=None, | |
# These will be deprecated in Datasette 1.0: | |
args=request.args, | |
data=data, | |
) | |
if asyncio.iscoroutine(result): | |
result = await result | |
if result is None: | |
raise NotFound("No data") | |
if isinstance(result, dict): | |
r = Response( | |
body=result.get("body"), | |
status=result.get("status_code") or 200, | |
content_type=result.get("content_type", "text/plain"), | |
headers=result.get("headers"), | |
) | |
elif isinstance(result, Response): | |
r = result | |
# if status_code is not None: | |
# # Over-ride the status code | |
# r.status = status_code | |
else: | |
assert False, f"{result} should be dict or Response" | |
elif format_ == "html": | |
headers = {} | |
templates = [ | |
f"table-{to_css_class(resolved.db.name)}-{to_css_class(resolved.table)}.html", | |
"table.html", | |
] | |
environment = datasette.get_jinja_environment(request) | |
template = environment.select_template(templates) | |
alternate_url_json = datasette.absolute_url( | |
request, | |
datasette.urls.path(path_with_format(request=request, format="json")), | |
) | |
headers.update( | |
{ | |
"Link": '{}; rel="alternate"; type="application/json+datasette"'.format( | |
alternate_url_json | |
) | |
} | |
) | |
r = Response.html( | |
await datasette.render_template( | |
template, | |
dict( | |
data, | |
append_querystring=append_querystring, | |
path_with_replaced_args=path_with_replaced_args, | |
fix_path=datasette.urls.path, | |
settings=datasette.settings_dict(), | |
# TODO: review up all of these hacks: | |
alternate_url_json=alternate_url_json, | |
datasette_allow_facet=( | |
"true" if datasette.setting("allow_facet") else "false" | |
), | |
is_sortable=any(c["sortable"] for c in data["display_columns"]), | |
allow_execute_sql=await datasette.permission_allowed( | |
request.actor, "execute-sql", resolved.db.name | |
), | |
query_ms=1.2, | |
select_templates=[ | |
f"{'*' if template_name == template.name else ''}{template_name}" | |
for template_name in templates | |
], | |
top_table=make_slot_function( | |
"top_table", | |
datasette, | |
request, | |
database=resolved.db.name, | |
table=resolved.table, | |
), | |
count_limit=resolved.db.count_limit, | |
), | |
request=request, | |
view_name="table", | |
), | |
headers=headers, | |
) | |
else: | |
assert False, "Invalid format: {}".format(format_) | |
if next_url: | |
r.headers["link"] = f'<{next_url}>; rel="next"' | |
return r | |
async def table_view_data( | |
datasette, | |
request, | |
resolved, | |
extra_extras=None, | |
context_for_html_hack=False, | |
default_labels=False, | |
_next=None, | |
): | |
extra_extras = extra_extras or set() | |
# We have a table or view | |
db = resolved.db | |
database_name = resolved.db.name | |
table_name = resolved.table | |
is_view = resolved.is_view | |
# Can this user view it? | |
visible, private = await datasette.check_visibility( | |
request.actor, | |
permissions=[ | |
("view-table", (database_name, table_name)), | |
("view-database", database_name), | |
"view-instance", | |
], | |
) | |
if not visible: | |
raise Forbidden("You do not have permission to view this table") | |
# Redirect based on request.args, if necessary | |
redirect_response = await _redirect_if_needed(datasette, request, resolved) | |
if redirect_response: | |
return redirect_response | |
# Introspect columns and primary keys for table | |
pks = await db.primary_keys(table_name) | |
table_columns = await db.table_columns(table_name) | |
# Take ?_col= and ?_nocol= into account | |
specified_columns = await _columns_to_select(table_columns, pks, request) | |
select_specified_columns = ", ".join(escape_sqlite(t) for t in specified_columns) | |
select_all_columns = ", ".join(escape_sqlite(t) for t in table_columns) | |
# rowid tables (no specified primary key) need a different SELECT | |
use_rowid = not pks and not is_view | |
order_by = "" | |
if use_rowid: | |
select_specified_columns = f"rowid, {select_specified_columns}" | |
select_all_columns = f"rowid, {select_all_columns}" | |
order_by = "rowid" | |
order_by_pks = "rowid" | |
else: | |
order_by_pks = ", ".join([escape_sqlite(pk) for pk in pks]) | |
order_by = order_by_pks | |
if is_view: | |
order_by = "" | |
# TODO: This logic should turn into logic about which ?_extras get | |
# executed instead: | |
nocount = request.args.get("_nocount") | |
nofacet = request.args.get("_nofacet") | |
nosuggest = request.args.get("_nosuggest") | |
if request.args.get("_shape") in ("array", "object"): | |
nocount = True | |
nofacet = True | |
table_metadata = await datasette.table_config(database_name, table_name) | |
# Arguments that start with _ and don't contain a __ are | |
# special - things like ?_search= - and should not be | |
# treated as filters. | |
filter_args = [] | |
for key in request.args: | |
if not (key.startswith("_") and "__" not in key): | |
for v in request.args.getlist(key): | |
filter_args.append((key, v)) | |
# Build where clauses from query string arguments | |
filters = Filters(sorted(filter_args)) | |
where_clauses, params = filters.build_where_clauses(table_name) | |
# Execute filters_from_request plugin hooks - including the default | |
# ones that live in datasette/filters.py | |
extra_context_from_filters = {} | |
extra_human_descriptions = [] | |
for hook in pm.hook.filters_from_request( | |
request=request, | |
table=table_name, | |
database=database_name, | |
datasette=datasette, | |
): | |
filter_arguments = await await_me_maybe(hook) | |
if filter_arguments: | |
where_clauses.extend(filter_arguments.where_clauses) | |
params.update(filter_arguments.params) | |
extra_human_descriptions.extend(filter_arguments.human_descriptions) | |
extra_context_from_filters.update(filter_arguments.extra_context) | |
# Deal with custom sort orders | |
sortable_columns = await _sortable_columns_for_table( | |
datasette, database_name, table_name, use_rowid | |
) | |
sort, sort_desc, order_by = await _sort_order( | |
table_metadata, sortable_columns, request, order_by | |
) | |
from_sql = "from {table_name} {where}".format( | |
table_name=escape_sqlite(table_name), | |
where=( | |
("where {} ".format(" and ".join(where_clauses))) if where_clauses else "" | |
), | |
) | |
# Copy of params so we can mutate them later: | |
from_sql_params = dict(**params) | |
count_sql = f"select count(*) {from_sql}" | |
# Handle pagination driven by ?_next= | |
_next = _next or request.args.get("_next") | |
offset = "" | |
if _next: | |
sort_value = None | |
if is_view: | |
# _next is an offset | |
offset = f" offset {int(_next)}" | |
else: | |
components = urlsafe_components(_next) | |
# If a sort order is applied and there are multiple components, | |
# the first of these is the sort value | |
if (sort or sort_desc) and (len(components) > 1): | |
sort_value = components[0] | |
# Special case for if non-urlencoded first token was $null | |
if _next.split(",")[0] == "$null": | |
sort_value = None | |
components = components[1:] | |
# Figure out the SQL for next-based-on-primary-key first | |
next_by_pk_clauses = [] | |
if use_rowid: | |
next_by_pk_clauses.append(f"rowid > :p{len(params)}") | |
params[f"p{len(params)}"] = components[0] | |
else: | |
# Apply the tie-breaker based on primary keys | |
if len(components) == len(pks): | |
param_len = len(params) | |
next_by_pk_clauses.append(compound_keys_after_sql(pks, param_len)) | |
for i, pk_value in enumerate(components): | |
params[f"p{param_len + i}"] = pk_value | |
# Now add the sort SQL, which may incorporate next_by_pk_clauses | |
if sort or sort_desc: | |
if sort_value is None: | |
if sort_desc: | |
# Just items where column is null ordered by pk | |
where_clauses.append( | |
"({column} is null and {next_clauses})".format( | |
column=escape_sqlite(sort_desc), | |
next_clauses=" and ".join(next_by_pk_clauses), | |
) | |
) | |
else: | |
where_clauses.append( | |
"({column} is not null or ({column} is null and {next_clauses}))".format( | |
column=escape_sqlite(sort), | |
next_clauses=" and ".join(next_by_pk_clauses), | |
) | |
) | |
else: | |
where_clauses.append( | |
"({column} {op} :p{p}{extra_desc_only} or ({column} = :p{p} and {next_clauses}))".format( | |
column=escape_sqlite(sort or sort_desc), | |
op=">" if sort else "<", | |
p=len(params), | |
extra_desc_only=( | |
"" | |
if sort | |
else " or {column2} is null".format( | |
column2=escape_sqlite(sort or sort_desc) | |
) | |
), | |
next_clauses=" and ".join(next_by_pk_clauses), | |
) | |
) | |
params[f"p{len(params)}"] = sort_value | |
order_by = f"{order_by}, {order_by_pks}" | |
else: | |
where_clauses.extend(next_by_pk_clauses) | |
where_clause = "" | |
if where_clauses: | |
where_clause = f"where {' and '.join(where_clauses)} " | |
if order_by: | |
order_by = f"order by {order_by}" | |
extra_args = {} | |
# Handle ?_size=500 | |
# TODO: This was: | |
# page_size = _size or request.args.get("_size") or table_metadata.get("size") | |
page_size = request.args.get("_size") or table_metadata.get("size") | |
if page_size: | |
if page_size == "max": | |
page_size = datasette.max_returned_rows | |
try: | |
page_size = int(page_size) | |
if page_size < 0: | |
raise ValueError | |
except ValueError: | |
raise BadRequest("_size must be a positive integer") | |
if page_size > datasette.max_returned_rows: | |
raise BadRequest(f"_size must be <= {datasette.max_returned_rows}") | |
extra_args["page_size"] = page_size | |
else: | |
page_size = datasette.page_size | |
# Facets are calculated against SQL without order by or limit | |
sql_no_order_no_limit = ( | |
"select {select_all_columns} from {table_name} {where}".format( | |
select_all_columns=select_all_columns, | |
table_name=escape_sqlite(table_name), | |
where=where_clause, | |
) | |
) | |
# This is the SQL that populates the main table on the page | |
sql = "select {select_specified_columns} from {table_name} {where}{order_by} limit {page_size}{offset}".format( | |
select_specified_columns=select_specified_columns, | |
table_name=escape_sqlite(table_name), | |
where=where_clause, | |
order_by=order_by, | |
page_size=page_size + 1, | |
offset=offset, | |
) | |
if request.args.get("_timelimit"): | |
extra_args["custom_time_limit"] = int(request.args.get("_timelimit")) | |
# Execute the main query! | |
try: | |
results = await db.execute(sql, params, truncate=True, **extra_args) | |
except (sqlite3.OperationalError, InvalidSql) as e: | |
raise DatasetteError(str(e), title="Invalid SQL", status=400) | |
except sqlite3.OperationalError as e: | |
raise DatasetteError(str(e)) | |
columns = [r[0] for r in results.description] | |
rows = list(results.rows) | |
# Expand labeled columns if requested | |
expanded_columns = [] | |
# List of (fk_dict, label_column-or-None) pairs for that table | |
expandable_columns = [] | |
for fk in await db.foreign_keys_for_table(table_name): | |
label_column = await db.label_column_for_table(fk["other_table"]) | |
expandable_columns.append((fk, label_column)) | |
columns_to_expand = None | |
try: | |
all_labels = value_as_boolean(request.args.get("_labels", "")) | |
except ValueError: | |
all_labels = default_labels | |
# Check for explicit _label= | |
if "_label" in request.args: | |
columns_to_expand = request.args.getlist("_label") | |
if columns_to_expand is None and all_labels: | |
# expand all columns with foreign keys | |
columns_to_expand = [fk["column"] for fk, _ in expandable_columns] | |
if columns_to_expand: | |
expanded_labels = {} | |
for fk, _ in expandable_columns: | |
column = fk["column"] | |
if column not in columns_to_expand: | |
continue | |
if column not in columns: | |
continue | |
expanded_columns.append(column) | |
# Gather the values | |
column_index = columns.index(column) | |
values = [row[column_index] for row in rows] | |
# Expand them | |
expanded_labels.update( | |
await datasette.expand_foreign_keys( | |
request.actor, database_name, table_name, column, values | |
) | |
) | |
if expanded_labels: | |
# Rewrite the rows | |
new_rows = [] | |
for row in rows: | |
new_row = CustomRow(columns) | |
for column in row.keys(): | |
value = row[column] | |
if (column, value) in expanded_labels and value is not None: | |
new_row[column] = { | |
"value": value, | |
"label": expanded_labels[(column, value)], | |
} | |
else: | |
new_row[column] = value | |
new_rows.append(new_row) | |
rows = new_rows | |
_next = request.args.get("_next") | |
# Pagination next link | |
next_value, next_url = await _next_value_and_url( | |
datasette, | |
db, | |
request, | |
table_name, | |
_next, | |
rows, | |
pks, | |
use_rowid, | |
sort, | |
sort_desc, | |
page_size, | |
is_view, | |
) | |
rows = rows[:page_size] | |
# Resolve extras | |
extras = _get_extras(request) | |
if any(k for k in request.args.keys() if k == "_facet" or k.startswith("_facet_")): | |
extras.add("facet_results") | |
if request.args.get("_shape") == "object": | |
extras.add("primary_keys") | |
if extra_extras: | |
extras.update(extra_extras) | |
async def extra_count_sql(): | |
return count_sql | |
async def extra_count(): | |
"Total count of rows matching these filters" | |
# Calculate the total count for this query | |
count = None | |
if ( | |
not db.is_mutable | |
and datasette.inspect_data | |
and count_sql == f"select count(*) from {table_name} " | |
): | |
# We can use a previously cached table row count | |
try: | |
count = datasette.inspect_data[database_name]["tables"][table_name][ | |
"count" | |
] | |
except KeyError: | |
pass | |
# Otherwise run a select count(*) ... | |
if count_sql and count is None and not nocount: | |
count_sql_limited = ( | |
f"select count(*) from (select * {from_sql} limit 10001)" | |
) | |
try: | |
count_rows = list(await db.execute(count_sql_limited, from_sql_params)) | |
count = count_rows[0][0] | |
except QueryInterrupted: | |
pass | |
return count | |
async def facet_instances(extra_count): | |
facet_instances = [] | |
facet_classes = list( | |
itertools.chain.from_iterable(pm.hook.register_facet_classes()) | |
) | |
for facet_class in facet_classes: | |
facet_instances.append( | |
facet_class( | |
datasette, | |
request, | |
database_name, | |
sql=sql_no_order_no_limit, | |
params=params, | |
table=table_name, | |
table_config=table_metadata, | |
row_count=extra_count, | |
) | |
) | |
return facet_instances | |
async def extra_facet_results(facet_instances): | |
"Results of facets calculated against this data" | |
facet_results = {} | |
facets_timed_out = [] | |
if not nofacet: | |
# Run them in parallel | |
facet_awaitables = [facet.facet_results() for facet in facet_instances] | |
facet_awaitable_results = await run_sequential(*facet_awaitables) | |
for ( | |
instance_facet_results, | |
instance_facets_timed_out, | |
) in facet_awaitable_results: | |
for facet_info in instance_facet_results: | |
base_key = facet_info["name"] | |
key = base_key | |
i = 1 | |
while key in facet_results: | |
i += 1 | |
key = f"{base_key}_{i}" | |
facet_results[key] = facet_info | |
facets_timed_out.extend(instance_facets_timed_out) | |
return { | |
"results": facet_results, | |
"timed_out": facets_timed_out, | |
} | |
async def extra_suggested_facets(facet_instances): | |
"Suggestions for facets that might return interesting results" | |
suggested_facets = [] | |
# Calculate suggested facets | |
if ( | |
datasette.setting("suggest_facets") | |
and datasette.setting("allow_facet") | |
and not _next | |
and not nofacet | |
and not nosuggest | |
): | |
# Run them in parallel | |
facet_suggest_awaitables = [facet.suggest() for facet in facet_instances] | |
for suggest_result in await run_sequential(*facet_suggest_awaitables): | |
suggested_facets.extend(suggest_result) | |
return suggested_facets | |
# Faceting | |
if not datasette.setting("allow_facet") and any( | |
arg.startswith("_facet") for arg in request.args | |
): | |
raise BadRequest("_facet= is not allowed") | |
# human_description_en combines filters AND search, if provided | |
async def extra_human_description_en(): | |
"Human-readable description of the filters" | |
human_description_en = filters.human_description_en( | |
extra=extra_human_descriptions | |
) | |
if sort or sort_desc: | |
human_description_en = " ".join( | |
[b for b in [human_description_en, sorted_by] if b] | |
) | |
return human_description_en | |
if sort or sort_desc: | |
sorted_by = "sorted by {}{}".format( | |
(sort or sort_desc), " descending" if sort_desc else "" | |
) | |
async def extra_next_url(): | |
"Full URL for the next page of results" | |
return next_url | |
async def extra_columns(): | |
"Column names returned by this query" | |
return columns | |
async def extra_primary_keys(): | |
"Primary keys for this table" | |
return pks | |
async def extra_actions(): | |
async def actions(): | |
links = [] | |
kwargs = { | |
"datasette": datasette, | |
"database": database_name, | |
"actor": request.actor, | |
"request": request, | |
} | |
if is_view: | |
kwargs["view"] = table_name | |
method = pm.hook.view_actions | |
else: | |
kwargs["table"] = table_name | |
method = pm.hook.table_actions | |
for hook in method(**kwargs): | |
extra_links = await await_me_maybe(hook) | |
if extra_links: | |
links.extend(extra_links) | |
return links | |
return actions | |
async def extra_is_view(): | |
return is_view | |
async def extra_debug(): | |
"Extra debug information" | |
return { | |
"resolved": repr(resolved), | |
"url_vars": request.url_vars, | |
"nofacet": nofacet, | |
"nosuggest": nosuggest, | |
} | |
async def extra_request(): | |
"Full information about the request" | |
return { | |
"url": request.url, | |
"path": request.path, | |
"full_path": request.full_path, | |
"host": request.host, | |
"args": request.args._data, | |
} | |
async def run_display_columns_and_rows(): | |
display_columns, display_rows = await display_columns_and_rows( | |
datasette, | |
database_name, | |
table_name, | |
results.description, | |
rows, | |
link_column=not is_view, | |
truncate_cells=datasette.setting("truncate_cells_html"), | |
sortable_columns=sortable_columns, | |
request=request, | |
) | |
return { | |
"columns": display_columns, | |
"rows": display_rows, | |
} | |
async def extra_display_columns(run_display_columns_and_rows): | |
return run_display_columns_and_rows["columns"] | |
async def extra_display_rows(run_display_columns_and_rows): | |
return run_display_columns_and_rows["rows"] | |
async def extra_query(): | |
"Details of the underlying SQL query" | |
return { | |
"sql": sql, | |
"params": params, | |
} | |
async def extra_metadata(): | |
"Metadata about the table and database" | |
tablemetadata = await datasette.get_resource_metadata(database_name, table_name) | |
rows = await datasette.get_internal_database().execute( | |
""" | |
SELECT | |
column_name, | |
value | |
FROM metadata_columns | |
WHERE database_name = ? | |
AND resource_name = ? | |
AND key = 'description' | |
""", | |
[database_name, table_name], | |
) | |
tablemetadata["columns"] = dict(rows) | |
return tablemetadata | |
async def extra_database(): | |
return database_name | |
async def extra_table(): | |
return table_name | |
async def extra_database_color(): | |
return db.color | |
async def extra_form_hidden_args(): | |
form_hidden_args = [] | |
for key in request.args: | |
if ( | |
key.startswith("_") | |
and key not in ("_sort", "_sort_desc", "_search", "_next") | |
and "__" not in key | |
): | |
for value in request.args.getlist(key): | |
form_hidden_args.append((key, value)) | |
return form_hidden_args | |
async def extra_filters(): | |
return filters | |
async def extra_custom_table_templates(): | |
return [ | |
f"_table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", | |
f"_table-table-{to_css_class(database_name)}-{to_css_class(table_name)}.html", | |
"_table.html", | |
] | |
async def extra_sorted_facet_results(extra_facet_results): | |
return sorted( | |
extra_facet_results["results"].values(), | |
key=lambda f: (len(f["results"]), f["name"]), | |
reverse=True, | |
) | |
async def extra_table_definition(): | |
return await db.get_table_definition(table_name) | |
async def extra_view_definition(): | |
return await db.get_view_definition(table_name) | |
async def extra_renderers(extra_expandable_columns, extra_query): | |
renderers = {} | |
url_labels_extra = {} | |
if extra_expandable_columns: | |
url_labels_extra = {"_labels": "on"} | |
for key, (_, can_render) in datasette.renderers.items(): | |
it_can_render = call_with_supported_arguments( | |
can_render, | |
datasette=datasette, | |
columns=columns or [], | |
rows=rows or [], | |
sql=extra_query.get("sql", None), | |
query_name=None, | |
database=database_name, | |
table=table_name, | |
request=request, | |
view_name="table", | |
) | |
it_can_render = await await_me_maybe(it_can_render) | |
if it_can_render: | |
renderers[key] = datasette.urls.path( | |
path_with_format( | |
request=request, format=key, extra_qs={**url_labels_extra} | |
) | |
) | |
return renderers | |
async def extra_private(): | |
return private | |
async def extra_expandable_columns(): | |
expandables = [] | |
db = datasette.databases[database_name] | |
for fk in await db.foreign_keys_for_table(table_name): | |
label_column = await db.label_column_for_table(fk["other_table"]) | |
expandables.append((fk, label_column)) | |
return expandables | |
async def extra_extras(): | |
"Available ?_extra= blocks" | |
all_extras = [ | |
(key[len("extra_") :], fn.__doc__) | |
for key, fn in registry._registry.items() | |
if key.startswith("extra_") | |
] | |
return [ | |
{ | |
"name": name, | |
"description": doc, | |
"toggle_url": datasette.absolute_url( | |
request, | |
datasette.urls.path( | |
path_with_added_args(request, {"_extra": name}) | |
if name not in extras | |
else path_with_removed_args(request, {"_extra": name}) | |
), | |
), | |
"selected": name in extras, | |
} | |
for name, doc in all_extras | |
] | |
async def extra_facets_timed_out(extra_facet_results): | |
return extra_facet_results["timed_out"] | |
bundles = { | |
"html": [ | |
"suggested_facets", | |
"facet_results", | |
"facets_timed_out", | |
"count", | |
"count_sql", | |
"human_description_en", | |
"next_url", | |
"metadata", | |
"query", | |
"columns", | |
"display_columns", | |
"display_rows", | |
"database", | |
"table", | |
"database_color", | |
"actions", | |
"filters", | |
"renderers", | |
"custom_table_templates", | |
"sorted_facet_results", | |
"table_definition", | |
"view_definition", | |
"is_view", | |
"private", | |
"primary_keys", | |
"expandable_columns", | |
"form_hidden_args", | |
] | |
} | |
for key, values in bundles.items(): | |
if f"_{key}" in extras: | |
extras.update(values) | |
extras.discard(f"_{key}") | |
registry = Registry( | |
extra_count, | |
extra_count_sql, | |
extra_facet_results, | |
extra_facets_timed_out, | |
extra_suggested_facets, | |
facet_instances, | |
extra_human_description_en, | |
extra_next_url, | |
extra_columns, | |
extra_primary_keys, | |
run_display_columns_and_rows, | |
extra_display_columns, | |
extra_display_rows, | |
extra_debug, | |
extra_request, | |
extra_query, | |
extra_metadata, | |
extra_extras, | |
extra_database, | |
extra_table, | |
extra_database_color, | |
extra_actions, | |
extra_filters, | |
extra_renderers, | |
extra_custom_table_templates, | |
extra_sorted_facet_results, | |
extra_table_definition, | |
extra_view_definition, | |
extra_is_view, | |
extra_private, | |
extra_expandable_columns, | |
extra_form_hidden_args, | |
) | |
results = await registry.resolve_multi( | |
["extra_{}".format(extra) for extra in extras] | |
) | |
data = { | |
"ok": True, | |
"next": next_value and str(next_value) or None, | |
} | |
data.update( | |
{ | |
key.replace("extra_", ""): value | |
for key, value in results.items() | |
if key.startswith("extra_") and key.replace("extra_", "") in extras | |
} | |
) | |
raw_sqlite_rows = rows[:page_size] | |
data["rows"] = [dict(r) for r in raw_sqlite_rows] | |
if context_for_html_hack: | |
data.update(extra_context_from_filters) | |
# filter_columns combine the columns we know are available | |
# in the table with any additional columns (such as rowid) | |
# which are available in the query | |
data["filter_columns"] = list(columns) + [ | |
table_column | |
for table_column in table_columns | |
if table_column not in columns | |
] | |
url_labels_extra = {} | |
if data.get("expandable_columns"): | |
url_labels_extra = {"_labels": "on"} | |
url_csv_args = {"_size": "max", **url_labels_extra} | |
url_csv = datasette.urls.path( | |
path_with_format(request=request, format="csv", extra_qs=url_csv_args) | |
) | |
url_csv_path = url_csv.split("?")[0] | |
data.update( | |
{ | |
"url_csv": url_csv, | |
"url_csv_path": url_csv_path, | |
"url_csv_hidden_args": [ | |
(key, value) | |
for key, value in urllib.parse.parse_qsl(request.query_string) | |
if key not in ("_labels", "_facet", "_size") | |
] | |
+ [("_size", "max")], | |
} | |
) | |
# if no sort specified AND table has a single primary key, | |
# set sort to that so arrow is displayed | |
if not sort and not sort_desc: | |
if 1 == len(pks): | |
sort = pks[0] | |
elif use_rowid: | |
sort = "rowid" | |
data["sort"] = sort | |
data["sort_desc"] = sort_desc | |
return data, rows[:page_size], columns, expanded_columns, sql, next_url | |
async def _next_value_and_url( | |
datasette, | |
db, | |
request, | |
table_name, | |
_next, | |
rows, | |
pks, | |
use_rowid, | |
sort, | |
sort_desc, | |
page_size, | |
is_view, | |
): | |
next_value = None | |
next_url = None | |
if 0 < page_size < len(rows): | |
if is_view: | |
next_value = int(_next or 0) + page_size | |
else: | |
next_value = path_from_row_pks(rows[-2], pks, use_rowid) | |
# If there's a sort or sort_desc, add that value as a prefix | |
if (sort or sort_desc) and not is_view: | |
try: | |
prefix = rows[-2][sort or sort_desc] | |
except IndexError: | |
# sort/sort_desc column missing from SELECT - look up value by PK instead | |
prefix_where_clause = " and ".join( | |
"[{}] = :pk{}".format(pk, i) for i, pk in enumerate(pks) | |
) | |
prefix_lookup_sql = "select [{}] from [{}] where {}".format( | |
sort or sort_desc, table_name, prefix_where_clause | |
) | |
prefix = ( | |
await db.execute( | |
prefix_lookup_sql, | |
{ | |
**{ | |
"pk{}".format(i): rows[-2][pk] | |
for i, pk in enumerate(pks) | |
} | |
}, | |
) | |
).single_value() | |
if isinstance(prefix, dict) and "value" in prefix: | |
prefix = prefix["value"] | |
if prefix is None: | |
prefix = "$null" | |
else: | |
prefix = tilde_encode(str(prefix)) | |
next_value = f"{prefix},{next_value}" | |
added_args = {"_next": next_value} | |
if sort: | |
added_args["_sort"] = sort | |
else: | |
added_args["_sort_desc"] = sort_desc | |
else: | |
added_args = {"_next": next_value} | |
next_url = datasette.absolute_url( | |
request, datasette.urls.path(path_with_replaced_args(request, added_args)) | |
) | |
return next_value, next_url | |
</document_content> | |
</document> | |
</documents> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can’t perform that action at this time.