initial commit, building blocks for the later usage

Signed-off-by: Uncle Stinky <uncle.stinky@ghostchain.io>
This commit is contained in:
Uncle Stinky 2024-11-18 16:58:38 +03:00
commit 51e78f29e7
Signed by: st1nky
GPG Key ID: 016064BD97603B40
50 changed files with 10243 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
debug/
target/
**/*.rs.bk
*.pdb

2599
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

14
Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[workspace]
resolver = "2"
members = [
"common",
"telemetry-core",
"telemetry-shard",
]
[profile.dev]
opt-level = 3
[profile.release]
lto = true
panic = "abort"

35
common/Cargo.toml Normal file
View File

@ -0,0 +1,35 @@
[package]
name = "ghost-telemetry-common"
version = "0.1.0"
authors = ["Uncle Stinky uncle.stinky@ghostchain.io"]
edition = "2021"
[dependencies]
anyhow = "1.0.42"
arrayvec = { version = "0.7.1", features = ["serde"] }
base64 = { version = "0.21", default-features = false, features = ["alloc"] }
bimap = "0.6.1"
bytes = "1.0.1"
flume = "0.10.8"
fnv = "1.0.7"
futures = "0.3.15"
hex = "0.4.3"
http = "0.2.4"
hyper = { version = "0.14.11", features = ["full"] }
log = "0.4.14"
num-traits = "0.2"
pin-project-lite = "0.2.7"
primitive-types = { version = "0.12.1", features = ["serde"] }
rustc-hash = "1.1.0"
serde = { version = "1.0.126", features = ["derive"] }
serde_json = { version = "1.0.64", features = ["raw_value"] }
sha-1 = { version = "0.10.1", default-features = false }
soketto = "0.7.1"
thiserror = "1.0.24"
tokio = { version = "1.10.1", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["compat"] }
tokio-rustls = "0.23.4"
webpki-roots = "0.22.4"
[dev-dependencies]
bincode = "1.3.3"

73
common/src/assign_id.rs Normal file
View File

@ -0,0 +1,73 @@
use bimap::BiMap;
use std::hash::Hash;
/// A struct that allows you to assign an Id to an arbitrary set of
/// details (so long as they are Eq+Hash+Clone), and then access
/// the assigned Id given those details or access the details given
/// the Id.
///
/// The Id can be any type that's convertible to/from a `usize`. Using
/// a custom type is recommended for increased type safety.
#[derive(Debug)]
pub struct AssignId<Id, Details> {
current_id: usize,
mapping: BiMap<usize, Details>,
_id_type: std::marker::PhantomData<Id>,
}
impl<Id, Details> AssignId<Id, Details>
where
Details: Eq + Hash,
Id: From<usize> + Copy,
usize: From<Id>,
{
pub fn new() -> Self {
Self {
current_id: 0,
mapping: BiMap::new(),
_id_type: std::marker::PhantomData,
}
}
pub fn assign_id(&mut self, details: Details) -> Id {
let this_id = self.current_id;
// It's very unlikely we'll ever overflow the ID limit, but in case we do,
// a wrapping_add will almost certainly be fine:
self.current_id = self.current_id.wrapping_add(1);
self.mapping.insert(this_id, details);
this_id.into()
}
pub fn get_details(&mut self, id: Id) -> Option<&Details> {
self.mapping.get_by_left(&id.into())
}
pub fn get_id(&mut self, details: &Details) -> Option<Id> {
self.mapping.get_by_right(details).map(|&id| id.into())
}
pub fn remove_by_id(&mut self, id: Id) -> Option<Details> {
self.mapping
.remove_by_left(&id.into())
.map(|(_, details)| details)
}
pub fn remove_by_details(&mut self, details: &Details) -> Option<Id> {
self.mapping
.remove_by_right(&details)
.map(|(id, _)| id.into())
}
pub fn clear(&mut self) {
// Leave the `current_id` as-is. Why? To avoid reusing IDs and risking
// race conditions where old messages can accidentally screw with new nodes
// that have been assigned the same ID.
self.mapping = BiMap::new();
}
pub fn iter(&self) -> impl Iterator<Item = (Id, &Details)> {
self.mapping
.iter()
.map(|(&id, details)| (id.into(), details))
}
}

92
common/src/byte_size.rs Normal file
View File

@ -0,0 +1,92 @@
use anyhow::{anyhow, Error};
#[derive(Copy, Clone, Debug)]
pub struct ByteSize(usize);
impl ByteSize {
pub fn new(bytes: usize) -> ByteSize {
ByteSize(bytes)
}
/// Return the number of bytes stored within.
pub fn num_bytes(self) -> usize {
self.0
}
}
impl From<ByteSize> for usize {
fn from(b: ByteSize) -> Self {
b.0
}
}
impl std::str::FromStr for ByteSize {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
match s.find(|c| !char::is_ascii_digit(&c)) {
// No non-numeric chars; assume bytes then
None => Ok(ByteSize(s.parse().expect("all ascii digits"))),
// First non-numeric char
Some(idx) => {
let n = s[..idx].parse().expect("all ascii digits");
let suffix = s[idx..].trim();
let n = match suffix {
"B" | "b" => n,
"kB" | "K" | "k" => n * 1000,
"MB" | "M" | "m" => n * 1000 * 1000,
"GB" | "G" | "g" => n * 1000 * 1000 * 1000,
"KiB" | "Ki" => n * 1024,
"MiB" | "Mi" => n * 1024 * 1024,
"GiB" | "Gi" => n * 1024 * 1024 * 1024,
_ => {
return Err(anyhow!(
"\
Cannot parse into bytes; suffix is '{}', but expecting one of \
B,b, kB,K,k, MB,M,m, GB,G,g, KiB,Ki, MiB,Mi, GiB,Gi",
suffix
))
}
};
Ok(ByteSize(n))
}
}
}
}
#[cfg(test)]
mod test {
use crate::byte_size::ByteSize;
#[test]
fn can_parse_valid_strings() {
let cases = vec![
("100", 100),
("100B", 100),
("100b", 100),
("20kB", 20 * 1000),
("20 kB", 20 * 1000),
("20K", 20 * 1000),
(" 20k", 20 * 1000),
("1MB", 1 * 1000 * 1000),
("1M", 1 * 1000 * 1000),
("1m", 1 * 1000 * 1000),
("1 m", 1 * 1000 * 1000),
("1GB", 1 * 1000 * 1000 * 1000),
("1G", 1 * 1000 * 1000 * 1000),
("1g", 1 * 1000 * 1000 * 1000),
("1KiB", 1 * 1024),
("1Ki", 1 * 1024),
("1MiB", 1 * 1024 * 1024),
("1Mi", 1 * 1024 * 1024),
("1GiB", 1 * 1024 * 1024 * 1024),
("1Gi", 1 * 1024 * 1024 * 1024),
(" 1 Gi ", 1 * 1024 * 1024 * 1024),
];
for (s, expected) in cases {
let b: ByteSize = s.parse().unwrap();
assert_eq!(b.num_bytes(), expected);
}
}
}

149
common/src/dense_map.rs Normal file
View File

@ -0,0 +1,149 @@
/// This stores items in contiguous memory, making a note of free
/// slots when items are removed again so that they can be reused.
///
/// This is particularly efficient when items are often added and
/// seldom removed.
///
/// Items are keyed by an Id, which can be any type you wish, but
/// must be convertible to/from a `usize`. This promotes using a
/// custom Id type to talk about items in the map.
pub struct DenseMap<Id, T> {
/// List of retired indexes that can be re-used
retired: Vec<usize>,
/// All items
items: Vec<Option<T>>,
/// Our ID type
_id_type: std::marker::PhantomData<Id>,
}
impl<Id, T> DenseMap<Id, T>
where
Id: From<usize> + Copy,
usize: From<Id>,
{
pub fn new() -> Self {
DenseMap {
retired: Vec::new(),
items: Vec::new(),
_id_type: std::marker::PhantomData,
}
}
pub fn add(&mut self, item: T) -> Id {
self.add_with(|_| item)
}
pub fn as_slice(&self) -> &[Option<T>] {
&self.items
}
pub fn add_with<F>(&mut self, f: F) -> Id
where
F: FnOnce(Id) -> T,
{
match self.retired.pop() {
Some(id) => {
let id_out = id.into();
self.items[id] = Some(f(id_out));
id_out
}
None => {
let id = self.items.len().into();
self.items.push(Some(f(id)));
id
}
}
}
pub fn get(&self, id: Id) -> Option<&T> {
let id: usize = id.into();
self.items.get(id).and_then(|item| item.as_ref())
}
pub fn get_mut(&mut self, id: Id) -> Option<&mut T> {
let id: usize = id.into();
self.items.get_mut(id).and_then(|item| item.as_mut())
}
pub fn remove(&mut self, id: Id) -> Option<T> {
let id: usize = id.into();
let old = self.items.get_mut(id).and_then(|item| item.take());
if old.is_some() {
// something was actually removed, so lets add the id to
// the list of retired ids!
self.retired.push(id);
}
old
}
pub fn iter(&self) -> impl Iterator<Item = (Id, &T)> + '_ {
self.items
.iter()
.enumerate()
.filter_map(|(id, item)| Some((id.into(), item.as_ref()?)))
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = (Id, &mut T)> + '_ {
self.items
.iter_mut()
.enumerate()
.filter_map(|(id, item)| Some((id.into(), item.as_mut()?)))
}
pub fn into_iter(self) -> impl Iterator<Item = (Id, T)> {
self.items
.into_iter()
.enumerate()
.filter_map(|(id, item)| Some((id.into(), item?)))
}
pub fn len(&self) -> usize {
self.items.len() - self.retired.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Return the next Id that will be assigned.
pub fn next_id(&self) -> usize {
match self.retired.last() {
Some(id) => *id,
None => self.items.len(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn len_doesnt_panic_if_lots_of_ids_are_retired() {
let mut map = DenseMap::<usize, usize>::new();
let id1 = map.add(1);
let id2 = map.add(2);
let id3 = map.add(3);
assert_eq!(map.len(), 3);
map.remove(id1);
map.remove(id2);
assert_eq!(map.len(), 1);
map.remove(id3);
assert_eq!(map.len(), 0);
map.remove(id1);
map.remove(id1);
map.remove(id1);
assert_eq!(map.len(), 0);
}
}

66
common/src/either_sink.rs Normal file
View File

@ -0,0 +1,66 @@
use futures::sink::Sink;
use pin_project_lite::pin_project;
pin_project! {
#[project = EitherSinkInner]
pub enum EitherSink<A, B> {
A { #[pin] inner: A },
B { #[pin] inner: B }
}
}
/// A simple enum that delegates implementation to one of
/// the two possible sinks contained within.
impl<A, B> EitherSink<A, B> {
pub fn a(val: A) -> Self {
EitherSink::A { inner: val }
}
pub fn b(val: B) -> Self {
EitherSink::B { inner: val }
}
}
impl<Item, Error, A, B> Sink<Item> for EitherSink<A, B>
where
A: Sink<Item, Error = Error>,
B: Sink<Item, Error = Error>,
{
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
match self.project() {
EitherSinkInner::A { inner } => inner.poll_ready(cx),
EitherSinkInner::B { inner } => inner.poll_ready(cx),
}
}
fn start_send(self: std::pin::Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
match self.project() {
EitherSinkInner::A { inner } => inner.start_send(item),
EitherSinkInner::B { inner } => inner.start_send(item),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
match self.project() {
EitherSinkInner::A { inner } => inner.poll_flush(cx),
EitherSinkInner::B { inner } => inner.poll_flush(cx),
}
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
match self.project() {
EitherSinkInner::A { inner } => inner.poll_close(cx),
EitherSinkInner::B { inner } => inner.poll_close(cx),
}
}
}

159
common/src/http_utils.rs Normal file
View File

@ -0,0 +1,159 @@
use futures::io::{BufReader, BufWriter};
use hyper::server::conn::AddrStream;
use hyper::{Body, Request, Response, Server};
use std::future::Future;
use std::net::SocketAddr;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
/// A convenience function to start up a Hyper server and handle requests.
pub async fn start_server<H, F>(addr: SocketAddr, handler: H) -> Result<(), anyhow::Error>
where
H: Clone + Send + Sync + 'static + FnMut(SocketAddr, Request<Body>) -> F,
F: Send + 'static + Future<Output = Result<Response<Body>, anyhow::Error>>,
{
let service = hyper::service::make_service_fn(move |addr: &AddrStream| {
let mut handler = handler.clone();
let addr = addr.remote_addr();
async move { Ok::<_, hyper::Error>(hyper::service::service_fn(move |r| handler(addr, r))) }
});
let server = Server::bind(&addr).serve(service);
log::info!("listening on http://{}", server.local_addr());
server.await?;
Ok(())
}
type WsStream = BufReader<BufWriter<Compat<hyper::upgrade::Upgraded>>>;
pub type WsSender = soketto::connection::Sender<WsStream>;
pub type WsReceiver = soketto::connection::Receiver<WsStream>;
/// A convenience function to upgrade a Hyper request into a Soketto Websocket.
pub fn upgrade_to_websocket<H, F>(req: Request<Body>, on_upgrade: H) -> hyper::Response<Body>
where
H: 'static + Send + FnOnce(WsSender, WsReceiver) -> F,
F: Send + Future<Output = ()>,
{
if !is_upgrade_request(&req) {
return basic_response(400, "Expecting WebSocket upgrade headers");
}
let key = match req.headers().get("Sec-WebSocket-Key") {
Some(key) => key,
None => {
return basic_response(
400,
"Upgrade to websocket connection failed; Sec-WebSocket-Key header not provided",
)
}
};
if req
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return basic_response(
400,
"Sec-WebSocket-Version header should have a value of 13",
);
}
// Just a little ceremony to return the correct response key:
let mut accept_key_buf = [0; 32];
let accept_key = generate_websocket_accept_key(key.as_bytes(), &mut accept_key_buf);
// Tell the client that we accept the upgrade-to-WS request:
let response = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header("Sec-WebSocket-Accept", accept_key)
.body(Body::empty())
.expect("bug: failed to build response");
// Spawn our handler to work with the WS connection:
tokio::spawn(async move {
// Get our underlying TCP stream:
let stream = match hyper::upgrade::on(req).await {
Ok(stream) => stream,
Err(e) => {
log::error!("Error upgrading connection to websocket: {}", e);
return;
}
};
// Start a Soketto server with it:
let server =
soketto::handshake::Server::new(BufReader::new(BufWriter::new(stream.compat())));
// Get hold of a way to send and receive messages:
let (sender, receiver) = server.into_builder().finish();
// Pass these to our when-upgraded handler:
on_upgrade(sender, receiver).await;
});
response
}
/// A helper to return a basic HTTP response with a code and text body.
fn basic_response(code: u16, msg: impl AsRef<str>) -> Response<Body> {
Response::builder()
.status(code)
.body(Body::from(msg.as_ref().to_owned()))
.expect("bug: failed to build response body")
}
/// Defined in RFC 6455. this is how we convert the Sec-WebSocket-Key in a request into a
/// Sec-WebSocket-Accept that we return in the response.
fn generate_websocket_accept_key<'a>(key: &[u8], buf: &'a mut [u8; 32]) -> &'a [u8] {
// Defined in RFC 6455, we append this to the key to generate the response:
const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
use sha1::{Digest, Sha1};
let mut digest = Sha1::new();
digest.update(key);
digest.update(KEY);
let d = digest.finalize();
use base64::{engine::general_purpose, Engine as _};
let n = general_purpose::STANDARD
.encode_slice(&d, buf)
.expect("Sha1 must fit into [u8; 32]");
&buf[..n]
}
/// Check if a request is a websocket upgrade request.
fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
header_contains_value(request.headers(), hyper::header::CONNECTION, b"upgrade")
&& header_contains_value(request.headers(), hyper::header::UPGRADE, b"websocket")
}
/// Check if there is a header of the given name containing the wanted value.
fn header_contains_value(
headers: &hyper::HeaderMap,
header: hyper::header::HeaderName,
value: &[u8],
) -> bool {
pub fn trim(x: &[u8]) -> &[u8] {
let from = match x.iter().position(|x| !x.is_ascii_whitespace()) {
Some(i) => i,
None => return &[],
};
let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap();
&x[from..=to]
}
for header in headers.get_all(header) {
if header
.as_bytes()
.split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true;
}
}
false
}

74
common/src/id_type.rs Normal file
View File

@ -0,0 +1,74 @@
/// Define a type that can be used as an ID, be converted from/to the inner type,
/// and serialized/deserialized transparently into the inner type.
#[macro_export]
macro_rules! id_type {
($( #[$attrs:meta] )* $vis:vis struct $ty:ident ( $inner:ident ) $(;)? ) => {
#[derive(Debug,Clone,Copy,PartialEq,Eq,Hash)]
$( #[$attrs] )*
$vis struct $ty($inner);
impl $ty {
#[allow(dead_code)]
pub fn new(inner: $inner) -> Self {
Self(inner)
}
}
impl From<$inner> for $ty {
fn from(inner: $inner) -> Self {
Self(inner)
}
}
impl From<$ty> for $inner {
fn from(ty: $ty) -> Self {
ty.0
}
}
}
}
#[cfg(test)]
mod test {
// Mostly we're just checking that everything compiles OK
// when the macro is used as expected..
// A basic definition is possible:
id_type! {
struct Foo(usize)
}
// We can add a ';' on the end:
id_type! {
struct Bar(usize);
}
// Visibility qualifiers are allowed:
id_type! {
pub struct Wibble(u64)
}
// Doc strings are possible
id_type! {
/// We can have doc strings, too
pub(crate) struct Wobble(u16)
}
// In fact, any attributes can be added (common
// derives are added already):
id_type! {
/// We can have doc strings, too
#[derive(serde::Serialize)]
#[serde(transparent)]
pub(crate) struct Lark(u16)
}
#[test]
fn create_and_use_new_id_type() {
let _ = Foo::new(123);
let id = Foo::from(123);
let id_num: usize = id.into();
assert_eq!(id_num, 123);
}
}

View File

@ -0,0 +1,51 @@
//! Internal messages passed between the shard and telemetry core.
use std::net::IpAddr;
use crate::id_type;
use crate::node_message::Payload;
use crate::node_types::{BlockHash, NodeDetails};
use serde::{Deserialize, Serialize};
id_type! {
/// The shard-local ID of a given node, where a single connection
/// might send data on behalf of more than one chain.
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ShardNodeId(usize);
}
/// Message sent from a telemetry shard to the telemetry core
#[derive(Deserialize, Serialize, Debug, Clone)]
pub enum FromShardAggregator {
/// Get information about a new node, including it's IP
/// address and chain genesis hash.
AddNode {
ip: IpAddr,
node: NodeDetails,
local_id: ShardNodeId,
genesis_hash: BlockHash,
},
/// A message payload with updated details for a node
UpdateNode {
local_id: ShardNodeId,
payload: Payload,
},
/// Inform the telemetry core that a node has been removed
RemoveNode { local_id: ShardNodeId },
}
/// Message sent form the telemetry core to a telemetry shard
#[derive(Deserialize, Serialize, Debug, Clone)]
pub enum FromTelemetryCore {
Mute {
local_id: ShardNodeId,
reason: MuteReason,
},
}
/// Why is the thing being muted?
#[derive(Deserialize, Serialize, Debug, Clone)]
pub enum MuteReason {
Overquota,
ChainNotAllowed,
}

27
common/src/lib.rs Normal file
View File

@ -0,0 +1,27 @@
pub mod byte_size;
pub mod http_utils;
pub mod id_type;
pub mod internal_messages;
pub mod node_message;
pub mod node_types;
pub mod ready_chunks_all;
pub mod rolling_total;
pub mod time;
pub mod ws_client;
mod assign_id;
mod dense_map;
mod either_sink;
mod mean_list;
mod most_seen;
mod multi_map_unique;
mod num_stats;
// Export a bunch of common bits at the top level for ease of import:
pub use assign_id::AssignId;
pub use dense_map::DenseMap;
pub use either_sink::EitherSink;
pub use mean_list::MeanList;
pub use most_seen::MostSeen;
pub use multi_map_unique::MultiMapUnique;
pub use num_stats::NumStats;

79
common/src/mean_list.rs Normal file
View File

@ -0,0 +1,79 @@
use num_traits::{Float, Zero};
use std::ops::AddAssign;
pub struct MeanList<T>
where
T: Float + AddAssign + Zero + From<u8>,
{
period_sum: T,
period_count: u8,
mean_index: u8,
means: [T; 20],
ticks_per_mean: u8,
}
impl<T> Default for MeanList<T>
where
T: Float + AddAssign + Zero + From<u8>,
{
fn default() -> MeanList<T> {
MeanList {
period_sum: T::zero(),
period_count: 0,
mean_index: 0,
means: [T::zero(); 20],
ticks_per_mean: 1,
}
}
}
impl<T> MeanList<T>
where
T: Float + AddAssign + Zero + From<u8>,
{
pub fn slice(&self) -> &[T] {
&self.means[..usize::from(self.mean_index)]
}
pub fn push(&mut self, val: T) -> bool {
if self.mean_index == 20 && self.ticks_per_mean < 32 {
self.squash_means();
}
self.period_sum += val;
self.period_count += 1;
if self.period_count == self.ticks_per_mean {
self.push_mean();
true
} else {
false
}
}
fn push_mean(&mut self) {
let mean = self.period_sum / std::convert::From::from(self.period_count);
if self.mean_index == 20 && self.ticks_per_mean == 32 {
self.means.rotate_left(1);
self.means[19] = mean;
} else {
self.means[usize::from(self.mean_index)] = mean;
self.mean_index += 1;
}
self.period_sum = T::zero();
self.period_count = 0;
}
fn squash_means(&mut self) {
self.ticks_per_mean *= 2;
self.mean_index = 10;
for i in 0..10 {
let i2 = i * 2;
self.means[i] = (self.means[i2] + self.means[i2 + 1]) / std::convert::From::from(2)
}
}
}

236
common/src/most_seen.rs Normal file
View File

@ -0,0 +1,236 @@
use std::collections::HashMap;
use std::hash::Hash;
/// Add items to this, and it will keep track of what the item
/// seen the most is.
#[derive(Debug)]
pub struct MostSeen<T> {
current_best: T,
current_count: usize,
others: HashMap<T, usize>,
}
impl<T: Default> Default for MostSeen<T> {
fn default() -> Self {
// This sets the "most seen item" to the default value for the type,
// and notes that nobody has actually seen it yet (current_count is 0).
Self {
current_best: T::default(),
current_count: 0,
others: HashMap::new(),
}
}
}
impl<T> MostSeen<T> {
pub fn new(item: T) -> Self {
// This starts us off with an item that we've seen. This item is set as
// the "most seen item" and the current_count is set to 1, as we've seen it
// once by virtue of providing it here.
Self {
current_best: item,
current_count: 1,
others: HashMap::new(),
}
}
pub fn best(&self) -> &T {
&self.current_best
}
pub fn best_count(&self) -> usize {
self.current_count
}
}
impl<T: Hash + Eq + Clone> MostSeen<T> {
pub fn insert(&mut self, item: &T) -> ChangeResult {
if &self.current_best == item {
// Item already the best one; bump count.
self.current_count += 1;
return ChangeResult::NoChange;
}
// Item not the best; increment count in map
let item_count = self.others.entry(item.clone()).or_default();
*item_count += 1;
// Is item now the best?
if *item_count > self.current_count {
let (mut item, mut count) = self.others.remove_entry(item).expect("item added above");
// Swap the current best for the new best:
std::mem::swap(&mut item, &mut self.current_best);
std::mem::swap(&mut count, &mut self.current_count);
// Insert the old best back into the map:
self.others.insert(item, count);
ChangeResult::NewMostSeenItem
} else {
ChangeResult::NoChange
}
}
pub fn remove(&mut self, item: &T) -> ChangeResult {
if &self.current_best == item {
// Item already the best one; reduce count (don't allow to drop below 0)
self.current_count = self.current_count.saturating_sub(1);
// Is there a new best?
let other_best = self.others.iter().max_by_key(|f| f.1);
let (other_item, &other_count) = match other_best {
Some(item) => item,
None => return ChangeResult::NoChange,
};
if other_count > self.current_count {
// Clone item to unborrow self.others so that we can remove
// the item from it. We could pre-emptively remove and reinsert
// instead, but most of the time there is no change, so I'm
// aiming to keep that path cheaper.
let other_item = other_item.clone();
let (mut other_item, mut other_count) = self
.others
.remove_entry(&other_item)
.expect("item returned above, so def exists");
// Swap the current best for the new best:
std::mem::swap(&mut other_item, &mut self.current_best);
std::mem::swap(&mut other_count, &mut self.current_count);
// Insert the old best back into the map:
self.others.insert(other_item, other_count);
return ChangeResult::NewMostSeenItem;
} else {
return ChangeResult::NoChange;
}
}
// Item is in the map; not the best anyway. decrement count.
if let Some(count) = self.others.get_mut(item) {
*count += 1;
}
ChangeResult::NoChange
}
}
/// Record the result of adding/removing an entry
#[derive(Clone, Copy)]
pub enum ChangeResult {
/// The best item has remained the same.
NoChange,
/// There is a new best item now.
NewMostSeenItem,
}
impl ChangeResult {
pub fn has_changed(self) -> bool {
match self {
ChangeResult::NewMostSeenItem => true,
ChangeResult::NoChange => false,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn default_renames_instantly() {
let mut a: MostSeen<&str> = MostSeen::default();
let res = a.insert(&"Hello");
assert_eq!(*a.best(), "Hello");
assert!(res.has_changed());
}
#[test]
fn new_renames_on_second_change() {
let mut a: MostSeen<&str> = MostSeen::new("First");
a.insert(&"Second");
assert_eq!(*a.best(), "First");
a.insert(&"Second");
assert_eq!(*a.best(), "Second");
}
#[test]
fn removing_doesnt_underflow() {
let mut a: MostSeen<&str> = MostSeen::new("First");
a.remove(&"First");
a.remove(&"First");
a.remove(&"Second");
a.remove(&"Third");
}
#[test]
fn keeps_track_of_best_count() {
let mut a: MostSeen<&str> = MostSeen::default();
a.insert(&"First");
assert_eq!(a.best_count(), 1);
a.insert(&"First");
assert_eq!(a.best_count(), 2);
a.insert(&"First");
assert_eq!(a.best_count(), 3);
a.remove(&"First");
assert_eq!(a.best_count(), 2);
a.remove(&"First");
assert_eq!(a.best_count(), 1);
a.remove(&"First");
assert_eq!(a.best_count(), 0);
a.remove(&"First");
assert_eq!(a.best_count(), 0);
}
#[test]
fn it_tracks_best_on_insert() {
let mut a: MostSeen<&str> = MostSeen::default();
a.insert(&"First");
assert_eq!(*a.best(), "First", "1");
a.insert(&"Second");
assert_eq!(*a.best(), "First", "2");
a.insert(&"Second");
assert_eq!(*a.best(), "Second", "3");
a.insert(&"First");
assert_eq!(*a.best(), "Second", "4");
a.insert(&"First");
assert_eq!(*a.best(), "First", "5");
}
#[test]
fn it_tracks_best() {
let mut a: MostSeen<&str> = MostSeen::default();
a.insert(&"First");
a.insert(&"Second");
a.insert(&"Third"); // 1
a.insert(&"Second");
a.insert(&"Second"); // 3
a.insert(&"First"); // 2
assert_eq!(*a.best(), "Second");
assert_eq!(a.best_count(), 3);
let res = a.remove(&"Second");
assert!(!res.has_changed());
assert_eq!(a.best_count(), 2);
assert_eq!(*a.best(), "Second"); // Tied with "First"
let res = a.remove(&"Second");
assert!(res.has_changed());
assert_eq!(a.best_count(), 2);
assert_eq!(*a.best(), "First"); // First is now ahead
}
}

View File

@ -0,0 +1,151 @@
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
/// A map where each key can contain multiple values. We enforce that a value
/// only ever belongs to one key at a time (the latest key it was inserted
/// against).
pub struct MultiMapUnique<K, V> {
value_to_key: HashMap<V, K>,
key_to_values: HashMap<K, HashSet<V>>,
}
impl<K, V> MultiMapUnique<K, V> {
/// Construct a new MultiMap
pub fn new() -> Self {
Self {
value_to_key: HashMap::new(),
key_to_values: HashMap::new(),
}
}
/// Return the set of values associated with a key.
pub fn get_values(&self, key: &K) -> Option<&HashSet<V>>
where
K: Eq + Hash,
{
self.key_to_values.get(key)
}
/// Remove a value from the MultiMap, returning the key it was found
/// under, if it was found at all.
///
/// ```
/// let mut m = ghost_telemetry_common::MultiMapUnique::new();
///
/// m.insert("a", 1);
/// m.insert("a", 2);
///
/// m.insert("b", 3);
/// m.insert("b", 4);
///
/// assert_eq!(m.num_keys(), 2);
/// assert_eq!(m.num_values(), 4);
///
/// m.remove_value(&1);
///
/// assert_eq!(m.num_keys(), 2);
/// assert_eq!(m.num_values(), 3);
///
/// m.remove_value(&2);
///
/// assert_eq!(m.num_keys(), 1);
/// assert_eq!(m.num_values(), 2);
/// ```
pub fn remove_value(&mut self, value: &V) -> Option<K>
where
V: Eq + Hash,
K: Eq + Hash,
{
if let Some(key) = self.value_to_key.remove(value) {
if let Some(m) = self.key_to_values.get_mut(&key) {
m.remove(value);
if m.is_empty() {
self.key_to_values.remove(&key);
}
}
return Some(key);
}
None
}
/// Insert a key+value pair into the multimap. Multiple different
/// values can exist for a single key, but only one of each value can
/// exist in the MultiMap.
///
/// If a previous value did exist, the old key it was inserted against
/// is returned.
///
/// ```
/// let mut m = ghost_telemetry_common::MultiMapUnique::new();
///
/// let old_key = m.insert("a", 1);
/// assert_eq!(old_key, None);
///
/// let old_key = m.insert("b", 1);
/// assert_eq!(old_key, Some("a"));
///
/// let old_key = m.insert("c", 1);
/// assert_eq!(old_key, Some("b"));
///
/// assert_eq!(m.num_keys(), 1);
/// assert_eq!(m.num_values(), 1);
///
/// // The value `1` must be unique in the map, so it only exists
/// // in the last location it was inserted:
/// assert!(m.get_values(&"a").is_none());
/// assert!(m.get_values(&"b").is_none());
/// assert_eq!(m.get_values(&"c").unwrap().iter().collect::<Vec<_>>(), vec![&1]);
/// ```
pub fn insert(&mut self, key: K, value: V) -> Option<K>
where
V: Clone + Eq + Hash,
K: Clone + Eq + Hash,
{
// Ensure that the value doesn't exist elsewhere already;
// values must be unique and only belong to one key:
let old_key = self.remove_value(&value);
self.value_to_key.insert(value.clone(), key.clone());
self.key_to_values.entry(key).or_default().insert(value);
old_key
}
/// Number of values stored in the map
pub fn num_values(&self) -> usize {
self.value_to_key.len()
}
/// Number of keys stored in the map
pub fn num_keys(&self) -> usize {
self.key_to_values.len()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn multiple_values_allowed_per_key() {
let mut m = MultiMapUnique::new();
m.insert("a", 1);
m.insert("a", 2);
m.insert("b", 3);
m.insert("b", 4);
assert_eq!(m.num_keys(), 2);
assert_eq!(m.num_values(), 4);
let a_vals = m.get_values(&"a").expect("a vals");
assert!(a_vals.contains(&1));
assert!(a_vals.contains(&2));
let b_vals = m.get_values(&"b").expect("b vals");
assert!(b_vals.contains(&3));
assert!(b_vals.contains(&4));
}
}

208
common/src/node_message.rs Normal file
View File

@ -0,0 +1,208 @@
//! This is the internal representation of telemetry messages sent from nodes.
//! There is a separate JSON representation of these types, because internally we want to be
//! able to serialize these messages to bincode, and various serde attributes aren't compatible
//! with this, hence this separate internal representation.
use crate::node_types::{Block, BlockHash, BlockNumber, NodeDetails};
use serde::{Deserialize, Serialize};
pub type NodeMessageId = u64;
#[derive(Serialize, Deserialize, Debug)]
pub enum NodeMessage {
V1 { payload: Payload },
V2 { id: NodeMessageId, payload: Payload },
}
impl NodeMessage {
/// Returns the ID associated with the node message, or 0
/// if the message has no ID.
pub fn id(&self) -> NodeMessageId {
match self {
NodeMessage::V1 { .. } => 0,
NodeMessage::V2 { id, .. } => *id,
}
}
/// Return the payload associated with the message.
pub fn into_payload(self) -> Payload {
match self {
NodeMessage::V1 { payload, .. } | NodeMessage::V2 { payload, .. } => payload,
}
}
}
impl From<NodeMessage> for Payload {
fn from(msg: NodeMessage) -> Payload {
msg.into_payload()
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum Payload {
SystemConnected(SystemConnected),
SystemInterval(SystemInterval),
BlockImport(Block),
NotifyFinalized(Finalized),
AfgAuthoritySet(AfgAuthoritySet),
HwBench(NodeHwBench),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SystemConnected {
pub genesis_hash: BlockHash,
pub node: NodeDetails,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SystemInterval {
pub peers: Option<u64>,
pub txcount: Option<u64>,
pub bandwidth_upload: Option<f64>,
pub bandwidth_download: Option<f64>,
pub finalized_height: Option<BlockNumber>,
pub finalized_hash: Option<BlockHash>,
pub block: Option<Block>,
pub used_state_cache_size: Option<f32>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Finalized {
pub hash: BlockHash,
pub height: Box<str>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AfgAuthoritySet {
pub authority_id: Box<str>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct NodeHwBench {
pub cpu_hashrate_score: u64,
pub memory_memcpy_score: u64,
pub disk_sequential_write_score: Option<u64>,
pub disk_random_write_score: Option<u64>,
}
impl Payload {
pub fn best_block(&self) -> Option<&Block> {
match self {
Payload::BlockImport(block) => Some(block),
Payload::SystemInterval(SystemInterval { block, .. }) => block.as_ref(),
_ => None,
}
}
pub fn finalized_block(&self) -> Option<Block> {
match self {
Payload::SystemInterval(ref interval) => Some(Block {
hash: interval.finalized_hash?,
height: interval.finalized_height?,
}),
Payload::NotifyFinalized(ref finalized) => Some(Block {
hash: finalized.hash,
height: finalized.height.parse().ok()?,
}),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrayvec::ArrayString;
use bincode::Options;
// Without adding a derive macro and marker trait (and enforcing their use), we don't really
// know whether things can (de)serialize to bincode or not at runtime without failing unless
// we test the different types we want to (de)serialize ourselves. We just need to test each
// type, not each variant.
fn bincode_can_serialize_and_deserialize<'de, T>(item: T)
where
T: Serialize + serde::de::DeserializeOwned,
{
let bytes = bincode::serialize(&item).expect("Serialization should work");
let _: T = bincode::deserialize(&bytes).expect("Deserialization should work");
}
#[test]
fn bincode_can_serialize_and_deserialize_node_message_system_connected() {
bincode_can_serialize_and_deserialize(NodeMessage::V1 {
payload: Payload::SystemConnected(SystemConnected {
genesis_hash: BlockHash::zero(),
node: NodeDetails {
chain: "foo".into(),
name: "foo".into(),
implementation: "foo".into(),
version: "foo".into(),
target_arch: Some("x86_64".into()),
target_os: Some("linux".into()),
target_env: Some("env".into()),
validator: None,
network_id: ArrayString::new(),
startup_time: None,
sysinfo: None,
ip: Some("127.0.0.1".into()),
},
}),
});
}
#[test]
fn bincode_can_serialize_and_deserialize_node_message_system_interval() {
bincode_can_serialize_and_deserialize(NodeMessage::V1 {
payload: Payload::SystemInterval(SystemInterval {
peers: None,
txcount: None,
bandwidth_upload: None,
bandwidth_download: None,
finalized_height: None,
finalized_hash: None,
block: None,
used_state_cache_size: None,
}),
});
}
#[test]
fn bincode_can_serialize_and_deserialize_node_message_block_import() {
bincode_can_serialize_and_deserialize(NodeMessage::V1 {
payload: Payload::BlockImport(Block {
hash: BlockHash([0; 32]),
height: 0,
}),
});
}
#[test]
fn bincode_can_serialize_and_deserialize_node_message_notify_finalized() {
bincode_can_serialize_and_deserialize(NodeMessage::V1 {
payload: Payload::NotifyFinalized(Finalized {
hash: BlockHash::zero(),
height: "foo".into(),
}),
});
}
#[test]
fn bincode_can_serialize_and_deserialize_node_message_afg_authority_set() {
bincode_can_serialize_and_deserialize(NodeMessage::V1 {
payload: Payload::AfgAuthoritySet(AfgAuthoritySet {
authority_id: "foo".into(),
}),
});
}
#[test]
fn bincode_block_zero() {
let raw = Block::zero();
let bytes = bincode::options().serialize(&raw).unwrap();
let deserialized: Block = bincode::options().deserialize(&bytes).unwrap();
assert_eq!(raw.hash, deserialized.hash);
assert_eq!(raw.height, deserialized.height);
}
}

245
common/src/node_types.rs Normal file
View File

@ -0,0 +1,245 @@
//! These types are partly used in [`crate::node_message`], but also stored and used
//! more generally through the application.
use arrayvec::ArrayString;
use serde::ser::{SerializeTuple, Serializer};
use serde::{Deserialize, Serialize};
use crate::{time, MeanList};
pub type BlockNumber = u64;
pub type Timestamp = u64;
pub use primitive_types::H256 as BlockHash;
pub type NetworkId = ArrayString<64>;
/// Basic node details.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct NodeDetails {
pub chain: Box<str>,
pub name: Box<str>,
pub implementation: Box<str>,
pub version: Box<str>,
pub validator: Option<Box<str>>,
pub network_id: NetworkId,
pub startup_time: Option<Box<str>>,
pub target_os: Option<Box<str>>,
pub target_arch: Option<Box<str>>,
pub target_env: Option<Box<str>>,
pub sysinfo: Option<NodeSysInfo>,
pub ip: Option<Box<str>>,
}
/// Hardware and software information for the node.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct NodeSysInfo {
/// The exact CPU model.
pub cpu: Option<Box<str>>,
/// The total amount of memory, in bytes.
pub memory: Option<u64>,
/// The number of physical CPU cores.
pub core_count: Option<u32>,
/// The Linux kernel version.
pub linux_kernel: Option<Box<str>>,
/// The exact Linux distribution used.
pub linux_distro: Option<Box<str>>,
/// Whether the node's running under a virtual machine.
pub is_virtual_machine: Option<bool>,
}
/// Hardware benchmark results for the node.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct NodeHwBench {
/// The CPU speed, as measured in how many MB/s it can hash using the BLAKE2b-256 hash.
pub cpu_hashrate_score: u64,
/// Memory bandwidth in MB/s, calculated by measuring the throughput of `memcpy`.
pub memory_memcpy_score: u64,
/// Sequential disk write speed in MB/s.
pub disk_sequential_write_score: Option<u64>,
/// Random disk write speed in MB/s.
pub disk_random_write_score: Option<u64>,
}
/// A couple of node statistics.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct NodeStats {
pub peers: u64,
pub txcount: u64,
}
// # A note about serialization/deserialization of types in this file:
//
// Some of the types here are sent to UI feeds. In an effort to keep the
// amount of bytes sent to a minimum, we have written custom serializers
// for those types.
//
// For testing purposes, it's useful to be able to deserialize from some
// of these types so that we can test message feed things, so custom
// deserializers exist to undo the work of the custom serializers.
impl Serialize for NodeStats {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(2)?;
tup.serialize_element(&self.peers)?;
tup.serialize_element(&self.txcount)?;
tup.end()
}
}
impl<'de> Deserialize<'de> for NodeStats {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let (peers, txcount) = <(u64, u64)>::deserialize(deserializer)?;
Ok(NodeStats { peers, txcount })
}
}
/// Node IO details.
#[derive(Default)]
pub struct NodeIO {
pub used_state_cache_size: MeanList<f32>,
}
impl Serialize for NodeIO {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(1)?;
// This is "one-way": we can't deserialize again from this to a MeanList:
tup.serialize_element(self.used_state_cache_size.slice())?;
tup.end()
}
}
/// Concise block details
#[derive(Deserialize, Serialize, Debug, Clone, Copy, PartialEq)]
pub struct Block {
pub hash: BlockHash,
pub height: BlockNumber,
}
impl Block {
pub fn zero() -> Self {
Block {
hash: BlockHash::from([0; 32]),
height: 0,
}
}
}
/// Node hardware details.
#[derive(Default)]
pub struct NodeHardware {
/// Upload uses means
pub upload: MeanList<f64>,
/// Download uses means
pub download: MeanList<f64>,
/// Stampchange uses means
pub chart_stamps: MeanList<f64>,
}
impl Serialize for NodeHardware {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(3)?;
// These are "one-way": we can't deserialize again from them to MeanLists:
tup.serialize_element(self.upload.slice())?;
tup.serialize_element(self.download.slice())?;
tup.serialize_element(self.chart_stamps.slice())?;
tup.end()
}
}
/// Node location details
#[derive(Debug, Clone, PartialEq)]
pub struct NodeLocation {
pub latitude: f32,
pub longitude: f32,
pub city: Box<str>,
}
impl Serialize for NodeLocation {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(3)?;
tup.serialize_element(&self.latitude)?;
tup.serialize_element(&self.longitude)?;
tup.serialize_element(&&*self.city)?;
tup.end()
}
}
impl<'de> Deserialize<'de> for NodeLocation {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let (latitude, longitude, city) = <(f32, f32, Box<str>)>::deserialize(deserializer)?;
Ok(NodeLocation {
latitude,
longitude,
city,
})
}
}
/// Verbose block details
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BlockDetails {
pub block: Block,
pub block_time: u64,
pub block_timestamp: u64,
pub propagation_time: Option<u64>,
}
impl Default for BlockDetails {
fn default() -> Self {
BlockDetails {
block: Block::zero(),
block_timestamp: time::now(),
block_time: 0,
propagation_time: None,
}
}
}
impl Serialize for BlockDetails {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(5)?;
tup.serialize_element(&self.block.height)?;
tup.serialize_element(&self.block.hash)?;
tup.serialize_element(&self.block_time)?;
tup.serialize_element(&self.block_timestamp)?;
tup.serialize_element(&self.propagation_time)?;
tup.end()
}
}
impl<'de> Deserialize<'de> for BlockDetails {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let tup = <(u64, BlockHash, u64, u64, Option<u64>)>::deserialize(deserializer)?;
Ok(BlockDetails {
block: Block {
height: tup.0,
hash: tup.1,
},
block_time: tup.2,
block_timestamp: tup.3,
propagation_time: tup.4,
})
}
}

104
common/src/num_stats.rs Normal file
View File

@ -0,0 +1,104 @@
use num_traits::{Bounded, NumOps, Zero};
use std::convert::TryFrom;
use std::iter::Sum;
/// Keep track of last N numbers pushed onto internal stack.
/// Provides means to get an average of said numbers.
pub struct NumStats<T> {
stack: Box<[T]>,
index: usize,
sum: T,
}
impl<T: NumOps + Zero + Bounded + Copy + Sum + TryFrom<usize>> NumStats<T> {
pub fn new(size: usize) -> Self {
NumStats {
stack: vec![T::zero(); size].into_boxed_slice(),
index: 0,
sum: T::zero(),
}
}
pub fn push(&mut self, val: T) {
let slot = &mut self.stack[self.index % self.stack.len()];
self.sum = (self.sum + val) - *slot;
*slot = val;
self.index += 1;
}
pub fn average(&self) -> T {
let cap = std::cmp::min(self.index, self.stack.len());
if cap == 0 {
return T::zero();
}
let cap = T::try_from(cap).unwrap_or_else(|_| T::max_value());
self.sum / cap
}
pub fn reset(&mut self) {
self.index = 0;
self.sum = T::zero();
for val in self.stack.iter_mut() {
*val = T::zero();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calculates_correct_average() {
let mut stats: NumStats<u64> = NumStats::new(10);
stats.push(3);
stats.push(7);
assert_eq!(stats.average(), 5);
}
#[test]
fn calculates_correct_average_over_bounds() {
let mut stats: NumStats<u64> = NumStats::new(10);
stats.push(100);
for _ in 0..9 {
stats.push(0);
}
assert_eq!(stats.average(), 10);
stats.push(0);
assert_eq!(stats.average(), 0);
}
#[test]
fn resets_properly() {
let mut stats: NumStats<u64> = NumStats::new(10);
for _ in 0..10 {
stats.push(100);
}
assert_eq!(stats.average(), 100);
stats.reset();
assert_eq!(stats.average(), 0);
stats.push(7);
stats.push(3);
assert_eq!(stats.average(), 5);
}
}

View File

@ -0,0 +1,105 @@
//! [`futures::StreamExt::ready_chunks()`] internally stores a vec with a certain capacity, and will buffer up
//! up to that many items that are ready from the underlying stream before returning either when we run out of
//! Poll::Ready items, or we hit the capacity.
//!
//! This variation has no fixed capacity, and will buffer everything it can up at each point to return. This is
//! better when the amount of items varies a bunch (and we don't want to allocate a fixed capacity every time),
//! and can help ensure that we process as many items as possible each time (rather than only up to capacity items).
//!
//! Code is adapted from the futures implementation
//! (see [ready_chunks.rs](https://docs.rs/futures-util/0.3.15/src/futures_util/stream/stream/ready_chunks.rs.html)).
use core::mem;
use core::pin::Pin;
use futures::stream::Fuse;
use futures::stream::{FusedStream, Stream};
use futures::task::{Context, Poll};
use futures::StreamExt;
use pin_project_lite::pin_project;
pin_project! {
/// Buffer up all Ready items in the underlying stream each time
/// we attempt to retrieve items from it, and return a Vec of those
/// items.
#[derive(Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct ReadyChunksAll<St: Stream> {
#[pin]
stream: Fuse<St>,
items: Vec<St::Item>,
}
}
impl<St: Stream> ReadyChunksAll<St>
where
St: Stream,
{
pub fn new(stream: St) -> Self {
Self {
stream: stream.fuse(),
items: Vec::new(),
}
}
}
impl<St: Stream> Stream for ReadyChunksAll<St> {
type Item = Vec<St::Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.stream.as_mut().poll_next(cx) {
// Flush all collected data if underlying stream doesn't contain
// more ready values
Poll::Pending => {
return if this.items.is_empty() {
Poll::Pending
} else {
Poll::Ready(Some(mem::take(this.items)))
}
}
// Push the ready item into the buffer
Poll::Ready(Some(item)) => {
this.items.push(item);
}
// Since the underlying stream ran out of values, return what we
// have buffered, if we have anything.
Poll::Ready(None) => {
let last = if this.items.is_empty() {
None
} else {
let full_buf = mem::take(this.items);
Some(full_buf)
};
return Poll::Ready(last);
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
// Look at the underlying stream's size_hint. If we've
// buffered some items, we'll return at least that Vec,
// giving us a lower bound 1 greater than the underlying.
// The upper bound is, worst case, our vec + each individual
// item in the underlying stream.
let chunk_len = if self.items.is_empty() { 0 } else { 1 };
let (lower, upper) = self.stream.size_hint();
let lower = lower.saturating_add(chunk_len);
let upper = match upper {
Some(x) => x.checked_add(chunk_len),
None => None,
};
(lower, upper)
}
}
impl<St: FusedStream> FusedStream for ReadyChunksAll<St> {
fn is_terminated(&self) -> bool {
self.stream.is_terminated() && self.items.is_empty()
}
}

342
common/src/rolling_total.rs Normal file
View File

@ -0,0 +1,342 @@
use num_traits::{SaturatingAdd, SaturatingSub, Zero};
use std::collections::VecDeque;
use std::time::{Duration, Instant};
/// Build an object responsible for keeping track of a rolling total.
/// It does this in constant time and using memory proportional to the
/// granularity * window size multiple that we set.
pub struct RollingTotalBuilder<Time: TimeSource = SystemTimeSource> {
window_size_multiple: usize,
granularity: Duration,
time_source: Time,
}
impl RollingTotalBuilder {
/// Build a [`RollingTotal`] struct. By default,
/// the window_size is 10s, the granularity is 1s,
/// and system time is used.
pub fn new() -> RollingTotalBuilder<SystemTimeSource> {
Self {
window_size_multiple: 10,
granularity: Duration::from_secs(1),
time_source: SystemTimeSource,
}
}
/// Set the source of time we'll use. By default, we use system time.
pub fn time_source<Time: TimeSource>(self, val: Time) -> RollingTotalBuilder<Time> {
RollingTotalBuilder {
window_size_multiple: self.window_size_multiple,
granularity: self.granularity,
time_source: val,
}
}
/// Set the size of the window of time that we'll look back on
/// to sum up values over to give us the current total. The size
/// is set as a multiple of the granularity; a granularity of 1s
/// and a size of 10 means the window size will be 10 seconds.
pub fn window_size_multiple(mut self, val: usize) -> Self {
self.window_size_multiple = val;
self
}
/// What is the granularity of our windows of time. For example, a
/// granularity of 5 seconds means that every 5 seconds the window
/// that we look at shifts forward to the next 5 seconds worth of data.
/// A larger granularity is more efficient but less accurate than a
/// smaller one.
pub fn granularity(mut self, val: Duration) -> Self {
self.granularity = val;
self
}
}
impl<Time: TimeSource> RollingTotalBuilder<Time> {
/// Create a [`RollingTotal`] with these settings, starting from the
/// instant provided.
pub fn start<T>(self) -> RollingTotal<T, Time>
where
T: Zero + SaturatingAdd + SaturatingSub,
{
let mut averages = VecDeque::new();
averages.push_back((self.time_source.now(), T::zero()));
RollingTotal {
window_size_multiple: self.window_size_multiple,
time_source: self.time_source,
granularity: self.granularity,
averages,
total: T::zero(),
}
}
}
pub struct RollingTotal<Val, Time = SystemTimeSource> {
window_size_multiple: usize,
time_source: Time,
granularity: Duration,
averages: VecDeque<(Instant, Val)>,
total: Val,
}
impl<Val, Time: TimeSource> RollingTotal<Val, Time>
where
Val: SaturatingAdd + SaturatingSub + Copy + std::fmt::Debug,
Time: TimeSource,
{
/// Add a new value at some time.
pub fn push(&mut self, value: Val) {
let time = self.time_source.now();
let (last_time, last_val) = self.averages.back_mut().expect("always 1 value");
let since_last_nanos = time.duration_since(*last_time).as_nanos();
let granularity_nanos = self.granularity.as_nanos();
if since_last_nanos >= granularity_nanos {
// New time doesn't fit into last bucket; create a new bucket with a time
// that is some number of granularity steps from the last, and add the
// value to that.
// This rounds down, eg 7 / 5 = 1. Find the number of granularity steps
// to jump from the last time such that the jump can fit this new value.
let steps = since_last_nanos / granularity_nanos;
// Create a new time this number of jumps forward, and push it.
let new_time =
*last_time + Duration::from_nanos(granularity_nanos as u64) * steps as u32;
self.total = self.total.saturating_add(&value);
self.averages.push_back((new_time, value));
// Remove any old times/values no longer within our window size. If window_size_multiple
// is 1, then we only keep the just-pushed time, hence the "-1". Remember to keep our
// cached total up to date if we remove things.
let oldest_time_in_window =
new_time - (self.granularity * (self.window_size_multiple - 1) as u32);
while self.averages.front().expect("always 1 value").0 < oldest_time_in_window {
let value = self.averages.pop_front().expect("always 1 value").1;
self.total = self.total.saturating_sub(&value);
}
} else {
// New time fits into our last bucket, so just add it on. We don't need to worry
// about bucket cleanup since number/times of buckets hasn't changed.
*last_val = last_val.saturating_add(&value);
self.total = self.total.saturating_add(&value);
}
}
/// Fetch the current rolling total that we've accumulated. Note that this
/// is based on the last seen times and values, and is not affected by the time
/// that it is called.
pub fn total(&self) -> Val {
self.total
}
/// Fetch the current time source, in case we need to modify it.
pub fn time_source(&mut self) -> &mut Time {
&mut self.time_source
}
#[cfg(test)]
pub fn averages(&self) -> &VecDeque<(Instant, Val)> {
&self.averages
}
}
/// A source of time that we can use in our rolling total.
/// This allows us to avoid explicitly mentioning time when pushing
/// new values, and makes it a little harder to accidentally pass
/// an older time and cause a panic.
pub trait TimeSource {
fn now(&self) -> Instant;
}
pub struct SystemTimeSource;
impl TimeSource for SystemTimeSource {
fn now(&self) -> Instant {
Instant::now()
}
}
pub struct UserTimeSource(Instant);
impl UserTimeSource {
pub fn new(time: Instant) -> Self {
UserTimeSource(time)
}
pub fn set_time(&mut self, time: Instant) {
self.0 = time;
}
pub fn increment_by(&mut self, duration: Duration) {
self.0 += duration;
}
}
impl TimeSource for UserTimeSource {
fn now(&self) -> Instant {
self.0
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn deosnt_grow_beyond_window_size() {
let start_time = Instant::now();
let granularity = Duration::from_secs(1);
let mut rolling_total = RollingTotalBuilder::new()
.granularity(granularity)
.window_size_multiple(3) // There should be no more than 3 buckets ever,
.time_source(UserTimeSource(start_time))
.start();
for n in 0..1_000 {
rolling_total.push(n);
rolling_total
.time_source()
.increment_by(Duration::from_millis(300)); // multiple values per granularity.
}
assert_eq!(rolling_total.averages().len(), 3);
assert!(rolling_total.averages().capacity() < 10); // Just to show that it's capacity is bounded.
}
#[test]
fn times_grouped_by_granularity_spacing() {
let start_time = Instant::now();
let granularity = Duration::from_secs(1);
let mut rolling_total = RollingTotalBuilder::new()
.granularity(granularity)
.window_size_multiple(10)
.time_source(UserTimeSource(start_time))
.start();
rolling_total.push(1);
rolling_total
.time_source()
.increment_by(Duration::from_millis(1210)); // 1210; bucket 1
rolling_total.push(2);
rolling_total
.time_source()
.increment_by(Duration::from_millis(2500)); // 3710: bucket 3
rolling_total.push(3);
rolling_total
.time_source()
.increment_by(Duration::from_millis(1100)); // 4810: bucket 4
rolling_total.push(4);
rolling_total
.time_source()
.increment_by(Duration::from_millis(190)); // 5000: bucket 5
rolling_total.push(5);
// Regardless of the exact time that's elapsed, we'll end up with buckets that
// are exactly granularity spacing (or multiples of) apart.
assert_eq!(
rolling_total
.averages()
.into_iter()
.copied()
.collect::<Vec<_>>(),
vec![
(start_time, 1),
(start_time + granularity, 2),
(start_time + granularity * 3, 3),
(start_time + granularity * 4, 4),
(start_time + granularity * 5, 5),
]
)
}
#[test]
fn gets_correct_total_within_granularity() {
let start_time = Instant::now();
let mut rolling_total = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.time_source(UserTimeSource(start_time))
.start();
rolling_total
.time_source()
.increment_by(Duration::from_millis(300));
rolling_total.push(1);
rolling_total
.time_source()
.increment_by(Duration::from_millis(300));
rolling_total.push(10);
rolling_total
.time_source()
.increment_by(Duration::from_millis(300));
rolling_total.push(-5);
assert_eq!(rolling_total.total(), 6);
assert_eq!(rolling_total.averages().len(), 1);
}
#[test]
fn gets_correct_total_within_window() {
let start_time = Instant::now();
let mut rolling_total = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.time_source(UserTimeSource(start_time))
.start();
rolling_total.push(4);
assert_eq!(rolling_total.averages().len(), 1);
assert_eq!(rolling_total.total(), 4);
rolling_total
.time_source()
.increment_by(Duration::from_secs(3));
rolling_total.push(1);
assert_eq!(rolling_total.averages().len(), 2);
assert_eq!(rolling_total.total(), 5);
rolling_total
.time_source()
.increment_by(Duration::from_secs(1));
rolling_total.push(10);
assert_eq!(rolling_total.averages().len(), 3);
assert_eq!(rolling_total.total(), 15);
// Jump precisely to the end of the window. Now, pushing a
// value will displace the first one (4). Note: if no value
// is pushed, this time change will have no effect.
rolling_total
.time_source()
.increment_by(Duration::from_secs(8));
rolling_total.push(20);
assert_eq!(rolling_total.averages().len(), 3);
assert_eq!(rolling_total.total(), 15 + 20 - 4);
// Jump so that only the last value is still within the window:
rolling_total
.time_source()
.increment_by(Duration::from_secs(9));
rolling_total.push(1);
assert_eq!(rolling_total.averages().len(), 2);
assert_eq!(rolling_total.total(), 21);
// Jump so that everything is out of scope (just about!):
rolling_total
.time_source()
.increment_by(Duration::from_secs(10));
rolling_total.push(1);
assert_eq!(rolling_total.averages().len(), 1);
assert_eq!(rolling_total.total(), 1);
}
}

9
common/src/time.rs Normal file
View File

@ -0,0 +1,9 @@
/// Returns current unix time in ms (compatible with JS Date.now())
pub fn now() -> u64 {
use std::time::SystemTime;
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("System time must be configured to be post Unix Epoch start; qed")
.as_millis() as u64
}

View File

@ -0,0 +1,279 @@
use super::on_close::OnClose;
use futures::{channel, StreamExt};
use soketto::handshake::{Client, ServerResponse};
use std::io;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, ServerName};
use tokio_rustls::{rustls, TlsConnector};
use tokio_util::compat::TokioAsyncReadCompatExt;
use super::{
receiver::{Receiver, RecvMessage},
sender::{Sender, SentMessage},
};
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadWrite for T {}
/// The send side of a Soketto WebSocket connection
pub type RawSender =
soketto::connection::Sender<tokio_util::compat::Compat<Box<dyn AsyncReadWrite>>>;
/// The receive side of a Soketto WebSocket connection
pub type RawReceiver =
soketto::connection::Receiver<tokio_util::compat::Compat<Box<dyn AsyncReadWrite>>>;
/// A websocket connection. From this, we can either expose the raw connection
/// or expose a cancel-safe interface to it.
pub struct Connection {
tx: RawSender,
rx: RawReceiver,
}
impl Connection {
/// Get hold of the raw send/receive interface for this connection.
/// These are not cancel-safe, but can be more performant than the
/// cancel-safe channel based interface.
pub fn into_raw(self) -> (RawSender, RawReceiver) {
(self.tx, self.rx)
}
/// Get hold of send and receive channels for this connection.
/// These channels are cancel-safe.
///
/// This spawns a couple of tasks for pulling/pushing messages onto the
/// connection, and so messages will be pushed onto the receiving channel
/// without any further polling. use [`Connection::into_raw`] if you need
/// more precise control over when messages are pulled from the socket.
///
/// # Panics
///
/// This will panic if not called within the context of a tokio runtime.
///
pub fn into_channels(self) -> (Sender, Receiver) {
let (mut ws_to_connection, mut ws_from_connection) = (self.tx, self.rx);
// Shut everything down when we're told to close, which will be either when
// we hit an error trying to receive data on the socket, or when both the send
// and recv channels that we hand out are dropped. Notably, we allow either recv or
// send alone to be dropped and still keep the socket open (we may only care about
// one way communication).
let (tx_closed1, mut rx_closed1) = tokio::sync::broadcast::channel::<()>(1);
let tx_closed2 = tx_closed1.clone();
let mut rx_closed2 = tx_closed1.subscribe();
// Receive messages from the socket:
let (tx_to_external, rx_from_ws) = channel::mpsc::unbounded();
tokio::spawn(async move {
let mut send_to_external = true;
loop {
let mut data = Vec::new();
// Wait for messages, or bail entirely if asked to close.
let message_data = tokio::select! {
msg_data = ws_from_connection.receive_data(&mut data) => { msg_data },
_ = rx_closed1.recv() => { break }
};
let message_data = match message_data {
Err(e) => {
// The socket had an error, so notify interested parties that we should
// shut the connection down and bail out of this receive loop.
log::error!(
"Shutting down websocket connection: Failed to receive data: {}",
e
);
let _ = tx_closed1.send(());
break;
}
Ok(data) => data,
};
// if we hit an error sending, we keep receiving messages and reacting
// to recv issues, but we stop trying to send them anywhere.
if !send_to_external {
continue;
}
let msg = match message_data {
soketto::Data::Binary(_) => Ok(RecvMessage::Binary(data)),
soketto::Data::Text(_) => String::from_utf8(data)
.map(|s| RecvMessage::Text(s))
.map_err(|e| e.into()),
};
if let Err(e) = tx_to_external.unbounded_send(msg) {
// Our external channel may have closed or errored, but the socket hasn't
// been closed, so keep receiving in order to allow the socket to continue to
// function properly (we may be happy just sending messages to it), but stop
// trying to hand back messages we've received from the socket.
log::warn!("Failed to send data out: {}", e);
send_to_external = false;
}
}
});
// Send messages to the socket:
let (tx_to_ws, mut rx_from_external) = channel::mpsc::unbounded::<SentMessage>();
tokio::spawn(async move {
loop {
// Wait for messages, or bail entirely if asked to close.
let msg = tokio::select! {
msg = rx_from_external.next() => { msg },
_ = rx_closed2.recv() => {
// attempt to gracefully end the connection.
let _ = ws_to_connection.close().await;
break
}
};
// No more messages; channel closed. End this loop. Unlike the recv side which
// needs to keep receiving data for the WS connection to stay open, there's no
// reason to keep this side of the loop open if our channel is closed.
let msg = match msg {
Some(msg) => msg,
None => break,
};
// We don't explicitly shut down the channel if we hit send errors. Why? Because the
// receive side of the channel will react to socket errors as well, and close things
// down from there.
match msg {
SentMessage::Text(s) => {
if let Err(e) = ws_to_connection.send_text_owned(s).await {
log::error!(
"Shutting down websocket connection: Failed to send text data: {}",
e
);
break;
}
}
SentMessage::Binary(bytes) => {
if let Err(e) = ws_to_connection.send_binary_mut(bytes).await {
log::error!(
"Shutting down websocket connection: Failed to send binary data: {}",
e
);
break;
}
}
SentMessage::StaticText(s) => {
if let Err(e) = ws_to_connection.send_text(s).await {
log::error!(
"Shutting down websocket connection: Failed to send text data: {}",
e
);
break;
}
}
SentMessage::StaticBinary(bytes) => {
if let Err(e) = ws_to_connection.send_binary(bytes).await {
log::error!(
"Shutting down websocket connection: Failed to send binary data: {}",
e
);
break;
}
}
}
if let Err(e) = ws_to_connection.flush().await {
log::error!(
"Shutting down websocket connection: Failed to flush data: {}",
e
);
break;
}
}
});
// Keep track of whether one of sender or received have
// been dropped. If both have, we close the socket connection.
let on_close = Arc::new(OnClose(tx_closed2));
(
Sender {
inner: tx_to_ws,
closer: Arc::clone(&on_close),
},
Receiver {
inner: rx_from_ws,
closer: on_close,
},
)
}
}
#[derive(thiserror::Error, Debug)]
pub enum ConnectError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Handshake error: {0}")]
Handshake(#[from] soketto::handshake::Error),
#[error("Redirect not supported (status code: {status_code})")]
ConnectionFailedRedirect { status_code: u16 },
#[error("Connection rejected (status code: {status_code})")]
ConnectionFailedRejected { status_code: u16 },
}
/// Establish a websocket connection that you can send and receive messages from.
pub async fn connect(uri: &http::Uri) -> Result<Connection, ConnectError> {
let host = uri.host().unwrap_or("127.0.0.1");
let scheme = uri.scheme_str().unwrap_or("ws");
let mut port = 80;
if scheme == "https" || scheme == "wss" {
port = 443
}
let path = uri.path();
let port = uri.port_u16().unwrap_or(port);
let socket = TcpStream::connect((host, port)).await?;
socket.set_nodelay(true).expect("socket set_nodelay failed");
// wrap TCP stream with TLS if schema is https or wss
let socket = may_connect_tls(socket, host, scheme == "https" || scheme == "wss").await?;
// Establish a WS connection:
let mut client = Client::new(socket.compat(), host, &path);
let (ws_to_connection, ws_from_connection) = match client.handshake().await? {
ServerResponse::Accepted { .. } => client.into_builder().finish(),
ServerResponse::Redirect { status_code, .. } => {
return Err(ConnectError::ConnectionFailedRedirect { status_code })
}
ServerResponse::Rejected { status_code } => {
return Err(ConnectError::ConnectionFailedRejected { status_code })
}
};
Ok(Connection {
tx: ws_to_connection,
rx: ws_from_connection,
})
}
async fn may_connect_tls(
socket: TcpStream,
host: &str,
use_https: bool,
) -> io::Result<Box<dyn AsyncReadWrite>> {
if !use_https {
return Ok(Box::new(socket));
};
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let domain = ServerName::try_from(host)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dns name"))?;
let socket = connector.connect(domain, socket).await?;
Ok(Box::new(socket))
}

View File

@ -0,0 +1,12 @@
/// Functionality to establish a connection
mod connect;
/// A close helper that we use in sender/receiver.
mod on_close;
/// The channel based receive interface
mod receiver;
/// The channel based send interface
mod sender;
pub use connect::{connect, ConnectError, Connection, RawReceiver, RawSender};
pub use receiver::{Receiver, RecvError, RecvMessage};
pub use sender::{SendError, Sender, SentMessage};

View File

@ -0,0 +1,10 @@
use tokio::sync::broadcast;
/// A small helper to fire the "close" channel when it's dropped.
pub struct OnClose(pub broadcast::Sender<()>);
impl Drop for OnClose {
fn drop(&mut self) {
let _ = self.0.send(());
}
}

View File

@ -0,0 +1,55 @@
use super::on_close::OnClose;
use futures::{channel, Stream, StreamExt};
use std::sync::Arc;
/// Receive messages out of a connection
pub struct Receiver {
pub(super) inner: channel::mpsc::UnboundedReceiver<Result<RecvMessage, RecvError>>,
pub(super) closer: Arc<OnClose>,
}
#[derive(thiserror::Error, Debug)]
pub enum RecvError {
#[error("Text message contains invalid UTF8: {0}")]
InvalidUtf8(#[from] std::string::FromUtf8Error),
#[error("Stream finished")]
StreamFinished,
#[error("Failed to send close message")]
CloseError,
}
impl Receiver {
/// Ask the underlying Websocket connection to close.
pub async fn close(&mut self) -> Result<(), RecvError> {
self.closer.0.send(()).map_err(|_| RecvError::CloseError)?;
Ok(())
}
}
impl Stream for Receiver {
type Item = Result<RecvMessage, RecvError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map_err(|e| e.into())
}
}
/// A message that can be received from the channel interface
#[derive(Debug, Clone)]
pub enum RecvMessage {
/// Send an owned string into the socket.
Text(String),
/// Send owned bytes into the socket.
Binary(Vec<u8>),
}
impl RecvMessage {
pub fn len(&self) -> usize {
match self {
RecvMessage::Binary(b) => b.len(),
RecvMessage::Text(s) => s.len(),
}
}
}

View File

@ -0,0 +1,63 @@
use super::on_close::OnClose;
use futures::channel;
use std::sync::Arc;
/// A message that can be sent into the channel interface
#[derive(Debug, Clone)]
pub enum SentMessage {
/// Being able to send static text is primarily useful for benchmarking,
/// so that we can avoid cloning an owned string and pass a static reference
/// (one such option here is using [`Box::leak`] to generate strings with
/// static lifetimes).
StaticText(&'static str),
/// Being able to send static bytes is primarily useful for benchmarking,
/// so that we can avoid cloning an owned string and pass a static reference
/// (one such option here is using [`Box::leak`] to generate bytes with
/// static lifetimes).
StaticBinary(&'static [u8]),
/// Send an owned string into the socket.
Text(String),
/// Send owned bytes into the socket.
Binary(Vec<u8>),
}
/// Send messages into the connection
#[derive(Clone)]
pub struct Sender {
pub(super) inner: channel::mpsc::UnboundedSender<SentMessage>,
pub(super) closer: Arc<OnClose>,
}
impl Sender {
/// Ask the underlying Websocket connection to close.
pub async fn close(&mut self) -> Result<(), SendError<SentMessage>> {
self.closer.0.send(()).map_err(|_| SendError::CloseError)?;
Ok(())
}
/// Returns whether this channel is closed.
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
/// Unbounded send will always queue the message and doesn't
/// need to be awaited.
pub fn unbounded_send(&self, msg: SentMessage) -> Result<(), channel::mpsc::SendError> {
self.inner
.unbounded_send(msg)
.map_err(|e| e.into_send_error())?;
Ok(())
}
/// Convert this sender into a Sink
pub fn into_sink(
self,
) -> impl futures::Sink<SentMessage> + std::marker::Unpin + Clone + 'static {
self.inner
}
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum SendError<T: std::fmt::Debug + 'static> {
#[error("Failed to send message: {0}")]
ChannelError(#[from] flume::SendError<T>),
#[error("Failed to send close message")]
CloseError,
}

41
telemetry-core/Cargo.toml Normal file
View File

@ -0,0 +1,41 @@
[package]
name = "ghost-telemetry-core"
version = "0.1.0"
authors = ["Uncle Stinky uncle.stinky@ghostchain.io"]
edition = "2021"
[dependencies]
anyhow = "1.0.42"
bimap = "0.6.1"
bincode = "1.3.3"
bytes = "1.0.1"
flume = "0.10.8"
futures = "0.3.15"
common = { package = "ghost-telemetry-common", path = "../common" }
hex = "0.4.3"
http = "0.2.4"
hyper = "0.14.11"
log = "0.4.14"
maxminddb = "0.23.0"
num_cpus = "1.13.0"
once_cell = "1.8.0"
parking_lot = "0.12.1"
primitive-types = { version = "0.12.1", features = ["serde"] }
rayon = "1.5.0"
reqwest = { version = "0.11.4", features = ["json"] }
rustc-hash = "1.1.0"
serde = { version = "1.0.126", features = ["derive"] }
serde_json = "1.0.64"
simple_logger = "4.0.0"
smallvec = "1.6.1"
soketto = "0.7.1"
structopt = "0.3.21"
thiserror = "1.0.25"
tokio = { version = "1.10.1", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["compat"] }
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"
[dev-dependencies]
shellwords = "1.1.0"

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 MiB

View File

@ -0,0 +1,148 @@
use super::inner_loop;
use crate::find_location::find_location;
use crate::state::NodeId;
use common::id_type;
use futures::{future, Sink, SinkExt};
use std::net::IpAddr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
id_type! {
/// A unique Id is assigned per websocket connection (or more accurately,
/// per feed socket and per shard socket). This can be combined with the
/// [`LocalId`] of messages to give us a global ID.
pub struct ConnId(u64)
}
#[derive(Clone)]
pub struct Aggregator(Arc<AggregatorInternal>);
/// Options to configure the aggregator loop(s)
#[derive(Debug, Clone)]
pub struct AggregatorOpts {
/// Any node from these chains is muted
pub denylist: Vec<String>,
/// If our incoming message queue exceeds this length, we start
/// dropping non-essential messages.
pub max_queue_len: usize,
/// Flag to expose the node's details (IP address, SysInfo, HwBench) of all connected
/// nodes to the feed subscribers.
pub expose_node_details: bool,
}
struct AggregatorInternal {
/// Shards that connect are each assigned a unique connection ID.
/// This helps us know who to send messages back to (especially in
/// conjunction with the `ShardNodeId` that messages will come with).
shard_conn_id: AtomicU64,
/// Feeds that connect have their own unique connection ID, too.
feed_conn_id: AtomicU64,
/// Send messages in to the aggregator from the outside via this. This is
/// stored here so that anybody holding an `Aggregator` handle can
/// make use of it.
tx_to_aggregator: flume::Sender<inner_loop::ToAggregator>,
}
impl Aggregator {
/// Spawn a new Aggregator. This connects to the telemetry backend
pub async fn spawn(opts: AggregatorOpts) -> anyhow::Result<Aggregator> {
let (tx_to_aggregator, rx_from_external) = flume::unbounded();
// Kick off a locator task to locate nodes, which hands back a channel to make location requests
let tx_to_locator =
find_location(tx_to_aggregator.clone().into_sink().with(|(node_id, msg)| {
future::ok::<_, flume::SendError<_>>(inner_loop::ToAggregator::FromFindLocation(
node_id, msg,
))
}));
// Handle any incoming messages in our handler loop:
tokio::spawn(Aggregator::handle_messages(
rx_from_external,
tx_to_locator,
opts,
));
// Return a handle to our aggregator:
Ok(Aggregator(Arc::new(AggregatorInternal {
shard_conn_id: AtomicU64::new(1),
feed_conn_id: AtomicU64::new(1),
tx_to_aggregator,
})))
}
/// This is spawned into a separate task and handles any messages coming
/// in to the aggregator. If nobody is holding the tx side of the channel
/// any more, this task will gracefully end.
async fn handle_messages(
rx_from_external: flume::Receiver<inner_loop::ToAggregator>,
tx_to_aggregator: flume::Sender<(NodeId, IpAddr)>,
opts: AggregatorOpts,
) {
inner_loop::InnerLoop::new(tx_to_aggregator, opts)
.handle(rx_from_external)
.await;
}
/// Gather metrics from our aggregator loop
pub async fn gather_metrics(&self) -> anyhow::Result<inner_loop::Metrics> {
let (tx, rx) = flume::unbounded();
let msg = inner_loop::ToAggregator::GatherMetrics(tx);
self.0.tx_to_aggregator.send_async(msg).await?;
let metrics = rx.recv_async().await?;
Ok(metrics)
}
/// Return a sink that a shard can send messages into to be handled by the aggregator.
pub fn subscribe_shard(
&self,
) -> impl Sink<inner_loop::FromShardWebsocket, Error = anyhow::Error> + Send + Sync + Unpin + 'static
{
// Assign a unique aggregator-local ID to each connection that subscribes, and pass
// that along with every message to the aggregator loop:
let shard_conn_id = self
.0
.shard_conn_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tx_to_aggregator = self.0.tx_to_aggregator.clone();
// Calling `send` on this Sink requires Unpin. There may be a nicer way than this,
// but pinning by boxing is the easy solution for now:
Box::pin(tx_to_aggregator.into_sink().with(move |msg| async move {
Ok(inner_loop::ToAggregator::FromShardWebsocket(
shard_conn_id.into(),
msg,
))
}))
}
/// Return a sink that a feed can send messages into to be handled by the aggregator.
pub fn subscribe_feed(
&self,
) -> (
u64,
impl Sink<inner_loop::FromFeedWebsocket, Error = anyhow::Error> + Send + Sync + Unpin + 'static,
) {
// Assign a unique aggregator-local ID to each connection that subscribes, and pass
// that along with every message to the aggregator loop:
let feed_conn_id = self
.0
.feed_conn_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tx_to_aggregator = self.0.tx_to_aggregator.clone();
// Calling `send` on this Sink requires Unpin. There may be a nicer way than this,
// but pinning by boxing is the easy solution for now:
(
feed_conn_id,
Box::pin(tx_to_aggregator.into_sink().with(move |msg| async move {
Ok(inner_loop::ToAggregator::FromFeedWebsocket(
feed_conn_id.into(),
msg,
))
})),
)
}
}

View File

@ -0,0 +1,135 @@
use super::aggregator::{Aggregator, AggregatorOpts};
use super::inner_loop;
use common::EitherSink;
use futures::{Sink, SinkExt};
use inner_loop::{FromShardWebsocket, Metrics};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct AggregatorSet(Arc<AggregatorSetInner>);
pub struct AggregatorSetInner {
aggregators: Vec<Aggregator>,
next_idx: AtomicUsize,
metrics: Mutex<Vec<Metrics>>,
}
impl AggregatorSet {
/// Spawn the number of aggregators we're asked to.
pub async fn spawn(
num_aggregators: usize,
opts: AggregatorOpts,
) -> anyhow::Result<AggregatorSet> {
assert_ne!(num_aggregators, 0, "You must have 1 or more aggregator");
let aggregators = futures::future::try_join_all(
(0..num_aggregators).map(|_| Aggregator::spawn(opts.clone())),
)
.await?;
let initial_metrics = (0..num_aggregators).map(|_| Metrics::default()).collect();
let this = AggregatorSet(Arc::new(AggregatorSetInner {
aggregators,
next_idx: AtomicUsize::new(0),
metrics: Mutex::new(initial_metrics),
}));
// Start asking for metrics:
this.spawn_metrics_loops();
Ok(this)
}
/// Spawn loops which periodically ask for metrics from each internal aggregator.
/// Depending on how busy the aggregators are, these metrics won't necessarily be in
/// sync with each other.
fn spawn_metrics_loops(&self) {
let aggregators = self.0.aggregators.clone();
for (idx, a) in aggregators.into_iter().enumerate() {
let inner = Arc::clone(&self.0);
tokio::spawn(async move {
loop {
let now = tokio::time::Instant::now();
let metrics = match a.gather_metrics().await {
Ok(metrics) => metrics,
// Any error here is unlikely and probably means that the aggregator
// loop has failed completely.
Err(e) => {
log::error!("Error obtaining metrics (bailing): {}", e);
return;
}
};
// Lock, update the stored metrics and drop the lock immediately.
// We discard any error; if something went wrong talking to the inner loop,
// it's probably a fatal error
{
inner.metrics.lock().unwrap()[idx] = metrics;
}
// Sleep *at least* 10 seconds. If it takes a while to get metrics back, we'll
// end up waiting longer between requests.
tokio::time::sleep_until(now + tokio::time::Duration::from_secs(10)).await;
}
});
}
}
/// Return the latest metrics we've gathered so far from each internal aggregator.
pub fn latest_metrics(&self) -> Vec<Metrics> {
self.0.metrics.lock().unwrap().clone()
}
/// Return a sink that a shard can send messages into to be handled by all aggregators.
pub fn subscribe_shard(
&self,
) -> impl Sink<inner_loop::FromShardWebsocket, Error = anyhow::Error> + Send + Sync + Unpin + 'static
{
// Special case 1 aggregator to avoid the extra indirection and so on
// if we don't actually need it.
if self.0.aggregators.len() == 1 {
let sub = self.0.aggregators[0].subscribe_shard();
return EitherSink::a(sub);
}
let mut conns: Vec<_> = self
.0
.aggregators
.iter()
.map(|a| a.subscribe_shard())
.collect();
let (tx, rx) = flume::unbounded::<FromShardWebsocket>();
// Send every incoming message to all aggregators.
tokio::spawn(async move {
while let Ok(msg) = rx.recv_async().await {
for conn in &mut conns {
// Unbounded channel under the hood, so this await
// shouldn't ever need to yield.
if let Err(e) = conn.send(msg.clone()).await {
log::error!("Aggregator connection has failed: {}", e);
return;
}
}
}
});
EitherSink::b(tx.into_sink().sink_map_err(|e| anyhow::anyhow!("{}", e)))
}
/// Return a sink that a feed can send messages into to be handled by a single aggregator.
pub fn subscribe_feed(
&self,
) -> (
u64,
impl Sink<inner_loop::FromFeedWebsocket, Error = anyhow::Error> + Send + Sync + Unpin + 'static,
) {
let last_val = self.0.next_idx.fetch_add(1, Ordering::Relaxed);
let this_idx = (last_val + 1) % self.0.aggregators.len();
self.0.aggregators[this_idx].subscribe_feed()
}
}

View File

@ -0,0 +1,666 @@
use super::aggregator::ConnId;
use crate::feed_message::{self, FeedMessageSerializer};
use crate::state::{self, NodeId, State};
use crate::{find_location, AggregatorOpts};
use bimap::BiMap;
use common::{
internal_messages::{self, MuteReason, ShardNodeId},
node_message,
node_types::BlockHash,
time, MultiMapUnique,
};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use std::{net::IpAddr, str::FromStr};
/// Incoming messages come via subscriptions, and end up looking like this.
#[derive(Clone, Debug)]
pub enum ToAggregator {
FromShardWebsocket(ConnId, FromShardWebsocket),
FromFeedWebsocket(ConnId, FromFeedWebsocket),
FromFindLocation(NodeId, find_location::Location),
/// Hand back some metrics. The provided sender is expected not to block when
/// a message is sent into it.
GatherMetrics(flume::Sender<Metrics>),
}
/// An incoming shard connection can send these messages to the aggregator.
#[derive(Clone, Debug)]
pub enum FromShardWebsocket {
/// When the socket is opened, it'll send this first
/// so that we have a way to communicate back to it.
Initialize {
channel: flume::Sender<ToShardWebsocket>,
},
/// Tell the aggregator about a new node.
Add {
local_id: ShardNodeId,
ip: std::net::IpAddr,
node: common::node_types::NodeDetails,
genesis_hash: common::node_types::BlockHash,
},
/// Update/pass through details about a node.
Update {
local_id: ShardNodeId,
payload: node_message::Payload,
},
/// Tell the aggregator that a node has been removed when it disconnects.
Remove { local_id: ShardNodeId },
/// The shard is disconnected.
Disconnected,
}
/// The aggregator can send these messages back to a shard connection.
#[derive(Debug)]
pub enum ToShardWebsocket {
/// Mute messages to the core by passing the shard-local ID of them.
Mute {
local_id: ShardNodeId,
reason: internal_messages::MuteReason,
},
}
/// An incoming feed connection can send these messages to the aggregator.
#[derive(Clone, Debug)]
pub enum FromFeedWebsocket {
/// When the socket is opened, it'll send this first
/// so that we have a way to communicate back to it.
/// Unbounded so that slow feeds don't block aggregator
/// progress.
Initialize {
channel: flume::Sender<ToFeedWebsocket>,
},
/// The feed can subscribe to a chain to receive
/// messages relating to it.
Subscribe { chain: BlockHash },
/// An explicit ping message.
Ping { value: Box<str> },
/// The feed is disconnected.
Disconnected,
}
/// A set of metrics returned when we ask for metrics
#[derive(Clone, Debug, Default)]
pub struct Metrics {
/// When in unix MS from epoch were these metrics obtained
pub timestamp_unix_ms: u64,
/// How many chains are feeds currently subscribed to.
pub chains_subscribed_to: usize,
/// Number of subscribed feeds.
pub subscribed_feeds: usize,
/// How many messages are currently queued up in internal channels
/// waiting to be sent out to feeds.
pub total_messages_to_feeds: usize,
/// How many messages are currently queued waiting to be handled by this aggregator.
pub current_messages_to_aggregator: usize,
/// The total number of messages sent to the aggregator.
pub total_messages_to_aggregator: u64,
/// How many (non-critical) messages have been dropped by the aggregator because it was overwhelmed.
pub dropped_messages_to_aggregator: u64,
/// How many nodes are currently known to this aggregator.
pub connected_nodes: usize,
/// How many feeds are currently connected to this aggregator.
pub connected_feeds: usize,
/// How many shards are currently connected to this aggregator.
pub connected_shards: usize,
}
// The frontend sends text based commands; parse them into these messages:
impl FromStr for FromFeedWebsocket {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (cmd, value) = match s.find(':') {
Some(idx) => (&s[..idx], &s[idx + 1..]),
None => return Err(anyhow::anyhow!("Expecting format `CMD:CHAIN_NAME`")),
};
match cmd {
"ping" => Ok(FromFeedWebsocket::Ping {
value: value.into(),
}),
"subscribe" => Ok(FromFeedWebsocket::Subscribe {
chain: value.parse()?,
}),
_ => return Err(anyhow::anyhow!("Command {} not recognised", cmd)),
}
}
}
/// The aggregator can send these messages back to a feed connection.
#[derive(Clone, Debug)]
pub enum ToFeedWebsocket {
Bytes(bytes::Bytes),
}
/// Instances of this are responsible for handling incoming and
/// outgoing messages in the main aggregator loop.
pub struct InnerLoop {
/// The state of our chains and nodes lives here:
node_state: State,
/// We maintain a mapping between NodeId and ConnId+LocalId, so that we know
/// which messages are about which nodes.
node_ids: BiMap<NodeId, (ConnId, ShardNodeId)>,
/// Keep track of how to send messages out to feeds.
feed_channels: HashMap<ConnId, flume::Sender<ToFeedWebsocket>>,
/// Keep track of how to send messages out to shards.
shard_channels: HashMap<ConnId, flume::Sender<ToShardWebsocket>>,
/// Which feeds are subscribed to a given chain?
chain_to_feed_conn_ids: MultiMapUnique<BlockHash, ConnId>,
/// Send messages here to make geographical location requests.
tx_to_locator: flume::Sender<(NodeId, IpAddr)>,
/// How big can the queue of messages coming in to the aggregator get before messages
/// are prioritised and dropped to try and get back on track.
max_queue_len: usize,
/// Flag to expose the node's details (IP address, SysInfo, HwBench) of all connected
/// nodes to the feed subscribers.
expose_node_details: bool,
}
impl InnerLoop {
/// Create a new inner loop handler with the various state it needs.
pub fn new(tx_to_locator: flume::Sender<(NodeId, IpAddr)>, opts: AggregatorOpts) -> Self {
InnerLoop {
node_state: State::new(opts.denylist),
node_ids: BiMap::new(),
feed_channels: HashMap::new(),
shard_channels: HashMap::new(),
chain_to_feed_conn_ids: MultiMapUnique::new(),
tx_to_locator,
max_queue_len: opts.max_queue_len,
expose_node_details: opts.expose_node_details,
}
}
/// Start handling and responding to incoming messages.
pub async fn handle(mut self, rx_from_external: flume::Receiver<ToAggregator>) {
let max_queue_len = self.max_queue_len;
let (metered_tx, metered_rx) = flume::unbounded();
// Keep count of the number of dropped/total messages for the sake of metric reporting
let dropped_messages = Arc::new(AtomicU64::new(0));
let total_messages = Arc::new(AtomicU64::new(0));
// Actually handle all of our messages, but before we get here, we
// check the length of the queue below to decide whether or not to
// pass the message on to this.
let dropped_messages2 = Arc::clone(&dropped_messages);
let total_messages2 = Arc::clone(&total_messages);
tokio::spawn(async move {
while let Ok(msg) = metered_rx.recv_async().await {
match msg {
ToAggregator::FromFeedWebsocket(feed_conn_id, msg) => {
self.handle_from_feed(feed_conn_id, msg)
}
ToAggregator::FromShardWebsocket(shard_conn_id, msg) => {
self.handle_from_shard(shard_conn_id, msg)
}
ToAggregator::FromFindLocation(node_id, location) => {
self.handle_from_find_location(node_id, location)
}
ToAggregator::GatherMetrics(tx) => self.handle_gather_metrics(
tx,
metered_rx.len(),
dropped_messages2.load(Ordering::Relaxed),
total_messages2.load(Ordering::Relaxed),
),
}
}
});
while let Ok(msg) = rx_from_external.recv_async().await {
total_messages.fetch_add(1, Ordering::Relaxed);
// ignore node updates if we have too many messages to handle, in an attempt
// to reduce the queue length back to something reasonable, lest it get out of
// control and start consuming a load of memory.
if metered_tx.len() > max_queue_len {
if matches!(
msg,
ToAggregator::FromShardWebsocket(.., FromShardWebsocket::Update { .. })
) {
// Note: this wraps on overflow (which is probably the best
// behaviour for graphing it anyway)
dropped_messages.fetch_add(1, Ordering::Relaxed);
continue;
}
}
if let Err(e) = metered_tx.send(msg) {
log::error!("Cannot send message into aggregator: {e}");
break;
}
}
}
/// Gather and return some metrics.
fn handle_gather_metrics(
&mut self,
rx: flume::Sender<Metrics>,
current_messages_to_aggregator: usize,
dropped_messages_to_aggregator: u64,
total_messages_to_aggregator: u64,
) {
let timestamp_unix_ms = time::now();
let connected_nodes = self.node_ids.len();
let subscribed_feeds = self.chain_to_feed_conn_ids.num_values();
let chains_subscribed_to = self.chain_to_feed_conn_ids.num_keys();
let connected_shards = self.shard_channels.len();
let connected_feeds = self.feed_channels.len();
let total_messages_to_feeds: usize = self.feed_channels.values().map(|c| c.len()).sum();
// Ignore error sending; assume the receiver stopped caring and dropped the channel:
let _ = rx.send(Metrics {
timestamp_unix_ms,
chains_subscribed_to,
subscribed_feeds,
total_messages_to_feeds,
current_messages_to_aggregator,
total_messages_to_aggregator,
dropped_messages_to_aggregator,
connected_nodes,
connected_feeds,
connected_shards,
});
}
/// Handle messages that come from the node geographical locator.
fn handle_from_find_location(&mut self, node_id: NodeId, location: find_location::Location) {
self.node_state
.update_node_location(node_id, location.clone());
if let Some(loc) = location {
let mut feed_message_serializer = FeedMessageSerializer::new();
feed_message_serializer.push(feed_message::LocatedNode(
node_id.get_chain_node_id().into(),
loc.latitude,
loc.longitude,
&loc.city,
));
let chain_genesis_hash = self
.node_state
.get_chain_by_node_id(node_id)
.map(|chain| chain.genesis_hash());
if let Some(chain_genesis_hash) = chain_genesis_hash {
self.finalize_and_broadcast_to_chain_feeds(
&chain_genesis_hash,
feed_message_serializer,
);
}
}
}
/// Handle messages coming from shards.
fn handle_from_shard(&mut self, shard_conn_id: ConnId, msg: FromShardWebsocket) {
match msg {
FromShardWebsocket::Initialize { channel } => {
self.shard_channels.insert(shard_conn_id, channel);
}
FromShardWebsocket::Add {
local_id,
ip,
mut node,
genesis_hash,
} => {
// Conditionally modify the node's details to include the IP address.
node.ip = self.expose_node_details.then_some(ip.to_string().into());
match self.node_state.add_node(genesis_hash, node) {
state::AddNodeResult::ChainOnDenyList => {
if let Some(shard_conn) = self.shard_channels.get_mut(&shard_conn_id) {
let _ = shard_conn.send(ToShardWebsocket::Mute {
local_id,
reason: MuteReason::ChainNotAllowed,
});
}
}
state::AddNodeResult::ChainOverQuota => {
if let Some(shard_conn) = self.shard_channels.get_mut(&shard_conn_id) {
let _ = shard_conn.send(ToShardWebsocket::Mute {
local_id,
reason: MuteReason::Overquota,
});
}
}
state::AddNodeResult::NodeAddedToChain(details) => {
let node_id = details.id;
// Record ID <-> (shardId,localId) for future messages:
self.node_ids.insert(node_id, (shard_conn_id, local_id));
// Don't hold onto details too long because we want &mut self later:
let new_chain_label = details.new_chain_label.to_owned();
let chain_node_count = details.chain_node_count;
let has_chain_label_changed = details.has_chain_label_changed;
// Tell chain subscribers about the node we've just added:
let mut feed_messages_for_chain = FeedMessageSerializer::new();
feed_messages_for_chain.push(feed_message::AddedNode(
node_id.get_chain_node_id().into(),
&details.node,
self.expose_node_details,
));
self.finalize_and_broadcast_to_chain_feeds(
&genesis_hash,
feed_messages_for_chain,
);
// Tell everybody about the new node count and potential rename:
let mut feed_messages_for_all = FeedMessageSerializer::new();
if has_chain_label_changed {
feed_messages_for_all.push(feed_message::RemovedChain(genesis_hash));
}
feed_messages_for_all.push(feed_message::AddedChain(
&new_chain_label,
genesis_hash,
chain_node_count,
));
self.finalize_and_broadcast_to_all_feeds(feed_messages_for_all);
// Ask for the geographical location of the node.
let _ = self.tx_to_locator.send((node_id, ip));
}
}
}
FromShardWebsocket::Remove { local_id } => {
let node_id = match self.node_ids.remove_by_right(&(shard_conn_id, local_id)) {
Some((node_id, _)) => node_id,
None => {
// It's possible that some race between removing and disconnecting shards might lead to
// more than one remove message for the same node. This isn't really a problem, but we
// hope it won't happen so make a note if it does:
log::debug!(
"Remove: Cannot find ID for node with shard/connectionId of {shard_conn_id:?}/{local_id:?}"
);
return;
}
};
self.remove_nodes_and_broadcast_result(Some(node_id));
}
FromShardWebsocket::Update { local_id, payload } => {
let node_id = match self.node_ids.get_by_right(&(shard_conn_id, local_id)) {
Some(id) => *id,
None => {
log::error!(
"Update: Cannot find ID for node with shard/connectionId of {shard_conn_id:?}/{local_id:?}"
);
return;
}
};
let mut feed_message_serializer = FeedMessageSerializer::new();
self.node_state.update_node(
node_id,
payload,
&mut feed_message_serializer,
self.expose_node_details,
);
if let Some(chain) = self.node_state.get_chain_by_node_id(node_id) {
let genesis_hash = chain.genesis_hash();
self.finalize_and_broadcast_to_chain_feeds(
&genesis_hash,
feed_message_serializer,
);
}
}
FromShardWebsocket::Disconnected => {
self.shard_channels.remove(&shard_conn_id);
// Find all nodes associated with this shard connection ID:
let node_ids_to_remove: Vec<NodeId> = self
.node_ids
.iter()
.filter(|(_, &(this_shard_conn_id, _))| shard_conn_id == this_shard_conn_id)
.map(|(&node_id, _)| node_id)
.collect();
// ... and remove them:
self.remove_nodes_and_broadcast_result(node_ids_to_remove);
}
}
}
/// Handle messages coming from feeds.
fn handle_from_feed(&mut self, feed_conn_id: ConnId, msg: FromFeedWebsocket) {
match msg {
FromFeedWebsocket::Initialize { channel } => {
self.feed_channels.insert(feed_conn_id, channel.clone());
// Tell the new feed subscription some basic things to get it going:
let mut feed_serializer = FeedMessageSerializer::new();
feed_serializer.push(feed_message::Version(32));
for chain in self.node_state.iter_chains() {
feed_serializer.push(feed_message::AddedChain(
chain.label(),
chain.genesis_hash(),
chain.node_count(),
));
}
// Send this to the channel that subscribed:
if let Some(bytes) = feed_serializer.into_finalized() {
let _ = channel.send(ToFeedWebsocket::Bytes(bytes));
}
}
FromFeedWebsocket::Ping { value } => {
let feed_channel = match self.feed_channels.get_mut(&feed_conn_id) {
Some(chan) => chan,
None => return,
};
// Pong!
let mut feed_serializer = FeedMessageSerializer::new();
feed_serializer.push(feed_message::Pong(&value));
if let Some(bytes) = feed_serializer.into_finalized() {
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes));
}
}
FromFeedWebsocket::Subscribe { chain } => {
let feed_channel = match self.feed_channels.get_mut(&feed_conn_id) {
Some(chan) => chan,
None => return,
};
// Unsubscribe from previous chain if subscribed to one:
let old_genesis_hash = self.chain_to_feed_conn_ids.remove_value(&feed_conn_id);
// Get old chain if there was one:
let node_state = &self.node_state;
let old_chain =
old_genesis_hash.and_then(|hash| node_state.get_chain_by_genesis_hash(&hash));
// Get new chain, ignoring the rest if it doesn't exist.
let new_chain = match self.node_state.get_chain_by_genesis_hash(&chain) {
Some(chain) => chain,
None => return,
};
// Send messages to the feed about this subscription:
let mut feed_serializer = FeedMessageSerializer::new();
if let Some(old_chain) = old_chain {
feed_serializer.push(feed_message::UnsubscribedFrom(old_chain.genesis_hash()));
}
feed_serializer.push(feed_message::SubscribedTo(new_chain.genesis_hash()));
feed_serializer.push(feed_message::TimeSync(time::now()));
feed_serializer.push(feed_message::BestBlock(
new_chain.best_block().height,
new_chain.timestamp(),
new_chain.average_block_time(),
));
feed_serializer.push(feed_message::BestFinalized(
new_chain.finalized_block().height,
new_chain.finalized_block().hash,
));
feed_serializer.push(feed_message::ChainStatsUpdate(new_chain.stats()));
if let Some(bytes) = feed_serializer.into_finalized() {
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes));
}
// If many (eg 10k) nodes are connected, serializing all of their info takes time.
// So, parallelise this with Rayon, but we still send out messages for each node in order
// (which is helpful for the UI as it tries to maintain a sorted list of nodes). The chunk
// size is the max number of node info we fit into 1 message; smaller messages allow the UI
// to react a little faster and not have to wait for a larger update to come in. A chunk size
// of 64 means each message is ~32k.
use rayon::prelude::*;
let all_feed_messages: Vec<_> = new_chain
.nodes_slice()
.par_iter()
.enumerate()
.chunks(64)
.filter_map(|nodes| {
let mut feed_serializer = FeedMessageSerializer::new();
for (node_id, node) in nodes
.iter()
.filter_map(|&(idx, n)| n.as_ref().map(|n| (idx, n)))
{
feed_serializer.push(feed_message::AddedNode(
node_id,
node,
self.expose_node_details,
));
feed_serializer.push(feed_message::FinalizedBlock(
node_id,
node.finalized().height,
node.finalized().hash,
));
if node.stale() {
feed_serializer.push(feed_message::StaleNode(node_id));
}
}
feed_serializer.into_finalized()
})
.collect();
for bytes in all_feed_messages {
let _ = feed_channel.send(ToFeedWebsocket::Bytes(bytes));
}
// Actually make a note of the new chain subscription:
let new_genesis_hash = new_chain.genesis_hash();
self.chain_to_feed_conn_ids
.insert(new_genesis_hash, feed_conn_id);
}
FromFeedWebsocket::Disconnected => {
// The feed has disconnected; clean up references to it:
self.chain_to_feed_conn_ids.remove_value(&feed_conn_id);
self.feed_channels.remove(&feed_conn_id);
}
}
}
/// Remove all of the node IDs provided and broadcast messages to feeds as needed.
fn remove_nodes_and_broadcast_result(&mut self, node_ids: impl IntoIterator<Item = NodeId>) {
// Group by chain to simplify the handling of feed messages:
let mut node_ids_per_chain: HashMap<BlockHash, Vec<NodeId>> = HashMap::new();
for node_id in node_ids.into_iter() {
if let Some(chain) = self.node_state.get_chain_by_node_id(node_id) {
node_ids_per_chain
.entry(chain.genesis_hash())
.or_default()
.push(node_id);
}
}
// Remove the nodes for each chain
let mut feed_messages_for_all = FeedMessageSerializer::new();
for (chain_label, node_ids) in node_ids_per_chain {
let mut feed_messages_for_chain = FeedMessageSerializer::new();
for node_id in node_ids {
self.remove_node(
node_id,
&mut feed_messages_for_chain,
&mut feed_messages_for_all,
);
}
self.finalize_and_broadcast_to_chain_feeds(&chain_label, feed_messages_for_chain);
}
self.finalize_and_broadcast_to_all_feeds(feed_messages_for_all);
}
/// Remove a single node by its ID, pushing any messages we'd want to send
/// out to feeds onto the provided feed serializers. Doesn't actually send
/// anything to the feeds; just updates state as needed.
fn remove_node(
&mut self,
node_id: NodeId,
feed_for_chain: &mut FeedMessageSerializer,
feed_for_all: &mut FeedMessageSerializer,
) {
// Remove our top level association (this may already have been done).
self.node_ids.remove_by_left(&node_id);
let removed_details = match self.node_state.remove_node(node_id) {
Some(remove_details) => remove_details,
None => {
log::error!("Could not find node {node_id:?}");
return;
}
};
// The chain has been removed (no nodes left in it, or it was renamed):
if removed_details.chain_node_count == 0 || removed_details.has_chain_label_changed {
feed_for_all.push(feed_message::RemovedChain(
removed_details.chain_genesis_hash,
));
}
// If the chain still exists, tell everybody about the new label or updated node count:
if removed_details.chain_node_count != 0 {
feed_for_all.push(feed_message::AddedChain(
&removed_details.new_chain_label,
removed_details.chain_genesis_hash,
removed_details.chain_node_count,
));
}
// Assuming the chain hasn't gone away, tell chain subscribers about the node removal
if removed_details.chain_node_count != 0 {
feed_for_chain.push(feed_message::RemovedNode(
node_id.get_chain_node_id().into(),
));
}
}
/// Finalize a [`FeedMessageSerializer`] and broadcast the result to feeds for the chain.
fn finalize_and_broadcast_to_chain_feeds(
&mut self,
genesis_hash: &BlockHash,
serializer: FeedMessageSerializer,
) {
if let Some(bytes) = serializer.into_finalized() {
self.broadcast_to_chain_feeds(genesis_hash, ToFeedWebsocket::Bytes(bytes));
}
}
/// Send a message to all chain feeds.
fn broadcast_to_chain_feeds(&mut self, genesis_hash: &BlockHash, message: ToFeedWebsocket) {
if let Some(feeds) = self.chain_to_feed_conn_ids.get_values(genesis_hash) {
for &feed_id in feeds {
if let Some(chan) = self.feed_channels.get_mut(&feed_id) {
let _ = chan.send(message.clone());
}
}
}
}
/// Finalize a [`FeedMessageSerializer`] and broadcast the result to all feeds
fn finalize_and_broadcast_to_all_feeds(&mut self, serializer: FeedMessageSerializer) {
if let Some(bytes) = serializer.into_finalized() {
self.broadcast_to_all_feeds(ToFeedWebsocket::Bytes(bytes));
}
}
/// Send a message to everybody.
fn broadcast_to_all_feeds(&mut self, message: ToFeedWebsocket) {
for chan in self.feed_channels.values_mut() {
let _ = chan.send(message.clone());
}
}
}

View File

@ -0,0 +1,9 @@
mod aggregator;
mod aggregator_set;
mod inner_loop;
// Expose the various message types that can be worked with externally:
pub use aggregator::AggregatorOpts;
pub use inner_loop::{FromFeedWebsocket, FromShardWebsocket, ToFeedWebsocket, ToShardWebsocket};
pub use aggregator_set::*;

View File

@ -0,0 +1,234 @@
use serde::Serialize;
use crate::state::Node;
use common::node_types::{
BlockDetails, BlockHash, BlockNumber, NodeHardware, NodeIO, NodeStats, Timestamp,
};
use serde_json::to_writer;
type FeedNodeId = usize;
pub trait FeedMessage {
const ACTION: u8;
}
pub trait FeedMessageWrite: FeedMessage {
fn write_to_feed(&self, ser: &mut FeedMessageSerializer);
}
impl<T> FeedMessageWrite for T
where
T: FeedMessage + Serialize,
{
fn write_to_feed(&self, ser: &mut FeedMessageSerializer) {
ser.write(self)
}
}
pub struct FeedMessageSerializer {
/// Current buffer.
buffer: Vec<u8>,
}
const BUFCAP: usize = 128;
impl FeedMessageSerializer {
pub fn new() -> Self {
Self {
buffer: Vec::with_capacity(BUFCAP),
}
}
pub fn push<Message>(&mut self, msg: Message)
where
Message: FeedMessageWrite,
{
let glue = match self.buffer.len() {
0 => b'[',
_ => b',',
};
self.buffer.push(glue);
self.write(&Message::ACTION);
self.buffer.push(b',');
msg.write_to_feed(self);
}
fn write<S>(&mut self, value: &S)
where
S: Serialize,
{
let _ = to_writer(&mut self.buffer, value);
}
/// Return the bytes that we've serialized so far, consuming the serializer.
pub fn into_finalized(mut self) -> Option<bytes::Bytes> {
if self.buffer.is_empty() {
return None;
}
self.buffer.push(b']');
Some(self.buffer.into())
}
}
macro_rules! actions {
($($action:literal: $t:ty,)*) => {
$(
impl FeedMessage for $t {
const ACTION: u8 = $action;
}
)*
}
}
actions! {
0: Version,
1: BestBlock,
2: BestFinalized,
3: AddedNode<'_>,
4: RemovedNode,
5: LocatedNode<'_>,
6: ImportedBlock<'_>,
7: FinalizedBlock,
8: NodeStatsUpdate<'_>,
9: Hardware<'_>,
10: TimeSync,
11: AddedChain<'_>,
12: RemovedChain,
13: SubscribedTo,
14: UnsubscribedFrom,
15: Pong<'_>,
// Note; some now-unused messages were removed between IDs 15 and 20.
// We maintain existing IDs for backward compatibility.
20: StaleNode,
21: NodeIOUpdate<'_>,
22: ChainStatsUpdate<'_>,
}
#[derive(Serialize)]
pub struct Version(pub usize);
#[derive(Serialize)]
pub struct BestBlock(pub BlockNumber, pub Timestamp, pub Option<u64>);
#[derive(Serialize)]
pub struct BestFinalized(pub BlockNumber, pub BlockHash);
pub struct AddedNode<'a>(pub FeedNodeId, pub &'a Node, pub bool);
#[derive(Serialize)]
pub struct RemovedNode(pub FeedNodeId);
#[derive(Serialize)]
pub struct LocatedNode<'a>(pub FeedNodeId, pub f32, pub f32, pub &'a str);
#[derive(Serialize)]
pub struct ImportedBlock<'a>(pub FeedNodeId, pub &'a BlockDetails);
#[derive(Serialize)]
pub struct FinalizedBlock(pub FeedNodeId, pub BlockNumber, pub BlockHash);
#[derive(Serialize)]
pub struct NodeStatsUpdate<'a>(pub FeedNodeId, pub &'a NodeStats);
#[derive(Serialize)]
pub struct NodeIOUpdate<'a>(pub FeedNodeId, pub &'a NodeIO);
#[derive(Serialize)]
pub struct Hardware<'a>(pub FeedNodeId, pub &'a NodeHardware);
#[derive(Serialize)]
pub struct TimeSync(pub u64);
#[derive(Serialize)]
pub struct AddedChain<'a>(pub &'a str, pub BlockHash, pub usize);
#[derive(Serialize)]
pub struct RemovedChain(pub BlockHash);
#[derive(Serialize)]
pub struct SubscribedTo(pub BlockHash);
#[derive(Serialize)]
pub struct UnsubscribedFrom(pub BlockHash);
#[derive(Serialize)]
pub struct Pong<'a>(pub &'a str);
#[derive(Serialize)]
pub struct StaleNode(pub FeedNodeId);
impl FeedMessageWrite for AddedNode<'_> {
fn write_to_feed(&self, ser: &mut FeedMessageSerializer) {
let AddedNode(nid, node, expose_node_details) = self;
let details = node.details();
// Always include sysinfo, conditionally include ip and hwbench based on expose_node_details.
let node_hwbench = node.hwbench();
let ip = if *expose_node_details {
&details.ip
} else {
&None
};
let sys_info = &details.sysinfo;
let hwbench = if *expose_node_details {
&node_hwbench
} else {
&None
};
let details = (
&details.name,
&details.implementation,
&details.version,
&details.validator,
&details.network_id,
&details.target_os,
&details.target_arch,
&details.target_env,
&ip,
&sys_info,
&hwbench,
);
ser.write(&(
nid,
details,
node.stats(),
node.io(),
node.hardware(),
node.block_details(),
&node.location(),
&node.startup_time(),
));
}
}
#[derive(Serialize)]
pub struct ChainStatsUpdate<'a>(pub &'a ChainStats);
#[derive(Serialize, PartialEq, Eq, Default)]
pub struct Ranking<K> {
pub list: Vec<(K, u64)>,
pub other: u64,
pub unknown: u64,
}
#[derive(Serialize, PartialEq, Eq, Default)]
pub struct ChainStats {
pub version: Ranking<String>,
pub target_os: Ranking<String>,
pub target_arch: Ranking<String>,
pub cpu: Ranking<String>,
pub memory: Ranking<(u32, Option<u32>)>,
pub core_count: Ranking<u32>,
pub linux_kernel: Ranking<String>,
pub linux_distro: Ranking<String>,
pub is_virtual_machine: Ranking<bool>,
pub cpu_hashrate_score: Ranking<(u32, Option<u32>)>,
pub memory_memcpy_score: Ranking<(u32, Option<u32>)>,
pub disk_sequential_write_score: Ranking<(u32, Option<u32>)>,
pub disk_random_write_score: Ranking<(u32, Option<u32>)>,
pub cpu_vendor: Ranking<String>,
}

View File

@ -0,0 +1,135 @@
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use futures::{Sink, SinkExt};
use maxminddb::{geoip2::City, Reader as GeoIpReader};
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use common::node_types::NodeLocation;
/// The returned location is optional; it may be None if not found.
pub type Location = Option<Arc<NodeLocation>>;
/// This is responsible for taking an IP address and attempting
/// to find a geographical location from this
pub fn find_location<Id, R>(response_chan: R) -> flume::Sender<(Id, IpAddr)>
where
R: Sink<(Id, Option<Arc<NodeLocation>>)> + Unpin + Send + Clone + 'static,
Id: Clone + Send + 'static,
{
let (tx, rx) = flume::unbounded();
// cache entries
let mut cache: FxHashMap<IpAddr, Arc<NodeLocation>> = FxHashMap::default();
// Default entry for localhost
cache.insert(
Ipv4Addr::new(127, 0, 0, 1).into(),
Arc::new(NodeLocation {
latitude: 52.516_6667,
longitude: 13.4,
city: "Berlin".into(),
}),
);
// Create a locator with our cache. This is used to obtain locations.
let locator = Locator::new(cache);
// Spawn a loop to handle location requests
tokio::spawn(async move {
loop {
while let Ok((id, ip_address)) = rx.recv_async().await {
let mut response_chan = response_chan.clone();
let locator = locator.clone();
tokio::spawn(async move {
let location = tokio::task::spawn_blocking(move || locator.locate(ip_address))
.await
.expect("Locate never panics");
let _ = response_chan.send((id, location)).await;
});
}
}
});
tx
}
/// This struct can be used to make location requests, given
/// an IPV4 or IPV6 address.
#[derive(Debug, Clone)]
struct Locator {
city: Arc<maxminddb::Reader<&'static [u8]>>,
cache: Arc<RwLock<FxHashMap<IpAddr, Arc<NodeLocation>>>>,
}
impl Locator {
/// GeoLite database release data: 2024-03-29
/// Database and Contents Copyright (c) 2024 MaxMind, Inc.
/// To download the latest version visit: https://dev.maxmind.com/geoip/geolite2-free-geolocation-data.
///
/// Use of this MaxMind product is governed by MaxMind's GeoLite2 End User License Agreement,
/// which can be viewed at https://www.maxmind.com/en/geolite2/eula.
/// This database incorporates GeoNames [https://www.geonames.org] geographical data,
/// which is made available under the Creative Commons Attribution 4.0 License.
/// To view a copy of this license, visit https://creativecommons.org/licenses/by/4.0/.
const CITY_DATA: &'static [u8] = include_bytes!("GeoLite2-City.mmdb");
pub fn new(cache: FxHashMap<IpAddr, Arc<NodeLocation>>) -> Self {
Self {
city: GeoIpReader::from_source(Self::CITY_DATA)
.map(Arc::new)
.expect("City data is always valid"),
cache: Arc::new(RwLock::new(cache)),
}
}
pub fn locate(&self, ip: IpAddr) -> Option<Arc<NodeLocation>> {
// Return location quickly if it's cached:
let cached_loc = {
let cache_reader = self.cache.read();
cache_reader.get(&ip).cloned()
};
if cached_loc.is_some() {
return cached_loc;
}
let City { city, location, .. } = self.city.lookup(ip.into()).ok()?;
let city = city
.as_ref()?
.names
.as_ref()?
.get("en")?
.to_string()
.into_boxed_str();
let latitude = location.as_ref()?.latitude? as f32;
let longitude = location?.longitude? as f32;
let location = Arc::new(NodeLocation {
city,
latitude,
longitude,
});
self.cache.write().insert(ip, Arc::clone(&location));
Some(location)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn locator_construction() {
Locator::new(Default::default());
}
#[test]
fn locate_random_ip() {
let ip = "12.5.56.25".parse().unwrap();
let node_location = Locator::new(Default::default()).locate(ip).unwrap();
assert_eq!(&*node_location.city, "Gardena");
}
}

545
telemetry-core/src/main.rs Normal file
View File

@ -0,0 +1,545 @@
mod aggregator;
mod feed_message;
mod find_location;
mod state;
use std::str::FromStr;
use tokio::time::{Duration, Instant};
use aggregator::{
AggregatorOpts, AggregatorSet, FromFeedWebsocket, FromShardWebsocket, ToFeedWebsocket,
ToShardWebsocket,
};
use bincode::Options;
use common::http_utils;
use common::internal_messages;
use common::ready_chunks_all::ReadyChunksAll;
use futures::{SinkExt, StreamExt};
use hyper::{Method, Response};
use simple_logger::SimpleLogger;
use structopt::StructOpt;
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const AUTHORS: &str = env!("CARGO_PKG_AUTHORS");
const NAME: &str = "Ghost Telemetry Backend Core";
const ABOUT: &str = "This is the Telemetry Backend Core that receives telemetry messages \
from Ghost/Casper nodes and provides the data to a subscribed feed";
#[derive(StructOpt, Debug)]
#[structopt(name = NAME, version = VERSION, author = AUTHORS, about = ABOUT)]
struct Opts {
/// This is the socket address that Telemetry is listening to. This is restricted to
/// localhost (127.0.0.1) by default and should be fine for most use cases. If
/// you are using Telemetry in a container, you likely want to set this to '0.0.0.0:8000'
#[structopt(short = "l", long = "listen", default_value = "127.0.0.1:8000")]
socket: std::net::SocketAddr,
/// The desired log level; one of 'error', 'warn', 'info', 'debug' or 'trace', where
/// 'error' only logs errors and 'trace' logs everything.
#[structopt(long = "log", default_value = "info")]
log_level: log::LevelFilter,
/// Space delimited list of the names of chains that are not allowed to connect to
/// telemetry. Case sensitive.
#[structopt(long, required = false)]
denylist: Vec<String>,
/// If it takes longer than this number of seconds to send the current batch of messages
/// to a feed, the feed connection will be closed.
#[structopt(long, default_value = "10")]
feed_timeout: u64,
/// Number of worker threads to spawn. If "0" is given, use the number of CPUs available
/// on the machine. If no value is given, use an internal default that we have deemed sane.
#[structopt(long)]
worker_threads: Option<usize>,
/// Each aggregator keeps track of the entire node state. Feed subscriptions are split across
/// aggregators.
#[structopt(long)]
num_aggregators: Option<usize>,
/// How big can the message queue for each aggregator grow before we start dropping non-essential
/// messages in an attempt to let it reduce?
#[structopt(long)]
aggregator_queue_len: Option<usize>,
/// Flag to expose the node's details (IP address, SysInfo, HwBench) of all connected
/// nodes to the feed subscribers.
#[structopt(long)]
pub expose_node_details: bool,
}
fn main() {
let opts = Opts::from_args();
SimpleLogger::new()
.with_level(opts.log_level)
.init()
.expect("Must be able to start a logger");
log::info!("Starting Telemetry Core version: {}", VERSION);
let worker_threads = match opts.worker_threads {
Some(0) => num_cpus::get(),
Some(n) => n,
// By default, use a max of 8 worker threads, as perf
// testing has found that to be a good sweet spot.
None => usize::min(num_cpus::get(), 8),
};
let num_aggregators = match opts.num_aggregators {
Some(0) => num_cpus::get(),
Some(n) => n,
// For now, we just have 1 aggregator loop by default,
// but we may want to be smarter here eventually.
None => 1,
};
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(worker_threads)
.thread_name("telemetry_core_worker")
.build()
.unwrap()
.block_on(async {
if let Err(e) = start_server(num_aggregators, opts).await {
log::error!("Error starting server: {}", e);
}
});
}
/// Declare our routes and start the server.
async fn start_server(num_aggregators: usize, opts: Opts) -> anyhow::Result<()> {
let aggregator_queue_len = opts.aggregator_queue_len.unwrap_or(10_000);
let aggregator = AggregatorSet::spawn(
num_aggregators,
AggregatorOpts {
max_queue_len: aggregator_queue_len,
denylist: opts.denylist,
expose_node_details: opts.expose_node_details,
},
)
.await?;
let socket_addr = opts.socket;
let feed_timeout = opts.feed_timeout;
let server = http_utils::start_server(socket_addr, move |addr, req| {
let aggregator = aggregator.clone();
async move {
match (req.method(), req.uri().path().trim_end_matches('/')) {
// Check that the server is up and running:
(&Method::GET, "/health") => Ok(Response::new("OK".into())),
// Subscribe to feed messages:
(&Method::GET, "/feed") => {
log::info!("Opening /feed connection from {:?}", addr);
Ok(http_utils::upgrade_to_websocket(
req,
move |ws_send, ws_recv| async move {
let (feed_id, tx_to_aggregator) = aggregator.subscribe_feed();
let (mut tx_to_aggregator, mut ws_send) =
handle_feed_websocket_connection(
ws_send,
ws_recv,
tx_to_aggregator,
feed_timeout,
feed_id,
)
.await;
log::info!("Closing /feed connection from {:?}", addr);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator.send(FromFeedWebsocket::Disconnected).await;
let _ = ws_send.close().await;
},
))
}
// Subscribe to shard messages:
(&Method::GET, "/shard_submit") => {
Ok(http_utils::upgrade_to_websocket(
req,
move |ws_send, ws_recv| async move {
log::info!("Opening /shard_submit connection from {:?}", addr);
let tx_to_aggregator = aggregator.subscribe_shard();
let (mut tx_to_aggregator, mut ws_send) =
handle_shard_websocket_connection(
ws_send,
ws_recv,
tx_to_aggregator,
)
.await;
log::info!("Closing /shard_submit connection from {:?}", addr);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator
.send(FromShardWebsocket::Disconnected)
.await;
let _ = ws_send.close().await;
},
))
}
// Return metrics in a prometheus-friendly text based format:
(&Method::GET, "/metrics") => Ok(return_prometheus_metrics(aggregator).await),
// 404 for anything else:
_ => Ok(Response::builder()
.status(404)
.body("Not found".into())
.unwrap()),
}
}
});
server.await?;
Ok(())
}
/// This handles messages coming to/from a shard connection
async fn handle_shard_websocket_connection<S>(
mut ws_send: http_utils::WsSender,
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromShardWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
let (tx_to_shard_conn, rx_from_aggregator) = flume::unbounded();
// Tell the aggregator about this new connection, and give it a way to send messages to us:
let init_msg = FromShardWebsocket::Initialize {
channel: tx_to_shard_conn,
};
if let Err(e) = tx_to_aggregator.send(init_msg).await {
log::error!("Error sending message to aggregator: {}", e);
return (tx_to_aggregator, ws_send);
}
// Channels to notify each loop if the other closes:
let (recv_closer_tx, mut recv_closer_rx) = tokio::sync::oneshot::channel::<()>();
let (send_closer_tx, mut send_closer_rx) = tokio::sync::oneshot::channel::<()>();
// Receive messages from a shard:
let recv_handle = tokio::spawn(async move {
loop {
let mut bytes = Vec::new();
// Receive a message, or bail if closer called. We don't care about cancel safety;
// if we're halfway through receiving a message, no biggie since we're closing the
// connection anyway.
let msg_info = tokio::select! {
msg_info = ws_recv.receive_data(&mut bytes) => msg_info,
_ = &mut recv_closer_rx => break
};
// Handle the socket closing, or errors receiving the message.
if let Err(soketto::connection::Error::Closed) = msg_info {
break;
}
if let Err(e) = msg_info {
log::error!("Shutting down websocket connection: Failed to receive data: {e}");
break;
}
let msg: internal_messages::FromShardAggregator =
match bincode::options().deserialize(&bytes) {
Ok(msg) => msg,
Err(e) => {
log::error!("Failed to deserialize message from shard; booting it: {e}");
break;
}
};
// Convert and send to the aggregator:
let aggregator_msg = match msg {
internal_messages::FromShardAggregator::AddNode {
ip,
node,
local_id,
genesis_hash,
} => FromShardWebsocket::Add {
ip,
node,
genesis_hash,
local_id,
},
internal_messages::FromShardAggregator::UpdateNode { payload, local_id } => {
FromShardWebsocket::Update { local_id, payload }
}
internal_messages::FromShardAggregator::RemoveNode { local_id } => {
FromShardWebsocket::Remove { local_id }
}
};
if let Err(e) = tx_to_aggregator.send(aggregator_msg).await {
log::error!("Failed to send message to aggregator; closing shard: {e}");
break;
}
}
drop(send_closer_tx); // Kill the send task if this recv task ends
tx_to_aggregator
});
// Send messages to the shard:
let send_handle = tokio::spawn(async move {
loop {
let msg = tokio::select! {
msg = rx_from_aggregator.recv_async() => msg,
_ = &mut send_closer_rx => { break }
};
let msg = match msg {
Ok(msg) => msg,
Err(flume::RecvError::Disconnected) => break,
};
let internal_msg = match msg {
ToShardWebsocket::Mute { local_id, reason } => {
internal_messages::FromTelemetryCore::Mute { local_id, reason }
}
};
let bytes = bincode::options()
.serialize(&internal_msg)
.expect("message to shard should serialize");
if let Err(e) = ws_send.send_binary(bytes).await {
log::error!("Failed to send message to aggregator; closing shard: {e}")
}
if let Err(e) = ws_send.flush().await {
log::error!("Failed to flush message to aggregator; closing shard: {e}")
}
}
drop(recv_closer_tx); // Kill the recv task if this send task ends
ws_send
});
// If our send/recv tasks are stopped (if one of them dies, they both will),
// collect the bits we need to hand back from them:
let ws_send = send_handle.await.unwrap();
let tx_to_aggregator = recv_handle.await.unwrap();
// loop ended; give socket back to parent:
(tx_to_aggregator, ws_send)
}
/// This handles messages coming from a feed connection
async fn handle_feed_websocket_connection<S>(
mut ws_send: http_utils::WsSender,
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
feed_timeout: u64,
_feed_id: u64, // <- can be useful for debugging purposes.
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromFeedWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// unbounded channel so that slow feeds don't block aggregator progress:
let (tx_to_feed_conn, rx_from_aggregator) = flume::unbounded();
// `Receiver::into_stream()` is currently problematic at the time of writing
// (see https://github.com/zesterer/flume/issues/88). If this stream is polled lots
// and isn't ready, it'll leak memory. In this case, since we only select from it or
// a close channel, we shouldn't poll the thing more than once before it's ready (and
// when it's ready, it cleans up after itself properly). So, I hope it won't leak!
let mut rx_from_aggregator_chunks = ReadyChunksAll::new(rx_from_aggregator.into_stream());
// Tell the aggregator about this new connection, and give it a way to send messages to us:
let init_msg = FromFeedWebsocket::Initialize {
channel: tx_to_feed_conn,
};
if let Err(e) = tx_to_aggregator.send(init_msg).await {
log::error!("Error sending message to aggregator: {e}");
return (tx_to_aggregator, ws_send);
}
// Channels to notify each loop if the other closes:
let (recv_closer_tx, mut recv_closer_rx) = tokio::sync::oneshot::channel::<()>();
let (send_closer_tx, mut send_closer_rx) = tokio::sync::oneshot::channel::<()>();
// Receive messages from the feed:
let recv_handle = tokio::spawn(async move {
loop {
let mut bytes = Vec::new();
// Receive a message, or bail if closer called. We don't care about cancel safety;
// if we're halfway through receiving a message, no biggie since we're closing the
// connection anyway.
let msg_info = tokio::select! {
msg_info = ws_recv.receive_data(&mut bytes) => msg_info,
_ = &mut recv_closer_rx => { break }
};
// Handle the socket closing, or errors receiving the message.
if let Err(soketto::connection::Error::Closed) = msg_info {
break;
}
if let Err(e) = msg_info {
log::error!("Shutting down websocket connection: Failed to receive data: {e}");
break;
}
// We ignore all but valid UTF8 text messages from the frontend:
let text = match String::from_utf8(bytes) {
Ok(s) => s,
Err(_) => continue,
};
// Parse the message into a command we understand and send it to the aggregator:
let cmd = match FromFeedWebsocket::from_str(&text) {
Ok(cmd) => cmd,
Err(e) => {
log::warn!("Ignoring invalid command '{text}' from the frontend: {e}");
continue;
}
};
if let Err(e) = tx_to_aggregator.send(cmd).await {
log::error!("Failed to send message to aggregator; closing feed: {e}");
break;
}
}
drop(send_closer_tx); // Kill the send task if this recv task ends
tx_to_aggregator
});
// Send messages to the feed:
let send_handle = tokio::spawn(async move {
'outer: loop {
let debounce = tokio::time::sleep_until(Instant::now() + Duration::from_millis(75));
let msgs = tokio::select! {
msgs = rx_from_aggregator_chunks.next() => msgs,
_ = &mut send_closer_rx => { break }
};
// End the loop when connection from aggregator ends:
let msgs = match msgs {
Some(msgs) => msgs,
None => break,
};
// There is only one message type at the mo; bytes to send
// to the websocket. collect them all up to dispatch in one shot.
let all_msg_bytes = msgs.into_iter().map(|msg| match msg {
ToFeedWebsocket::Bytes(bytes) => bytes,
});
// If the feed is too slow to receive the current batch of messages, we'll drop it.
let message_send_deadline = Instant::now() + Duration::from_secs(feed_timeout);
for bytes in all_msg_bytes {
match tokio::time::timeout_at(message_send_deadline, ws_send.send_binary(&bytes))
.await
{
Err(_) => {
log::debug!("Closing feed websocket that was too slow to keep up (too slow to send messages)");
break 'outer;
}
Ok(Err(soketto::connection::Error::Closed)) => {
break 'outer;
}
Ok(Err(e)) => {
log::debug!("Closing feed websocket due to error sending data: {}", e);
break 'outer;
}
Ok(_) => {}
}
}
match tokio::time::timeout_at(message_send_deadline, ws_send.flush()).await {
Err(_) => {
log::debug!("Closing feed websocket that was too slow to keep up (too slow to flush messages)");
break;
}
Ok(Err(soketto::connection::Error::Closed)) => {
break;
}
Ok(Err(e)) => {
log::debug!("Closing feed websocket due to error flushing data: {}", e);
break;
}
Ok(_) => {}
}
debounce.await;
}
drop(recv_closer_tx); // Kill the recv task if this send task ends
ws_send
});
// If our send/recv tasks are stopped (if one of them dies, they both will),
// collect the bits we need to hand back from them:
let ws_send = send_handle.await.unwrap();
let tx_to_aggregator = recv_handle.await.unwrap();
// loop ended; give socket back to parent:
(tx_to_aggregator, ws_send)
}
async fn return_prometheus_metrics(aggregator: AggregatorSet) -> Response<hyper::Body> {
let metrics = aggregator.latest_metrics();
// Instead of using the rust prometheus library (which is optimised around global variables updated across a codebase),
// we just split out the text format that prometheus expects ourselves, and use the latest metrics that we've
// captured so far from the aggregators. See:
//
// https://github.com/prometheus/docs/blob/master/content/docs/instrumenting/exposition_formats.md#text-format-details
//
// For an example and explanation of this text based format. The minimal output we produce here seems to
// be handled correctly when pointing a current version of prometheus at it.
//
// Note: '{{' and '}}' are just escaped versions of '{' and '}' in Rust fmt strings.
use std::fmt::Write;
let mut s = String::new();
for (idx, m) in metrics.iter().enumerate() {
let _ = write!(
&mut s,
"telemetry_core_connected_feeds{{aggregator=\"{}\"}} {} {}\n",
idx, m.connected_feeds, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_connected_nodes{{aggregator=\"{}\"}} {} {}\n",
idx, m.connected_nodes, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_connected_shards{{aggregator=\"{}\"}} {} {}\n",
idx, m.connected_shards, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_chains_subscribed_to{{aggregator=\"{}\"}} {} {}\n",
idx, m.chains_subscribed_to, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_subscribed_feeds{{aggregator=\"{}\"}} {} {}\n",
idx, m.subscribed_feeds, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_total_messages_to_feeds{{aggregator=\"{}\"}} {} {}\n",
idx, m.total_messages_to_feeds, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_current_messages_to_aggregator{{aggregator=\"{}\"}} {} {}\n\n",
idx, m.current_messages_to_aggregator, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_total_messages_to_aggregator{{aggregator=\"{}\"}} {} {}\n\n",
idx, m.total_messages_to_aggregator, m.timestamp_unix_ms
);
let _ = write!(
&mut s,
"telemetry_core_dropped_messages_to_aggregator{{aggregator=\"{}\"}} {} {}\n\n",
idx, m.dropped_messages_to_aggregator, m.timestamp_unix_ms
);
}
Response::builder()
// The version number here tells prometheus which version of the text format we're using:
.header(http::header::CONTENT_TYPE, "text/plain; version=0.0.4")
.body(s.into())
.unwrap()
}

View File

@ -0,0 +1,393 @@
use common::node_message::Payload;
use common::node_types::BlockHash;
use common::node_types::{Block, Timestamp};
use common::{id_type, time, DenseMap, MostSeen, NumStats};
use once_cell::sync::Lazy;
use std::collections::HashSet;
use std::str::FromStr;
use std::time::{Duration, Instant};
use crate::feed_message::{self, ChainStats, FeedMessageSerializer};
use crate::find_location;
use super::chain_stats::ChainStatsCollator;
use super::counter::CounterValue;
use super::node::Node;
id_type! {
/// A Node ID that is unique to the chain it's in.
pub struct ChainNodeId(usize)
}
pub type Label = Box<str>;
const STALE_TIMEOUT: u64 = 2 * 60 * 1000; // 2 minutes
const STATS_UPDATE_INTERVAL: Duration = Duration::from_secs(5);
pub struct Chain {
/// Labels that nodes use for this chain. We keep track of
/// the most commonly used label as nodes are added/removed.
labels: MostSeen<Label>,
/// Set of nodes that are in this chain
nodes: DenseMap<ChainNodeId, Node>,
/// Best block
best: Block,
/// Finalized block
finalized: Block,
/// Block times history, stored so we can calculate averages
block_times: NumStats<u64>,
/// Calculated average block time
average_block_time: Option<u64>,
/// When the best block first arrived
timestamp: Option<Timestamp>,
/// Genesis hash of this chain
genesis_hash: BlockHash,
/// Maximum number of nodes allowed to connect from this chain
max_nodes: usize,
/// Collator for the stats.
stats_collator: ChainStatsCollator,
/// Stats for this chain.
stats: ChainStats,
/// Timestamp of when the stats were last regenerated.
stats_last_regenerated: Instant,
}
pub enum AddNodeResult {
Overquota,
Added {
id: ChainNodeId,
chain_renamed: bool,
},
}
pub struct RemoveNodeResult {
pub chain_renamed: bool,
}
/// Genesis hashes of chains we consider "first party". These chains allow any
/// number of nodes to connect.
static FIRST_PARTY_NETWORKS: Lazy<HashSet<BlockHash>> = Lazy::new(|| {
let genesis_hash_strs = &[
"0xce321292c998085ec1c5f5fab1add59fc163a28a5762fa49080215ce6bc040c7", // Casper v0.0.2
];
genesis_hash_strs
.iter()
.map(|h| BlockHash::from_str(h).expect("hardcoded hash str should be valid"))
.collect()
});
/// When we construct a chain, we want to check to see whether or not it's a "first party"
/// network first, and assign a `max_nodes` accordingly. This helps us do that.
pub fn is_first_party_network(genesis_hash: &BlockHash) -> bool {
FIRST_PARTY_NETWORKS.contains(genesis_hash)
}
impl Chain {
/// Create a new chain with an initial label.
pub fn new(genesis_hash: BlockHash, max_nodes: usize) -> Self {
Chain {
labels: MostSeen::default(),
nodes: DenseMap::new(),
best: Block::zero(),
finalized: Block::zero(),
block_times: NumStats::new(50),
average_block_time: None,
timestamp: None,
genesis_hash,
max_nodes,
stats_collator: Default::default(),
stats: Default::default(),
stats_last_regenerated: Instant::now(),
}
}
/// Is the chain the node belongs to overquota?
pub fn is_overquota(&self) -> bool {
self.nodes.len() >= self.max_nodes
}
/// Assign a node to this chain.
pub fn add_node(&mut self, node: Node) -> AddNodeResult {
if self.is_overquota() {
return AddNodeResult::Overquota;
}
let details = node.details();
self.stats_collator
.add_or_remove_node(details, None, CounterValue::Increment);
let node_chain_label = &details.chain;
let label_result = self.labels.insert(node_chain_label);
let node_id = self.nodes.add(node);
AddNodeResult::Added {
id: node_id,
chain_renamed: label_result.has_changed(),
}
}
/// Remove a node from this chain.
pub fn remove_node(&mut self, node_id: ChainNodeId) -> RemoveNodeResult {
let node = match self.nodes.remove(node_id) {
Some(node) => node,
None => {
return RemoveNodeResult {
chain_renamed: false,
}
}
};
let details = node.details();
self.stats_collator
.add_or_remove_node(details, node.hwbench(), CounterValue::Decrement);
let node_chain_label = &node.details().chain;
let label_result = self.labels.remove(node_chain_label);
RemoveNodeResult {
chain_renamed: label_result.has_changed(),
}
}
/// Attempt to update the best block seen in this chain.
pub fn update_node(
&mut self,
nid: ChainNodeId,
payload: Payload,
feed: &mut FeedMessageSerializer,
expose_node_details: bool,
) {
if let Some(block) = payload.best_block() {
self.handle_block(block, nid, feed);
}
if let Some(node) = self.nodes.get_mut(nid) {
match payload {
Payload::SystemInterval(ref interval) => {
// Send a feed message if any of the relevant node details change:
if node.update_hardware(interval) {
feed.push(feed_message::Hardware(nid.into(), node.hardware()));
}
if let Some(stats) = node.update_stats(interval) {
feed.push(feed_message::NodeStatsUpdate(nid.into(), stats));
}
if let Some(io) = node.update_io(interval) {
feed.push(feed_message::NodeIOUpdate(nid.into(), io));
}
}
Payload::AfgAuthoritySet(authority) => {
// If our node validator address (and thus details) change, send an
// updated "add node" feed message:
if node.set_validator_address(authority.authority_id.clone()) {
feed.push(feed_message::AddedNode(
nid.into(),
&node,
expose_node_details,
));
}
return;
}
Payload::HwBench(ref hwbench) => {
let new_hwbench = common::node_types::NodeHwBench {
cpu_hashrate_score: hwbench.cpu_hashrate_score,
memory_memcpy_score: hwbench.memory_memcpy_score,
disk_sequential_write_score: hwbench.disk_sequential_write_score,
disk_random_write_score: hwbench.disk_random_write_score,
};
let old_hwbench = node.update_hwbench(new_hwbench);
// The `hwbench` for this node has changed, send an updated "add node".
// Note: There is no need to send this message if the details
// will not be serialized over the wire.
if expose_node_details {
feed.push(feed_message::AddedNode(
nid.into(),
&node,
expose_node_details,
));
}
self.stats_collator
.update_hwbench(old_hwbench.as_ref(), CounterValue::Decrement);
self.stats_collator
.update_hwbench(node.hwbench(), CounterValue::Increment);
}
_ => {}
}
if let Some(block) = payload.finalized_block() {
if let Some(finalized) = node.update_finalized(block) {
feed.push(feed_message::FinalizedBlock(
nid.into(),
finalized.height,
finalized.hash,
));
if finalized.height > self.finalized.height {
self.finalized = *finalized;
feed.push(feed_message::BestFinalized(
finalized.height,
finalized.hash,
));
}
}
}
}
}
fn handle_block(&mut self, block: &Block, nid: ChainNodeId, feed: &mut FeedMessageSerializer) {
let mut propagation_time = None;
let now = time::now();
let nodes_len = self.nodes.len();
self.update_stale_nodes(now, feed);
self.regenerate_stats_if_necessary(feed);
let node = match self.nodes.get_mut(nid) {
Some(node) => node,
None => return,
};
if node.update_block(*block) {
if block.height > self.best.height {
self.best = *block;
log::debug!(
"[{}] [nodes={}] new best block={}/{:?}",
self.labels.best(),
nodes_len,
self.best.height,
self.best.hash,
);
if let Some(timestamp) = self.timestamp {
self.block_times.push(now.saturating_sub(timestamp));
self.average_block_time = Some(self.block_times.average());
}
self.timestamp = Some(now);
feed.push(feed_message::BestBlock(
self.best.height,
now,
self.average_block_time,
));
propagation_time = Some(0);
} else if block.height == self.best.height {
if let Some(timestamp) = self.timestamp {
propagation_time = Some(now.saturating_sub(timestamp));
}
}
if let Some(details) = node.update_details(now, propagation_time) {
feed.push(feed_message::ImportedBlock(nid.into(), details));
}
}
}
/// Check if the chain is stale (has not received a new best block in a while).
/// If so, find a new best block, ignoring any stale nodes and marking them as such.
fn update_stale_nodes(&mut self, now: u64, feed: &mut FeedMessageSerializer) {
let threshold = now - STALE_TIMEOUT;
let timestamp = match self.timestamp {
Some(ts) => ts,
None => return,
};
if timestamp > threshold {
// Timestamp is in range, nothing to do
return;
}
let mut best = Block::zero();
let mut finalized = Block::zero();
let mut timestamp = None;
for (nid, node) in self.nodes.iter_mut() {
if !node.update_stale(threshold) {
if node.best().height > best.height {
best = *node.best();
timestamp = Some(node.best_timestamp());
}
if node.finalized().height > finalized.height {
finalized = *node.finalized();
}
} else {
feed.push(feed_message::StaleNode(nid.into()));
}
}
if self.best.height != 0 || self.finalized.height != 0 {
self.best = best;
self.finalized = finalized;
self.block_times.reset();
self.timestamp = timestamp;
feed.push(feed_message::BestBlock(
self.best.height,
timestamp.unwrap_or(now),
None,
));
feed.push(feed_message::BestFinalized(
finalized.height,
finalized.hash,
));
}
}
fn regenerate_stats_if_necessary(&mut self, feed: &mut FeedMessageSerializer) {
let now = Instant::now();
let elapsed = now - self.stats_last_regenerated;
if elapsed < STATS_UPDATE_INTERVAL {
return;
}
self.stats_last_regenerated = now;
let new_stats = self.stats_collator.generate();
if new_stats != self.stats {
self.stats = new_stats;
feed.push(feed_message::ChainStatsUpdate(&self.stats));
}
}
pub fn update_node_location(
&mut self,
node_id: ChainNodeId,
location: find_location::Location,
) -> bool {
if let Some(node) = self.nodes.get_mut(node_id) {
node.update_location(location);
true
} else {
false
}
}
pub fn get_node(&self, id: ChainNodeId) -> Option<&Node> {
self.nodes.get(id)
}
pub fn nodes_slice(&self) -> &[Option<Node>] {
self.nodes.as_slice()
}
pub fn label(&self) -> &str {
&self.labels.best()
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn best_block(&self) -> &Block {
&self.best
}
pub fn timestamp(&self) -> Option<Timestamp> {
self.timestamp
}
pub fn average_block_time(&self) -> Option<u64> {
self.average_block_time
}
pub fn finalized_block(&self) -> &Block {
&self.finalized
}
pub fn genesis_hash(&self) -> BlockHash {
self.genesis_hash
}
pub fn stats(&self) -> &ChainStats {
&self.stats
}
}

View File

@ -0,0 +1,251 @@
use super::counter::{Counter, CounterValue};
use crate::feed_message::ChainStats;
// These are the benchmark scores generated on our reference hardware.
const REFERENCE_CPU_SCORE: u64 = 257;
const REFERENCE_MEMORY_SCORE: u64 = 6070;
const REFERENCE_DISK_SEQUENTIAL_WRITE_SCORE: u64 = 425;
const REFERENCE_DISK_RANDOM_WRITE_SCORE: u64 = 210;
macro_rules! buckets {
(@try $value:expr, $bucket_min:expr, $bucket_max:expr,) => {
if $value < $bucket_max {
return ($bucket_min, Some($bucket_max));
}
};
($value:expr, $bucket_min:expr, $bucket_max:expr, $($remaining:expr,)*) => {
buckets! { @try $value, $bucket_min, $bucket_max, }
buckets! { $value, $bucket_max, $($remaining,)* }
};
($value:expr, $bucket_last:expr,) => {
($bucket_last, None)
}
}
/// Translates a given raw benchmark score into a relative measure
/// of how the score compares to the reference score.
///
/// The value returned is the range (in percent) within which the given score
/// falls into. For example, a value of `(90, Some(110))` means that the score
/// is between 90% and 110% of the reference score, with the lower bound being
/// inclusive and the upper bound being exclusive.
fn bucket_score(score: u64, reference_score: u64) -> (u32, Option<u32>) {
let relative_score = ((score as f64 / reference_score as f64) * 100.0) as u32;
buckets! {
relative_score,
0,
10,
30,
50,
70,
90,
110,
130,
150,
200,
300,
400,
500,
}
}
#[test]
fn test_bucket_score() {
assert_eq!(bucket_score(0, 100), (0, Some(10)));
assert_eq!(bucket_score(9, 100), (0, Some(10)));
assert_eq!(bucket_score(10, 100), (10, Some(30)));
assert_eq!(bucket_score(29, 100), (10, Some(30)));
assert_eq!(bucket_score(30, 100), (30, Some(50)));
assert_eq!(bucket_score(100, 100), (90, Some(110)));
assert_eq!(bucket_score(500, 100), (500, None));
}
fn bucket_memory(memory: u64) -> (u32, Option<u32>) {
let memory = memory / (1024 * 1024) / 1000;
buckets! {
memory,
1,
2,
4,
6,
8,
10,
16,
24,
32,
48,
56,
64,
128,
}
}
fn kernel_version_number(version: &Box<str>) -> &str {
let index = version
.find("-")
.or_else(|| version.find("+"))
.unwrap_or(version.len());
&version[0..index]
}
#[test]
fn test_kernel_version_number() {
assert_eq!(kernel_version_number(&"5.10.0-8-amd64".into()), "5.10.0");
// Plus sign indicates that the kernel was built from modified sources.
// This should only appear at the end of the version string.
assert_eq!(kernel_version_number(&"5.10.0+82453".into()), "5.10.0");
assert_eq!(kernel_version_number(&"5.10.0".into()), "5.10.0");
}
fn cpu_vendor(cpu: &Box<str>) -> &str {
let lowercase_cpu = cpu.to_ascii_lowercase();
if lowercase_cpu.contains("intel") {
"Intel"
} else if lowercase_cpu.contains("amd") {
"AMD"
} else if lowercase_cpu.contains("arm") {
"ARM"
} else if lowercase_cpu.contains("apple") {
"Apple"
} else {
"Other"
}
}
#[derive(Default)]
pub struct ChainStatsCollator {
version: Counter<String>,
target_os: Counter<String>,
target_arch: Counter<String>,
cpu: Counter<String>,
memory: Counter<(u32, Option<u32>)>,
core_count: Counter<u32>,
linux_kernel: Counter<String>,
linux_distro: Counter<String>,
is_virtual_machine: Counter<bool>,
cpu_hashrate_score: Counter<(u32, Option<u32>)>,
memory_memcpy_score: Counter<(u32, Option<u32>)>,
disk_sequential_write_score: Counter<(u32, Option<u32>)>,
disk_random_write_score: Counter<(u32, Option<u32>)>,
cpu_vendor: Counter<String>,
}
impl ChainStatsCollator {
pub fn add_or_remove_node(
&mut self,
details: &common::node_types::NodeDetails,
hwbench: Option<&common::node_types::NodeHwBench>,
op: CounterValue,
) {
self.version.modify(Some(&*details.version), op);
self.target_os
.modify(details.target_os.as_ref().map(|value| &**value), op);
self.target_arch
.modify(details.target_arch.as_ref().map(|value| &**value), op);
let sysinfo = details.sysinfo.as_ref();
self.cpu.modify(
sysinfo
.and_then(|sysinfo| sysinfo.cpu.as_ref())
.map(|value| &**value),
op,
);
let memory = sysinfo.and_then(|sysinfo| sysinfo.memory.map(bucket_memory));
self.memory.modify(memory.as_ref(), op);
self.core_count
.modify(sysinfo.and_then(|sysinfo| sysinfo.core_count.as_ref()), op);
self.linux_kernel.modify(
sysinfo
.and_then(|sysinfo| sysinfo.linux_kernel.as_ref())
.map(kernel_version_number),
op,
);
self.linux_distro.modify(
sysinfo
.and_then(|sysinfo| sysinfo.linux_distro.as_ref())
.map(|value| &**value),
op,
);
self.is_virtual_machine.modify(
sysinfo.and_then(|sysinfo| sysinfo.is_virtual_machine.as_ref()),
op,
);
self.cpu_vendor.modify(
sysinfo.and_then(|sysinfo| sysinfo.cpu.as_ref().map(cpu_vendor)),
op,
);
self.update_hwbench(hwbench, op);
}
pub fn update_hwbench(
&mut self,
hwbench: Option<&common::node_types::NodeHwBench>,
op: CounterValue,
) {
self.cpu_hashrate_score.modify(
hwbench
.map(|hwbench| bucket_score(hwbench.cpu_hashrate_score, REFERENCE_CPU_SCORE))
.as_ref(),
op,
);
self.memory_memcpy_score.modify(
hwbench
.map(|hwbench| bucket_score(hwbench.memory_memcpy_score, REFERENCE_MEMORY_SCORE))
.as_ref(),
op,
);
self.disk_sequential_write_score.modify(
hwbench
.and_then(|hwbench| hwbench.disk_sequential_write_score)
.map(|score| bucket_score(score, REFERENCE_DISK_SEQUENTIAL_WRITE_SCORE))
.as_ref(),
op,
);
self.disk_random_write_score.modify(
hwbench
.and_then(|hwbench| hwbench.disk_random_write_score)
.map(|score| bucket_score(score, REFERENCE_DISK_RANDOM_WRITE_SCORE))
.as_ref(),
op,
);
}
pub fn generate(&self) -> ChainStats {
ChainStats {
version: self.version.generate_ranking_top(10),
target_os: self.target_os.generate_ranking_top(10),
target_arch: self.target_arch.generate_ranking_top(10),
cpu: self.cpu.generate_ranking_top(10),
memory: self.memory.generate_ranking_ordered(),
core_count: self.core_count.generate_ranking_top(10),
linux_kernel: self.linux_kernel.generate_ranking_top(10),
linux_distro: self.linux_distro.generate_ranking_top(10),
is_virtual_machine: self.is_virtual_machine.generate_ranking_ordered(),
cpu_hashrate_score: self.cpu_hashrate_score.generate_ranking_top(10),
memory_memcpy_score: self.memory_memcpy_score.generate_ranking_ordered(),
disk_sequential_write_score: self
.disk_sequential_write_score
.generate_ranking_ordered(),
disk_random_write_score: self.disk_random_write_score.generate_ranking_ordered(),
cpu_vendor: self.cpu_vendor.generate_ranking_top(10),
}
}
}

View File

@ -0,0 +1,103 @@
use crate::feed_message::Ranking;
use std::collections::HashMap;
/// A data structure which counts how many occurrences of a given key we've seen.
#[derive(Default)]
pub struct Counter<K> {
/// A map containing the number of occurrences of a given key.
///
/// If there are none then the entry is removed.
map: HashMap<K, u64>,
/// The number of occurrences where the key is `None`.
empty: u64,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum CounterValue {
Increment,
Decrement,
}
impl<K> Counter<K>
where
K: Sized + std::hash::Hash + Eq,
{
/// Either adds or removes a single occurence of a given `key`.
pub fn modify<'a, Q>(&mut self, key: Option<&'a Q>, op: CounterValue)
where
Q: ?Sized + std::hash::Hash + Eq,
K: std::borrow::Borrow<Q>,
Q: std::borrow::ToOwned<Owned = K>,
{
if let Some(key) = key {
if let Some(entry) = self.map.get_mut(key) {
match op {
CounterValue::Increment => {
*entry += 1;
}
CounterValue::Decrement => {
*entry -= 1;
if *entry == 0 {
// Don't keep entries for which there are no hits.
self.map.remove(key);
}
}
}
} else {
assert_eq!(op, CounterValue::Increment);
self.map.insert(key.to_owned(), 1);
}
} else {
match op {
CounterValue::Increment => {
self.empty += 1;
}
CounterValue::Decrement => {
self.empty -= 1;
}
}
}
}
/// Generates a top-N table of the most common keys.
pub fn generate_ranking_top(&self, max_count: usize) -> Ranking<K>
where
K: Clone,
{
let mut all: Vec<(&K, u64)> = self.map.iter().map(|(key, count)| (key, *count)).collect();
all.sort_unstable_by_key(|&(_, count)| !count);
let list = all
.iter()
.take(max_count)
.map(|&(key, count)| (key.clone(), count))
.collect();
let other = all
.iter()
.skip(max_count)
.fold(0, |sum, (_, count)| sum + *count);
Ranking {
list,
other,
unknown: self.empty,
}
}
/// Generates a sorted table of all of the keys.
pub fn generate_ranking_ordered(&self) -> Ranking<K>
where
K: Copy + Clone + Ord,
{
let mut list: Vec<(K, u64)> = self.map.iter().map(|(key, count)| (*key, *count)).collect();
list.sort_unstable_by_key(|&(key, count)| (key, !count));
Ranking {
list,
other: 0,
unknown: self.empty,
}
}
}

View File

@ -0,0 +1,9 @@
mod chain;
mod chain_stats;
mod counter;
mod node;
mod state;
pub use node::Node;
pub use state::*;

View File

@ -0,0 +1,224 @@
use crate::find_location;
use common::node_message::SystemInterval;
use common::node_types::{
Block, BlockDetails, NodeDetails, NodeHardware, NodeHwBench, NodeIO, NodeLocation, NodeStats,
Timestamp,
};
use common::time;
/// Minimum time between block below broadcasting updates to the browser gets throttled, in ms.
const THROTTLE_THRESHOLD: u64 = 100;
/// Minimum time of intervals for block updates sent to the browser when throttled, in ms.
const THROTTLE_INTERVAL: u64 = 1000;
pub struct Node {
/// Static details
details: NodeDetails,
/// Basic stats
stats: NodeStats,
/// Node IO stats
io: NodeIO,
/// Best block
best: BlockDetails,
/// Finalized block
finalized: Block,
/// Timer for throttling block updates
throttle: u64,
/// Hardware stats over time
hardware: NodeHardware,
/// Physical location details
location: find_location::Location,
/// Flag marking if the node is stale (not syncing or producing blocks)
stale: bool,
/// Unix timestamp for when node started up (falls back to connection time)
startup_time: Option<Timestamp>,
/// Hardware benchmark results for the node
hwbench: Option<NodeHwBench>,
}
impl Node {
pub fn new(mut details: NodeDetails) -> Self {
let startup_time = details
.startup_time
.take()
.and_then(|time| time.parse().ok());
Node {
details,
stats: NodeStats::default(),
io: NodeIO::default(),
best: BlockDetails::default(),
finalized: Block::zero(),
throttle: 0,
hardware: NodeHardware::default(),
location: None,
stale: false,
startup_time,
hwbench: None,
}
}
pub fn details(&self) -> &NodeDetails {
&self.details
}
pub fn stats(&self) -> &NodeStats {
&self.stats
}
pub fn io(&self) -> &NodeIO {
&self.io
}
pub fn best(&self) -> &Block {
&self.best.block
}
pub fn best_timestamp(&self) -> u64 {
self.best.block_timestamp
}
pub fn finalized(&self) -> &Block {
&self.finalized
}
pub fn hardware(&self) -> &NodeHardware {
&self.hardware
}
pub fn location(&self) -> Option<&NodeLocation> {
self.location.as_deref()
}
pub fn update_location(&mut self, location: find_location::Location) {
self.location = location;
}
pub fn block_details(&self) -> &BlockDetails {
&self.best
}
pub fn hwbench(&self) -> Option<&NodeHwBench> {
self.hwbench.as_ref()
}
pub fn update_hwbench(&mut self, hwbench: NodeHwBench) -> Option<NodeHwBench> {
self.hwbench.replace(hwbench)
}
pub fn update_block(&mut self, block: Block) -> bool {
if block.height > self.best.block.height {
self.stale = false;
self.best.block = block;
true
} else {
false
}
}
pub fn update_details(
&mut self,
timestamp: u64,
propagation_time: Option<u64>,
) -> Option<&BlockDetails> {
self.best.block_time = timestamp - self.best.block_timestamp;
self.best.block_timestamp = timestamp;
self.best.propagation_time = propagation_time;
if self.throttle < timestamp {
if self.best.block_time <= THROTTLE_THRESHOLD {
self.throttle = timestamp + THROTTLE_INTERVAL;
}
Some(&self.best)
} else {
None
}
}
pub fn update_hardware(&mut self, interval: &SystemInterval) -> bool {
let mut changed = false;
if let Some(upload) = interval.bandwidth_upload {
changed |= self.hardware.upload.push(upload);
}
if let Some(download) = interval.bandwidth_download {
changed |= self.hardware.download.push(download);
}
self.hardware.chart_stamps.push(time::now() as f64);
changed
}
pub fn update_stats(&mut self, interval: &SystemInterval) -> Option<&NodeStats> {
let mut changed = false;
if let Some(peers) = interval.peers {
if peers != self.stats.peers {
self.stats.peers = peers;
changed = true;
}
}
if let Some(txcount) = interval.txcount {
if txcount != self.stats.txcount {
self.stats.txcount = txcount;
changed = true;
}
}
if changed {
Some(&self.stats)
} else {
None
}
}
pub fn update_io(&mut self, interval: &SystemInterval) -> Option<&NodeIO> {
let mut changed = false;
if let Some(size) = interval.used_state_cache_size {
changed |= self.io.used_state_cache_size.push(size);
}
if changed {
Some(&self.io)
} else {
None
}
}
pub fn update_finalized(&mut self, block: Block) -> Option<&Block> {
if block.height > self.finalized.height {
self.finalized = block;
Some(self.finalized())
} else {
None
}
}
pub fn update_stale(&mut self, threshold: u64) -> bool {
if self.best.block_timestamp < threshold {
self.stale = true;
}
self.stale
}
pub fn stale(&self) -> bool {
self.stale
}
pub fn set_validator_address(&mut self, addr: Box<str>) -> bool {
if self.details.validator.as_ref() == Some(&addr) {
false
} else {
self.details.validator = Some(addr);
true
}
}
pub fn startup_time(&self) -> Option<Timestamp> {
self.startup_time
}
}

View File

@ -0,0 +1,396 @@
use super::node::Node;
use crate::feed_message::{ChainStats, FeedMessageSerializer};
use crate::find_location;
use common::node_message::Payload;
use common::node_types::{Block, BlockHash, NodeDetails, Timestamp};
use common::{id_type, DenseMap};
use std::collections::{HashMap, HashSet};
use std::iter::IntoIterator;
use super::chain::{self, Chain, ChainNodeId};
id_type! {
/// A globally unique Chain ID.
pub struct ChainId(usize)
}
/// A "global" Node ID is a composite of the ID of the chain it's
/// on, and it's chain local ID.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct NodeId(ChainId, ChainNodeId);
impl NodeId {
pub fn get_chain_node_id(&self) -> ChainNodeId {
self.1
}
}
/// Our state contains node and chain information
pub struct State {
chains: DenseMap<ChainId, Chain>,
/// Find the right chain given various details.
chains_by_genesis_hash: HashMap<BlockHash, ChainId>,
/// Chain labels that we do not want to allow connecting.
denylist: HashSet<String>,
}
/// Adding a node to a chain leads to this result.
pub enum AddNodeResult<'a> {
/// The chain is on the "deny list", so we can't add the node
ChainOnDenyList,
/// The chain is over quota (too many nodes connected), so can't add the node
ChainOverQuota,
/// The node was added to the chain
NodeAddedToChain(NodeAddedToChain<'a>),
}
#[cfg(test)]
impl<'a> AddNodeResult<'a> {
pub fn unwrap_id(&self) -> NodeId {
match &self {
AddNodeResult::NodeAddedToChain(d) => d.id,
_ => panic!("Attempt to unwrap_id on AddNodeResult that did not succeed"),
}
}
}
pub struct NodeAddedToChain<'a> {
/// The ID assigned to this node.
pub id: NodeId,
/// The new label of the chain.
pub new_chain_label: &'a str,
/// The node that was added.
pub node: &'a Node,
/// Number of nodes in the chain. If 1, the chain was just added.
pub chain_node_count: usize,
/// Has the chain label been updated?
pub has_chain_label_changed: bool,
}
/// if removing a node is successful, we get this information back.
pub struct RemovedNode {
/// How many nodes remain on the chain (0 if the chain was removed)
pub chain_node_count: usize,
/// Has the chain label been updated?
pub has_chain_label_changed: bool,
/// Genesis hash of the chain to be updated.
pub chain_genesis_hash: BlockHash,
/// The new label of the chain.
pub new_chain_label: Box<str>,
}
impl State {
pub fn new<T: IntoIterator<Item = String>>(denylist: T) -> State {
State {
chains: DenseMap::new(),
chains_by_genesis_hash: HashMap::new(),
denylist: denylist.into_iter().collect(),
}
}
pub fn iter_chains(&self) -> impl Iterator<Item = StateChain<'_>> {
self.chains
.iter()
.map(move |(_, chain)| StateChain { chain })
}
pub fn get_chain_by_node_id(&self, node_id: NodeId) -> Option<StateChain<'_>> {
self.chains.get(node_id.0).map(|chain| StateChain { chain })
}
pub fn get_chain_by_genesis_hash(&self, genesis_hash: &BlockHash) -> Option<StateChain<'_>> {
self.chains_by_genesis_hash
.get(genesis_hash)
.and_then(|&chain_id| self.chains.get(chain_id))
.map(|chain| StateChain { chain })
}
pub fn add_node(
&mut self,
genesis_hash: BlockHash,
node_details: NodeDetails,
) -> AddNodeResult<'_> {
if self.denylist.contains(&*node_details.chain) {
return AddNodeResult::ChainOnDenyList;
}
// Get the chain ID, creating a new empty chain if one doesn't exist.
// If we create a chain here, we are expecting that it will allow at
// least this node to be added, because we don't currently try and clean it up
// if the add fails.
let chain_id = match self.chains_by_genesis_hash.get(&genesis_hash) {
Some(id) => *id,
None => {
let max_nodes = match chain::is_first_party_network(&genesis_hash) {
true => usize::MAX,
false => return AddNodeResult::ChainOnDenyList,
};
let chain_id = self.chains.add(Chain::new(genesis_hash, max_nodes));
self.chains_by_genesis_hash.insert(genesis_hash, chain_id);
chain_id
}
};
// Get the chain.
let chain = self.chains.get_mut(chain_id).expect(
"should be known to exist after the above (unless chains_by_genesis_hash out of sync)",
);
let node = Node::new(node_details);
match chain.add_node(node) {
chain::AddNodeResult::Overquota => AddNodeResult::ChainOverQuota,
chain::AddNodeResult::Added { id, chain_renamed } => {
let chain = &*chain;
AddNodeResult::NodeAddedToChain(NodeAddedToChain {
id: NodeId(chain_id, id),
node: chain.get_node(id).expect("node added above"),
new_chain_label: chain.label(),
chain_node_count: chain.node_count(),
has_chain_label_changed: chain_renamed,
})
}
}
}
/// Remove a node
pub fn remove_node(&mut self, NodeId(chain_id, chain_node_id): NodeId) -> Option<RemovedNode> {
let chain = self.chains.get_mut(chain_id)?;
// Actually remove the node
let remove_result = chain.remove_node(chain_node_id);
// Get updated chain details.
let new_chain_label: Box<str> = chain.label().into();
let chain_node_count = chain.node_count();
let chain_genesis_hash = chain.genesis_hash();
// Is the chain empty? Remove if so and clean up indexes to it
if chain_node_count == 0 {
let genesis_hash = chain.genesis_hash();
self.chains_by_genesis_hash.remove(&genesis_hash);
self.chains.remove(chain_id);
}
Some(RemovedNode {
new_chain_label,
chain_node_count,
chain_genesis_hash,
has_chain_label_changed: remove_result.chain_renamed,
})
}
/// Attempt to update the best block seen, given a node and block.
pub fn update_node(
&mut self,
NodeId(chain_id, chain_node_id): NodeId,
payload: Payload,
feed: &mut FeedMessageSerializer,
expose_node_details: bool,
) {
let chain = match self.chains.get_mut(chain_id) {
Some(chain) => chain,
None => {
log::error!("Cannot find chain for node with ID {:?}", chain_id);
return;
}
};
chain.update_node(chain_node_id, payload, feed, expose_node_details)
}
/// Update the location for a node. Return `false` if the node was not found.
pub fn update_node_location(
&mut self,
NodeId(chain_id, chain_node_id): NodeId,
location: find_location::Location,
) -> bool {
if let Some(chain) = self.chains.get_mut(chain_id) {
chain.update_node_location(chain_node_id, location)
} else {
false
}
}
}
/// When we ask for a chain, we get this struct back. This ensures that we have
/// a consistent public interface, and don't expose methods on [`Chain`] that
/// aren't really intended for use outside of [`State`] methods. Any modification
/// of a chain needs to go through [`State`].
pub struct StateChain<'a> {
chain: &'a Chain,
}
impl<'a> StateChain<'a> {
pub fn label(&self) -> &'a str {
self.chain.label()
}
pub fn genesis_hash(&self) -> BlockHash {
self.chain.genesis_hash()
}
pub fn node_count(&self) -> usize {
self.chain.node_count()
}
pub fn best_block(&self) -> &'a Block {
self.chain.best_block()
}
pub fn timestamp(&self) -> Timestamp {
self.chain.timestamp().unwrap_or(0)
}
pub fn average_block_time(&self) -> Option<u64> {
self.chain.average_block_time()
}
pub fn finalized_block(&self) -> &'a Block {
self.chain.finalized_block()
}
pub fn nodes_slice(&self) -> &[Option<Node>] {
self.chain.nodes_slice()
}
pub fn stats(&self) -> &ChainStats {
self.chain.stats()
}
}
#[cfg(test)]
mod test {
use super::*;
use common::node_types::NetworkId;
use std::str::FromStr;
fn node(name: &str, chain: &str) -> NodeDetails {
NodeDetails {
chain: chain.into(),
name: name.into(),
implementation: "Bar".into(),
target_arch: Some("x86_64".into()),
target_os: Some("linux".into()),
target_env: Some("env".into()),
version: "0.1".into(),
validator: None,
network_id: NetworkId::new(),
startup_time: None,
sysinfo: None,
ip: None,
}
}
fn get_valid_genesis() -> BlockHash {
BlockHash::from_str("0xce321292c998085ec1c5f5fab1add59fc163a28a5762fa49080215ce6bc040c7")
.expect("Should be good genesis")
}
#[test]
fn adding_a_node_returns_expected_response() {
let mut state = State::new(None);
let chain1_genesis = get_valid_genesis();
let add_result = state.add_node(chain1_genesis, node("A", "Chain One"));
let add_node_result = match add_result {
AddNodeResult::ChainOnDenyList => panic!("Chain not on deny list"),
AddNodeResult::ChainOverQuota => panic!("Chain not Overquota"),
AddNodeResult::NodeAddedToChain(details) => details,
};
assert_eq!(add_node_result.id, NodeId(0.into(), 0.into()));
assert_eq!(&*add_node_result.new_chain_label, "Chain One");
assert_eq!(add_node_result.chain_node_count, 1);
assert_eq!(add_node_result.has_chain_label_changed, true);
let add_result = state.add_node(chain1_genesis, node("A", "Chain One"));
let add_node_result = match add_result {
AddNodeResult::ChainOnDenyList => panic!("Chain not on deny list"),
AddNodeResult::ChainOverQuota => panic!("Chain not Overquota"),
AddNodeResult::NodeAddedToChain(details) => details,
};
assert_eq!(add_node_result.id, NodeId(0.into(), 1.into()));
assert_eq!(&*add_node_result.new_chain_label, "Chain One");
assert_eq!(add_node_result.chain_node_count, 2);
assert_eq!(add_node_result.has_chain_label_changed, false);
}
#[test]
fn adding_and_removing_nodes_updates_chain_label_mapping() {
let mut state = State::new(None);
let chain1_genesis = get_valid_genesis();
let node_id0 = state
.add_node(chain1_genesis, node("A", "Chain One")) // 0
.unwrap_id();
assert_eq!(
state
.get_chain_by_node_id(node_id0)
.expect("Chain should exist")
.label(),
"Chain One"
);
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
let node_id1 = state
.add_node(chain1_genesis, node("B", "Chain Two")) // 1
.unwrap_id();
// Chain name hasn't changed yet; "Chain One" as common as "Chain Two"..
assert_eq!(
state
.get_chain_by_node_id(node_id0)
.expect("Chain should exist")
.label(),
"Chain One"
);
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
let node_id2 = state
.add_node(chain1_genesis, node("B", "Chain Two"))
.unwrap_id(); // 2
// Chain name has changed; "Chain Two" the winner now..
assert_eq!(
state
.get_chain_by_node_id(node_id0)
.expect("Chain should exist")
.label(),
"Chain Two"
);
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
state.remove_node(node_id1).expect("Removal OK (id: 1)");
state.remove_node(node_id2).expect("Removal OK (id: 2)");
// Removed both "Chain Two" nodes; dominant name now "Chain One" again..
assert_eq!(
state
.get_chain_by_node_id(node_id0)
.expect("Chain should exist")
.label(),
"Chain One"
);
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
}
#[test]
fn chain_removed_when_last_node_is() {
let mut state = State::new(None);
let chain1_genesis = get_valid_genesis();
let node_id = state
.add_node(chain1_genesis, node("A", "Chain One")) // 0
.unwrap_id();
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_some());
assert_eq!(state.iter_chains().count(), 1);
state.remove_node(node_id);
assert!(state.get_chain_by_genesis_hash(&chain1_genesis).is_none());
assert_eq!(state.iter_chains().count(), 0);
}
}

View File

@ -0,0 +1,29 @@
[package]
name = "ghost-telemetry-shard"
version = "0.1.0"
authors = ["Uncle Stinky uncle.stinky@ghostchain.io"]
edition = "2021"
[dependencies]
anyhow = "1.0.42"
bincode = "1.3.3"
flume = "0.10.8"
futures = "0.3.15"
hex = "0.4.3"
http = "0.2.4"
hyper = "0.14.11"
log = "0.4.14"
num_cpus = "1.13.0"
primitive-types = { version = "0.12.1", features = ["serde"] }
serde = { version = "1.0.126", features = ["derive"] }
serde_json = "1.0.64"
simple_logger = "4.0.0"
soketto = "0.7.1"
thiserror = "1.0.25"
tokio = { version = "1.10.1", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["compat"] }
common = { package = "ghost-telemetry-common", path = "../common" }
structopt = "0.3.21"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@ -0,0 +1,311 @@
use crate::connection::{create_ws_connection_to_core, Message};
use common::{
internal_messages::{self, ShardNodeId},
node_message,
node_types::BlockHash,
AssignId,
};
use futures::{Sink, SinkExt};
use std::collections::{HashMap, HashSet};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
/// A unique Id is assigned per websocket connection (or more accurately,
/// per thing-that-subscribes-to-the-aggregator). That connection might send
/// data on behalf of multiple chains, so this ID is local to the aggregator,
/// and a unique ID is assigned per batch of data too ([`internal_messages::ShardNodeId`]).
type ConnId = u64;
/// Incoming messages are either from websocket connections or
/// from the telemetry core. This can be private since the only
/// external messages are via subscriptions that take
/// [`FromWebsocket`] instances.
#[derive(Clone, Debug)]
enum ToAggregator {
/// Sent when the telemetry core is disconnected.
DisconnectedFromTelemetryCore,
/// Sent when the telemetry core (re)connects.
ConnectedToTelemetryCore,
/// Sent when a message comes in from a ghost node.
FromWebsocket(ConnId, FromWebsocket),
/// Send when a message comes in from the telemetry core.
FromTelemetryCore(internal_messages::FromTelemetryCore),
}
/// An incoming socket connection can provide these messages.
/// Until a node has been Added via [`FromWebsocket::Add`],
/// messages from it will be ignored.
#[derive(Clone, Debug)]
pub enum FromWebsocket {
/// Fire this when the connection is established.
Initialize {
/// When a message is sent back up this channel, we terminate
/// the websocket connection and force the node to reconnect
/// so that it sends its system info again incase the telemetry
/// core has restarted.
close_connection: flume::Sender<()>,
},
/// Tell the aggregator about a new node.
Add {
message_id: node_message::NodeMessageId,
ip: std::net::IpAddr,
node: common::node_types::NodeDetails,
genesis_hash: BlockHash,
},
/// Update/pass through details about a node.
Update {
message_id: node_message::NodeMessageId,
payload: node_message::Payload,
},
/// remove a node with the given message ID
Remove {
message_id: node_message::NodeMessageId,
},
/// Make a note when the node disconnects.
Disconnected,
}
pub type FromAggregator = internal_messages::FromShardAggregator;
/// The aggregator loop handles incoming messages from nodes, or from the telemetry core.
/// this is where we decide what effect messages will have.
#[derive(Clone)]
pub struct Aggregator(Arc<AggregatorInternal>);
struct AggregatorInternal {
/// Nodes that connect are each assigned a unique connection ID. Nodes
/// can send messages on behalf of more than one chain, and so this ID is
/// only really used inside the Aggregator in conjunction with a per-message
/// ID.
conn_id: AtomicU64,
/// Send messages to the aggregator from websockets via this. This is
/// stored here so that anybody holding an `Aggregator` handle can
/// make use of it.
tx_to_aggregator: flume::Sender<ToAggregator>,
}
impl Aggregator {
/// Spawn a new Aggregator. This connects to the telemetry backend
pub async fn spawn(telemetry_uri: http::Uri) -> anyhow::Result<Aggregator> {
let (tx_to_aggregator, rx_from_external) = flume::bounded(10);
// Establish a resilient connection to the core (this retries as needed):
let (tx_to_telemetry_core, rx_from_telemetry_core) =
create_ws_connection_to_core(telemetry_uri).await;
// Forward messages from the telemetry core into the aggregator:
let tx_to_aggregator2 = tx_to_aggregator.clone();
tokio::spawn(async move {
while let Ok(msg) = rx_from_telemetry_core.recv_async().await {
let msg_to_aggregator = match msg {
Message::Connected => ToAggregator::ConnectedToTelemetryCore,
Message::Disconnected => ToAggregator::DisconnectedFromTelemetryCore,
Message::Data(data) => ToAggregator::FromTelemetryCore(data),
};
if let Err(_) = tx_to_aggregator2.send_async(msg_to_aggregator).await {
// This will close the ws channels, which themselves log messages.
break;
}
}
});
// Start our aggregator loop, handling any incoming messages:
tokio::spawn(Aggregator::handle_messages(
rx_from_external,
tx_to_telemetry_core,
));
// Return a handle to our aggregator so that we can send in messages to it:
Ok(Aggregator(Arc::new(AggregatorInternal {
conn_id: AtomicU64::new(1),
tx_to_aggregator,
})))
}
// This is spawned into a separate task and handles any messages coming
// in to the aggregator. If nobody is holding the tx side of the channel
// any more, this task will gracefully end.
async fn handle_messages(
rx_from_external: flume::Receiver<ToAggregator>,
tx_to_telemetry_core: flume::Sender<FromAggregator>,
) {
use internal_messages::{FromShardAggregator, FromTelemetryCore};
// Just as an optimisation, we can keep track of whether we're connected to the backend
// or not, and ignore incoming messages while we aren't.
let mut connected_to_telemetry_core = false;
// A list of close channels for the currently connected ghost nodes. Send an empty
// tuple to these to ask the connections to be closed.
let mut close_connections: HashMap<ConnId, flume::Sender<()>> = HashMap::new();
// Maintain mappings from the connection ID and node message ID to the "local ID" which we
// broadcast to the telemetry core.
let mut to_local_id = AssignId::new();
// Any messages coming from nodes that have been muted are ignored:
let mut muted: HashSet<ShardNodeId> = HashSet::new();
// Now, loop and receive messages to handle.
while let Ok(msg) = rx_from_external.recv_async().await {
match msg {
ToAggregator::ConnectedToTelemetryCore => {
// Take hold of the connection closers and run them all.
let closers = close_connections;
for (_, closer) in closers {
// if this fails, it probably means the connection has died already anyway.
let _ = closer.send_async(()).await;
}
// We've told everything to disconnect. Now, reset our state:
close_connections = HashMap::new();
to_local_id.clear();
muted.clear();
connected_to_telemetry_core = true;
log::info!("Connected to telemetry core");
}
ToAggregator::DisconnectedFromTelemetryCore => {
connected_to_telemetry_core = false;
log::info!("Disconnected from telemetry core");
}
ToAggregator::FromWebsocket(
conn_id,
FromWebsocket::Initialize { close_connection },
) => {
// We boot all connections on a reconnect-to-core to force new systemconnected
// messages to be sent. We could boot on muting, but need to be careful not to boot
// connections where we mute one set of messages it sends and not others.
close_connections.insert(conn_id, close_connection);
}
ToAggregator::FromWebsocket(
conn_id,
FromWebsocket::Add {
message_id,
ip,
node,
genesis_hash,
},
) => {
// Don't bother doing anything else if we're disconnected, since we'll force the
// node to reconnect anyway when the backend does:
if !connected_to_telemetry_core {
continue;
}
// Generate a new "local ID" for messages from this connection:
let local_id = to_local_id.assign_id((conn_id, message_id));
// Send the message to the telemetry core with this local ID:
let _ = tx_to_telemetry_core
.send_async(FromShardAggregator::AddNode {
ip,
node,
genesis_hash,
local_id,
})
.await;
}
ToAggregator::FromWebsocket(
conn_id,
FromWebsocket::Update {
message_id,
payload,
},
) => {
// Ignore incoming messages if we're not connected to the backend:
if !connected_to_telemetry_core {
continue;
}
// Get the local ID, ignoring the message if none match:
let local_id = match to_local_id.get_id(&(conn_id, message_id)) {
Some(id) => id,
None => continue,
};
// ignore the message if this node has been muted:
if muted.contains(&local_id) {
continue;
}
// Send the message to the telemetry core with this local ID:
let _ = tx_to_telemetry_core
.send_async(FromShardAggregator::UpdateNode { local_id, payload })
.await;
}
ToAggregator::FromWebsocket(conn_id, FromWebsocket::Remove { message_id }) => {
// Get the local ID, ignoring the message if none match:
let local_id = match to_local_id.get_id(&(conn_id, message_id)) {
Some(id) => id,
None => continue,
};
// Remove references to this single node:
to_local_id.remove_by_id(local_id);
muted.remove(&local_id);
// If we're not connected to the core, don't buffer up remove messages. The core will remove
// all nodes associated with this shard anyway, so the remove message would be redundant.
if connected_to_telemetry_core {
let _ = tx_to_telemetry_core
.send_async(FromShardAggregator::RemoveNode { local_id })
.await;
}
}
ToAggregator::FromWebsocket(disconnected_conn_id, FromWebsocket::Disconnected) => {
// Find all of the local IDs corresponding to the disconnected connection ID and
// remove them, telling Telemetry Core about them too. This could be more efficient,
// but the mapping isn't currently cached and it's not a super frequent op.
let local_ids_disconnected: Vec<_> = to_local_id
.iter()
.filter(|(_, &(conn_id, _))| disconnected_conn_id == conn_id)
.map(|(local_id, _)| local_id)
.collect();
close_connections.remove(&disconnected_conn_id);
for local_id in local_ids_disconnected {
to_local_id.remove_by_id(local_id);
muted.remove(&local_id);
// If we're not connected to the core, don't buffer up remove messages. The core will remove
// all nodes associated with this shard anyway, so the remove message would be redundant.
if connected_to_telemetry_core {
let _ = tx_to_telemetry_core
.send_async(FromShardAggregator::RemoveNode { local_id })
.await;
}
}
}
ToAggregator::FromTelemetryCore(FromTelemetryCore::Mute {
local_id,
reason: _,
}) => {
// Mute the local ID we've been told to:
muted.insert(local_id);
}
}
}
}
/// Return a sink that a node can send messages into to be handled by the aggregator.
pub fn subscribe_node(&self) -> impl Sink<FromWebsocket, Error = anyhow::Error> + Unpin {
// Assign a unique aggregator-local ID to each connection that subscribes, and pass
// that along with every message to the aggregator loop:
let conn_id: ConnId = self
.0
.conn_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let tx_to_aggregator = self.0.tx_to_aggregator.clone();
// Calling `send` on this Sink requires Unpin. There may be a nicer way than this,
// but pinning by boxing is the easy solution for now:
Box::pin(
tx_to_aggregator
.into_sink()
.with(move |msg| async move { Ok(ToAggregator::FromWebsocket(conn_id, msg)) }),
)
}
}

View File

@ -0,0 +1,50 @@
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
/// Keep track of nodes that have been blocked.
#[derive(Debug, Clone)]
pub struct BlockedAddrs(Arc<BlockAddrsInner>);
#[derive(Debug)]
struct BlockAddrsInner {
block_duration: Duration,
inner: Mutex<HashMap<IpAddr, (&'static str, Instant)>>,
}
impl BlockedAddrs {
/// Create a new block list. Nodes are blocked for the duration
/// provided here.
pub fn new(block_duration: Duration) -> BlockedAddrs {
BlockedAddrs(Arc::new(BlockAddrsInner {
block_duration,
inner: Mutex::new(HashMap::new()),
}))
}
/// Block a new address
pub fn block_addr(&self, addr: IpAddr, reason: &'static str) {
let now = Instant::now();
self.0.inner.lock().unwrap().insert(addr, (reason, now));
}
/// Find out whether an address has been blocked. If it has, a reason
/// will be returned. Else, we'll get None back. This function may also
/// perform cleanup if the item was blocked and the block has expired.
pub fn blocked_reason(&self, addr: &IpAddr) -> Option<&'static str> {
let mut map = self.0.inner.lock().unwrap();
let (reason, time) = match map.get(addr) {
Some(&(reason, time)) => (reason, time),
None => return None,
};
if time + self.0.block_duration < Instant::now() {
map.remove(addr);
None
} else {
Some(reason)
}
}
}

View File

@ -0,0 +1,125 @@
use bincode::Options;
use common::ws_client;
use futures::StreamExt;
#[derive(Clone, Debug)]
pub enum Message<Out> {
Connected,
Disconnected,
Data(Out),
}
/// Connect to the telemetry core, retrying the connection if we're disconnected.
/// - Sends `Message::Connected` and `Message::Disconnected` when the connection goes up/down.
/// - Returns a channel that allows you to send messages to the connection.
/// - Messages are all encoded/decoded to/from bincode, and so need to support being (de)serialized from
/// a non self-describing encoding.
///
/// Note: have a look at [`common::internal_messages`] to see the different message types exchanged
/// between aggregator and core.
pub async fn create_ws_connection_to_core<In, Out>(
telemetry_uri: http::Uri,
) -> (flume::Sender<In>, flume::Receiver<Message<Out>>)
where
In: serde::Serialize + Send + 'static,
Out: serde::de::DeserializeOwned + Send + 'static,
{
let (tx_in, rx_in) = flume::bounded::<In>(10);
let (tx_out, rx_out) = flume::bounded(10);
let mut is_connected = false;
tokio::spawn(async move {
loop {
// Throw away any pending messages from the incoming channel so that it
// doesn't get filled up and begin blocking while we're looping and waiting
// for a reconnection.
while let Ok(_) = rx_in.try_recv() {}
// Try to connect. If connection established, we serialize and forward messages
// to/from the core. If the external channels break, we end for good. If the internal
// channels break, we loop around and try connecting again.
match ws_client::connect(&telemetry_uri).await {
Ok(connection) => {
let (tx_to_core, mut rx_from_core) = connection.into_channels();
is_connected = true;
let tx_out = tx_out.clone();
if let Err(e) = tx_out.send_async(Message::Connected).await {
// If receiving end is closed, bail now.
log::warn!("Aggregator is no longer receiving messages from core; disconnecting (permanently): {}", e);
return;
}
// Loop, forwarding messages to and from the core until something goes wrong.
loop {
tokio::select! {
msg = rx_from_core.next() => {
let msg = match msg {
Some(Ok(msg)) => msg,
// No more messages from core? core WS is disconnected.
_ => {
log::warn!("No more messages from core: shutting down connection (will reconnect)");
break
}
};
let bytes = match msg {
ws_client::RecvMessage::Binary(bytes) => bytes,
ws_client::RecvMessage::Text(s) => s.into_bytes()
};
let msg = bincode::options()
.deserialize(&bytes)
.expect("internal messages must be deserializable");
if let Err(e) = tx_out.send_async(Message::Data(msg)).await {
log::error!("Aggregator is no longer receiving messages from core; disconnecting (permanently): {}", e);
return;
}
},
msg = rx_in.recv_async() => {
let msg = match msg {
Ok(msg) => msg,
Err(flume::RecvError::Disconnected) => {
log::error!("Aggregator is no longer sending messages to core; disconnecting (permanently)");
return
}
};
let bytes = bincode::options()
.serialize(&msg)
.expect("internal messages must be serializable");
let ws_msg = ws_client::SentMessage::Binary(bytes);
if let Err(e) = tx_to_core.unbounded_send(ws_msg) {
log::warn!("Unable to send message to core; shutting down connection (will reconnect): {}", e);
break;
}
}
};
}
}
Err(connect_err) => {
// Issue connecting? Wait and try again on the next loop iteration.
log::error!(
"Error connecting to websocker server (will reconnect): {}",
connect_err
);
}
}
if is_connected {
is_connected = false;
if let Err(e) = tx_out.send_async(Message::Disconnected).await {
log::error!("Aggregator is no longer receiving messages from core; disconnecting (permanently): {}", e);
return;
}
}
// Wait a little before we try to connect again.
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
(tx_in, rx_out)
}

View File

@ -0,0 +1,219 @@
//! A hash wrapper which can be deserialized from a hex string as well as from an array of bytes,
//! so that it can deal with the sort of inputs we expect from ghost nodes.
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Unexpected, Visitor};
use serde::ser::{Serialize, Serializer};
use std::fmt::{self, Debug, Display};
use std::str::FromStr;
/// We assume that hashes are 32 bytes long, and in practise that's currently true,
/// but in theory it doesn't need to be. We may need to be more dynamic here.
const HASH_BYTES: usize = 32;
/// Newtype wrapper for 32-byte hash values, implementing readable `Debug` and `serde::Deserialize`.
/// This can deserialize from a JSON string or array.
#[derive(Hash, PartialEq, Eq, Clone, Copy)]
pub struct Hash([u8; HASH_BYTES]);
impl From<Hash> for common::node_types::BlockHash {
fn from(hash: Hash) -> Self {
hash.0.into()
}
}
impl From<common::node_types::BlockHash> for Hash {
fn from(hash: common::node_types::BlockHash) -> Self {
Hash(hash.0)
}
}
struct HashVisitor;
impl<'de> Visitor<'de> for HashVisitor {
type Value = Hash;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(
"byte array of length 32, or hexadecimal string of 32 bytes beginning with 0x",
)
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
value
.parse()
.map_err(|_| de::Error::invalid_value(Unexpected::Str(value), &self))
}
fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
if value.len() == HASH_BYTES {
let mut hash = [0; HASH_BYTES];
hash.copy_from_slice(value);
return Ok(Hash(hash));
}
Hash::from_ascii(value)
.map_err(|_| de::Error::invalid_value(Unexpected::Bytes(value), &self))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut hash = [0u8; HASH_BYTES];
for (i, byte) in hash.iter_mut().enumerate() {
match seq.next_element()? {
Some(b) => *byte = b,
None => return Err(de::Error::invalid_length(i, &"an array of 32 bytes")),
}
}
if seq.next_element::<u8>()?.is_some() {
return Err(de::Error::invalid_length(33, &"an array of 32 bytes"));
}
Ok(Hash(hash))
}
}
impl Hash {
pub fn from_ascii(value: &[u8]) -> Result<Self, HashParseError> {
if !value.starts_with(b"0x") {
return Err(HashParseError::InvalidPrefix);
}
let mut hash = [0; HASH_BYTES];
hex::decode_to_slice(&value[2..], &mut hash).map_err(HashParseError::HexError)?;
Ok(Hash(hash))
}
}
impl FromStr for Hash {
type Err = HashParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
Hash::from_ascii(value.as_bytes())
}
}
impl<'de> Deserialize<'de> for Hash {
fn deserialize<D>(deserializer: D) -> Result<Hash, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_bytes(HashVisitor)
}
}
impl Serialize for Hash {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&self.0)
}
}
impl Display for Hash {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("0x")?;
let mut ascii = [0; HASH_BYTES * 2];
hex::encode_to_slice(self.0, &mut ascii)
.expect("Encoding 32 bytes into 64 bytes of ascii; qed");
f.write_str(std::str::from_utf8(&ascii).expect("ASCII hex encoded bytes can't fail; qed"))
}
}
impl Debug for Hash {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
#[derive(thiserror::Error, Debug)]
pub enum HashParseError {
#[error("Error parsing string into hex: {0}")]
HexError(hex::FromHexError),
#[error("Invalid hex prefix: expected '0x'")]
InvalidPrefix,
}
#[cfg(test)]
mod tests {
use super::Hash;
use bincode::Options;
const DUMMY: Hash = {
let mut hash = [0; 32];
hash[0] = 0xDE;
hash[1] = 0xAD;
hash[2] = 0xBE;
hash[3] = 0xEF;
Hash(hash)
};
#[test]
fn deserialize_json_hash_str() {
let json = r#""0xdeadBEEF00000000000000000000000000000000000000000000000000000000""#;
let hash: Hash = serde_json::from_str(json).unwrap();
assert_eq!(hash, DUMMY);
}
#[test]
fn deserialize_json_array() {
let json = r#"[222,173,190,239,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"#;
let hash: Hash = serde_json::from_str(json).unwrap();
assert_eq!(hash, DUMMY);
}
#[test]
fn deserialize_json_array_too_short() {
let json = r#"[222,173,190,239,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"#;
let res = serde_json::from_str::<Hash>(json);
assert!(res.is_err());
}
#[test]
fn deserialize_json_array_too_long() {
let json = r#"[222,173,190,239,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"#;
let res = serde_json::from_str::<Hash>(json);
assert!(res.is_err());
}
#[test]
fn bincode() {
let bytes = bincode::options().serialize(&DUMMY).unwrap();
let mut expected = [0; 33];
expected[0] = 32; // length
expected[1..].copy_from_slice(&DUMMY.0);
assert_eq!(bytes, &expected);
let deserialized: Hash = bincode::options().deserialize(&bytes).unwrap();
assert_eq!(DUMMY, deserialized);
}
}

View File

@ -0,0 +1,6 @@
//! This module contains the types we need to deserialize JSON messages from nodes
mod hash;
mod node_message;
pub use node_message::*;

View File

@ -0,0 +1,432 @@
//! The structs and enums defined in this module are largely identical to those
//! we'll use elsewhere internally, but are kept separate so that the JSON structure
//! is defined (almost) from just this file, and we don't have to worry about breaking
//! compatibility with the input data when we make changes to our internal data
//! structures (for example, to support bincode better).
use super::hash::Hash;
use common::node_message as internal;
use common::node_types;
use serde::Deserialize;
/// This struct represents a telemetry message sent from a node as
/// a JSON payload. Since JSON is self describing, we can use attributes
/// like serde(untagged) and serde(flatten) without issue.
///
/// Internally, we want to minimise the amount of data sent from shards to
/// the core node. For that reason, we use a non-self-describing serialization
/// format like bincode, which doesn't support things like `[serde(flatten)]` (which
/// internally wants to serialize to a map of unknown length) or `[serde(tag/untagged)]`
/// (which relies on the data to know which variant to deserialize to.)
///
/// So, this can be converted fairly cheaply into an enum we'll use internally
/// which is compatible with formats like bincode.
#[derive(Deserialize, Debug)]
#[serde(untagged)]
pub enum NodeMessage {
V1 {
#[serde(flatten)]
payload: Payload,
},
V2 {
id: NodeMessageId,
payload: Payload,
},
}
impl From<NodeMessage> for internal::NodeMessage {
fn from(msg: NodeMessage) -> Self {
match msg {
NodeMessage::V1 { payload } => internal::NodeMessage::V1 {
payload: payload.into(),
},
NodeMessage::V2 { id, payload } => internal::NodeMessage::V2 {
id,
payload: payload.into(),
},
}
}
}
#[derive(Deserialize, Debug)]
#[serde(tag = "msg")]
pub enum Payload {
#[serde(rename = "system.connected")]
SystemConnected(SystemConnected),
#[serde(rename = "system.interval")]
SystemInterval(SystemInterval),
#[serde(rename = "block.import")]
BlockImport(Block),
#[serde(rename = "notify.finalized")]
NotifyFinalized(Finalized),
#[serde(rename = "afg.authority_set")]
AfgAuthoritySet(AfgAuthoritySet),
#[serde(rename = "sysinfo.hwbench")]
HwBench(NodeHwBench),
}
impl From<Payload> for internal::Payload {
fn from(msg: Payload) -> Self {
match msg {
Payload::SystemConnected(m) => internal::Payload::SystemConnected(m.into()),
Payload::SystemInterval(m) => internal::Payload::SystemInterval(m.into()),
Payload::BlockImport(m) => internal::Payload::BlockImport(m.into()),
Payload::NotifyFinalized(m) => internal::Payload::NotifyFinalized(m.into()),
Payload::AfgAuthoritySet(m) => internal::Payload::AfgAuthoritySet(m.into()),
Payload::HwBench(m) => internal::Payload::HwBench(m.into()),
}
}
}
#[derive(Deserialize, Debug)]
pub struct SystemConnected {
pub genesis_hash: Hash,
#[serde(flatten)]
pub node: NodeDetails,
}
impl From<SystemConnected> for internal::SystemConnected {
fn from(msg: SystemConnected) -> Self {
internal::SystemConnected {
genesis_hash: msg.genesis_hash.into(),
node: msg.node.into(),
}
}
}
#[derive(Deserialize, Debug)]
pub struct SystemInterval {
pub peers: Option<u64>,
pub txcount: Option<u64>,
pub bandwidth_upload: Option<f64>,
pub bandwidth_download: Option<f64>,
pub finalized_height: Option<BlockNumber>,
pub finalized_hash: Option<Hash>,
#[serde(flatten)]
pub block: Option<Block>,
pub used_state_cache_size: Option<f32>,
}
impl From<SystemInterval> for internal::SystemInterval {
fn from(msg: SystemInterval) -> Self {
internal::SystemInterval {
peers: msg.peers,
txcount: msg.txcount,
bandwidth_upload: msg.bandwidth_upload,
bandwidth_download: msg.bandwidth_download,
finalized_height: msg.finalized_height,
finalized_hash: msg.finalized_hash.map(|h| h.into()),
block: msg.block.map(|b| b.into()),
used_state_cache_size: msg.used_state_cache_size,
}
}
}
#[derive(Deserialize, Debug)]
pub struct Finalized {
#[serde(rename = "best")]
pub hash: Hash,
pub height: Box<str>,
}
impl From<Finalized> for internal::Finalized {
fn from(msg: Finalized) -> Self {
internal::Finalized {
hash: msg.hash.into(),
height: msg.height,
}
}
}
#[derive(Deserialize, Debug)]
pub struct AfgAuthoritySet {
pub authority_id: Box<str>,
}
impl From<AfgAuthoritySet> for internal::AfgAuthoritySet {
fn from(msg: AfgAuthoritySet) -> Self {
internal::AfgAuthoritySet {
authority_id: msg.authority_id,
}
}
}
#[derive(Deserialize, Debug, Clone, Copy)]
pub struct Block {
#[serde(rename = "best")]
pub hash: Hash,
pub height: BlockNumber,
}
impl From<Block> for node_types::Block {
fn from(block: Block) -> Self {
node_types::Block {
hash: block.hash.into(),
height: block.height,
}
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct NodeSysInfo {
pub cpu: Option<Box<str>>,
pub memory: Option<u64>,
pub core_count: Option<u32>,
pub linux_kernel: Option<Box<str>>,
pub linux_distro: Option<Box<str>>,
pub is_virtual_machine: Option<bool>,
}
impl From<NodeSysInfo> for node_types::NodeSysInfo {
fn from(sysinfo: NodeSysInfo) -> Self {
node_types::NodeSysInfo {
cpu: sysinfo.cpu,
memory: sysinfo.memory,
core_count: sysinfo.core_count,
linux_kernel: sysinfo.linux_kernel,
linux_distro: sysinfo.linux_distro,
is_virtual_machine: sysinfo.is_virtual_machine,
}
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct NodeHwBench {
pub cpu_hashrate_score: u64,
pub memory_memcpy_score: u64,
pub disk_sequential_write_score: Option<u64>,
pub disk_random_write_score: Option<u64>,
}
impl From<NodeHwBench> for node_types::NodeHwBench {
fn from(hwbench: NodeHwBench) -> Self {
node_types::NodeHwBench {
cpu_hashrate_score: hwbench.cpu_hashrate_score,
memory_memcpy_score: hwbench.memory_memcpy_score,
disk_sequential_write_score: hwbench.disk_sequential_write_score,
disk_random_write_score: hwbench.disk_random_write_score,
}
}
}
impl From<NodeHwBench> for internal::NodeHwBench {
fn from(msg: NodeHwBench) -> Self {
internal::NodeHwBench {
cpu_hashrate_score: msg.cpu_hashrate_score,
memory_memcpy_score: msg.memory_memcpy_score,
disk_sequential_write_score: msg.disk_sequential_write_score,
disk_random_write_score: msg.disk_random_write_score,
}
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct NodeDetails {
pub chain: Box<str>,
pub name: Box<str>,
pub implementation: Box<str>,
pub version: Box<str>,
pub validator: Option<Box<str>>,
pub network_id: node_types::NetworkId,
pub startup_time: Option<Box<str>>,
pub target_os: Option<Box<str>>,
pub target_arch: Option<Box<str>>,
pub target_env: Option<Box<str>>,
pub sysinfo: Option<NodeSysInfo>,
pub ip: Option<Box<str>>,
}
impl From<NodeDetails> for node_types::NodeDetails {
fn from(mut details: NodeDetails) -> Self {
// Migrate old-style `version` to the split metrics.
// TODO: Remove this once everyone updates their nodes.
if details.target_os.is_none()
&& details.target_arch.is_none()
&& details.target_env.is_none()
{
if let Some((version, target_arch, target_os, target_env)) =
split_old_style_version(&details.version)
{
details.target_arch = Some(target_arch.into());
details.target_os = Some(target_os.into());
details.target_env = Some(target_env.into());
details.version = version.into();
}
}
node_types::NodeDetails {
chain: details.chain,
name: details.name,
implementation: details.implementation,
version: details.version,
validator: details.validator,
network_id: details.network_id,
startup_time: details.startup_time,
target_os: details.target_os,
target_arch: details.target_arch,
target_env: details.target_env,
sysinfo: details.sysinfo.map(|sysinfo| sysinfo.into()),
ip: details.ip,
}
}
}
type NodeMessageId = u64;
type BlockNumber = u64;
fn is_version_or_hash(name: &str) -> bool {
name.bytes().all(|byte| {
byte.is_ascii_digit()
|| byte == b'.'
|| byte == b'a'
|| byte == b'b'
|| byte == b'c'
|| byte == b'd'
|| byte == b'e'
|| byte == b'f'
})
}
/// Split an old style version string into its version + target_arch + target_os + target_arch parts.
fn split_old_style_version(version_and_target: &str) -> Option<(&str, &str, &str, &str)> {
// Old style versions are composed of the following parts:
// $version-$commit_hash-$arch-$os-$env
// where $commit_hash and $env are optional.
//
// For example these are all valid:
// 0.9.17-75dd6c7d0-x86_64-linux-gnu
// 0.9.17-75dd6c7d0-x86_64-linux
// 0.9.17-x86_64-linux-gnu
// 0.9.17-x86_64-linux
// 2.0.0-alpha.5-da487d19d-x86_64-linux
let mut iter = version_and_target.rsplit('-').take(3).skip(2);
// This will one of these: $arch, $commit_hash, $version
let item = iter.next()?;
let target_offset = if is_version_or_hash(item) {
item.as_ptr() as usize + item.len() + 1
} else {
item.as_ptr() as usize
} - version_and_target.as_ptr() as usize;
let version = version_and_target.get(0..target_offset - 1)?;
let mut target = version_and_target.get(target_offset..)?.split('-');
let target_arch = target.next()?;
let target_os = target.next()?;
let target_env = target.next().unwrap_or("");
Some((version, target_arch, target_os, target_env))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_v1() {
let json = r#"{
"msg":"notify.finalized",
"level":"INFO",
"ts":"2021-01-13T12:38:25.410794650+01:00",
"best":"0x031c3521ca2f9c673812d692fc330b9a18e18a2781e3f9976992f861fd3ea0cb",
"height":"50"
}"#;
assert!(
matches!(
serde_json::from_str::<NodeMessage>(json).unwrap(),
NodeMessage::V1 { .. },
),
"message did not match variant V1",
);
}
#[test]
fn message_v2() {
let json = r#"{
"id":1,
"ts":"2021-01-13T12:22:20.053527101+01:00",
"payload":{
"best":"0xcc41708573f2acaded9dd75e07dac2d4163d136ca35b3061c558d7a35a09dd8d",
"height":"209",
"msg":"notify.finalized"
}
}"#;
assert!(
matches!(
serde_json::from_str::<NodeMessage>(json).unwrap(),
NodeMessage::V2 { .. },
),
"message did not match variant V2",
);
}
#[test]
fn message_v2_tx_pool_import() {
// We should happily ignore any fields we don't care about.
let json = r#"{
"id":1,
"ts":"2021-01-13T12:22:20.053527101+01:00",
"payload":{
"foo":"Something",
"bar":123,
"wibble":"wobble",
"msg":"block.import",
"best":"0xcc41708573f2acaded9dd75e07dac2d4163d136ca35b3061c558d7a35a09dd8d",
"height": 1234
}
}"#;
assert!(
matches!(
serde_json::from_str::<NodeMessage>(json).unwrap(),
NodeMessage::V2 {
payload: Payload::BlockImport(Block { .. }),
..
},
),
"message did not match the expected output",
);
}
#[test]
fn split_old_style_version_works() {
let (version, target_arch, target_os, target_env) =
split_old_style_version("0.9.17-75dd6c7d0-x86_64-linux-gnu").unwrap();
assert_eq!(version, "0.9.17-75dd6c7d0");
assert_eq!(target_arch, "x86_64");
assert_eq!(target_os, "linux");
assert_eq!(target_env, "gnu");
let (version, target_arch, target_os, target_env) =
split_old_style_version("0.9.17-75dd6c7d0-x86_64-linux").unwrap();
assert_eq!(version, "0.9.17-75dd6c7d0");
assert_eq!(target_arch, "x86_64");
assert_eq!(target_os, "linux");
assert_eq!(target_env, "");
let (version, target_arch, target_os, target_env) =
split_old_style_version("0.9.17-x86_64-linux-gnu").unwrap();
assert_eq!(version, "0.9.17");
assert_eq!(target_arch, "x86_64");
assert_eq!(target_os, "linux");
assert_eq!(target_env, "gnu");
let (version, target_arch, target_os, target_env) =
split_old_style_version("0.9.17-x86_64-linux").unwrap();
assert_eq!(version, "0.9.17");
assert_eq!(target_arch, "x86_64");
assert_eq!(target_os, "linux");
assert_eq!(target_env, "");
let (version, target_arch, target_os, target_env) =
split_old_style_version("2.0.0-alpha.5-da487d19d-x86_64-linux").unwrap();
assert_eq!(version, "2.0.0-alpha.5-da487d19d");
assert_eq!(target_arch, "x86_64");
assert_eq!(target_os, "linux");
assert_eq!(target_env, "");
assert_eq!(split_old_style_version(""), None);
assert_eq!(split_old_style_version("a"), None);
assert_eq!(split_old_style_version("a-b"), None);
}
}

377
telemetry-shard/src/main.rs Normal file
View File

@ -0,0 +1,377 @@
#[warn(missing_docs)]
mod aggregator;
mod blocked_addrs;
mod connection;
mod json_message;
mod real_ip;
use std::{
collections::HashMap,
net::IpAddr,
time::{Duration, Instant},
};
use aggregator::{Aggregator, FromWebsocket};
use blocked_addrs::BlockedAddrs;
use common::byte_size::ByteSize;
use common::http_utils;
use common::node_message;
use common::node_message::NodeMessageId;
use common::rolling_total::RollingTotalBuilder;
use futures::{SinkExt, StreamExt};
use http::Uri;
use hyper::{Method, Response};
use simple_logger::SimpleLogger;
use structopt::StructOpt;
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const AUTHORS: &str = env!("CARGO_PKG_AUTHORS");
const NAME: &str = "Ghost Telemetry Backend Shard";
const ABOUT: &str = "This is the Telemetry Backend Shard that forwards the \
data sent by Ghost/Casper nodes to the Backend Core";
#[derive(StructOpt, Debug)]
#[structopt(name = NAME, version = VERSION, author = AUTHORS, about = ABOUT)]
struct Opts {
/// This is the socket address that this shard is listening to. This is restricted to
/// localhost (127.0.0.1) by default and should be fine for most use cases. If
/// you are using Telemetry in a container, you likely want to set this to '0.0.0.0:8000'
#[structopt(short = "l", long = "listen", default_value = "127.0.0.1:8001")]
socket: std::net::SocketAddr,
/// The desired log level; one of 'error', 'warn', 'info', 'debug' or 'trace', where
/// 'error' only logs errors and 'trace' logs everything.
#[structopt(long = "log", default_value = "info")]
log_level: log::LevelFilter,
/// Url to the Backend Core endpoint accepting shard connections
#[structopt(
short = "c",
long = "core",
default_value = "ws://127.0.0.1:8000/shard_submit/"
)]
core_url: Uri,
/// How many different nodes is a given connection to the /submit endpoint allowed to
/// tell us about before we ignore the rest?
///
/// This is important because without a limit, a single connection could exhaust
/// RAM by suggesting that it accounts for billions of nodes.
#[structopt(long, default_value = "20")]
max_nodes_per_connection: usize,
/// What is the maximum number of bytes per second, on average, that a connection from a
/// node is allowed to send to a shard before it gets booted. This is averaged over a
/// rolling window of 10 seconds, and so spikes beyond this limit are allowed as long as
/// the average traffic in the last 10 seconds falls below this value.
///
/// As a reference point, syncing a new Ghost node leads to a maximum of about 25k of
/// traffic on average (at least initially).
#[structopt(long, default_value = "256k")]
max_node_data_per_second: ByteSize,
/// How many seconds is a "/feed" connection that violates the '--max-node-data-per-second'
/// value prevented from reconnecting to this shard for, in seconds.
#[structopt(long, default_value = "600")]
node_block_seconds: u64,
/// Number of worker threads to spawn. If "0" is given, use the number of CPUs available
/// on the machine. If no value is given, use an internal default that we have deemed sane.
#[structopt(long)]
worker_threads: Option<usize>,
/// Roughly how long to wait in seconds for new telemetry data to arrive from a node. If
/// telemetry for a node does not arrive in this time frame, we remove the corresponding node
/// state, and if no messages are received on the connection at all in this time, it will be
/// dropped.
#[structopt(long, default_value = "60")]
stale_node_timeout: u64,
}
fn main() {
let opts = Opts::from_args();
SimpleLogger::new()
.with_level(opts.log_level)
.init()
.expect("Must be able to start a logger");
log::info!("Starting Telemetry Shard version: {}", VERSION);
let worker_threads = match opts.worker_threads {
Some(0) => num_cpus::get(),
Some(n) => n,
// By default, use a max of 4 worker threads, as we don't
// expect to need a lot of parallelism in shards.
None => usize::min(num_cpus::get(), 4),
};
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(worker_threads)
.thread_name("telemetry_shard_worker")
.build()
.unwrap()
.block_on(async {
if let Err(e) = start_server(opts).await {
log::error!("Error starting server: {}", e);
}
});
}
/// Declare our routes and start the server.
async fn start_server(opts: Opts) -> anyhow::Result<()> {
let block_list = BlockedAddrs::new(Duration::from_secs(opts.node_block_seconds));
let aggregator = Aggregator::spawn(opts.core_url).await?;
let socket_addr = opts.socket;
let max_nodes_per_connection = opts.max_nodes_per_connection;
let bytes_per_second = opts.max_node_data_per_second;
let stale_node_timeout = Duration::from_secs(opts.stale_node_timeout);
let server = http_utils::start_server(socket_addr, move |addr, req| {
let aggregator = aggregator.clone();
let block_list = block_list.clone();
async move {
match (req.method(), req.uri().path().trim_end_matches('/')) {
// Check that the server is up and running:
(&Method::GET, "/health") => Ok(Response::new("OK".into())),
// Nodes send messages here:
(&Method::GET, "/submit") => {
let (real_addr, real_addr_source) = real_ip::real_ip(addr, req.headers());
if let Some(reason) = block_list.blocked_reason(&real_addr) {
return Ok(Response::builder().status(403).body(reason.into()).unwrap());
}
Ok(http_utils::upgrade_to_websocket(
req,
move |ws_send, ws_recv| async move {
log::info!(
"Opening /submit connection from {:?} (address source: {})",
real_addr,
real_addr_source
);
let tx_to_aggregator = aggregator.subscribe_node();
let (mut tx_to_aggregator, mut ws_send) =
handle_node_websocket_connection(
real_addr,
ws_send,
ws_recv,
tx_to_aggregator,
max_nodes_per_connection,
bytes_per_second,
block_list,
stale_node_timeout,
)
.await;
log::info!(
"Closing /submit connection from {:?} (address source: {})",
real_addr,
real_addr_source
);
// Tell the aggregator that this connection has closed, so it can tidy up.
let _ = tx_to_aggregator.send(FromWebsocket::Disconnected).await;
let _ = ws_send.close().await;
},
))
}
// 404 for anything else:
_ => Ok(Response::builder()
.status(404)
.body("Not found".into())
.unwrap()),
}
}
});
server.await?;
Ok(())
}
/// This takes care of handling messages from an established socket connection.
async fn handle_node_websocket_connection<S>(
real_addr: IpAddr,
ws_send: http_utils::WsSender,
mut ws_recv: http_utils::WsReceiver,
mut tx_to_aggregator: S,
max_nodes_per_connection: usize,
bytes_per_second: ByteSize,
block_list: BlockedAddrs,
stale_node_timeout: Duration,
) -> (S, http_utils::WsSender)
where
S: futures::Sink<FromWebsocket, Error = anyhow::Error> + Unpin + Send + 'static,
{
// Keep track of the message Ids that have been "granted access". We allow a maximum of
// `max_nodes_per_connection` before ignoring others.
let mut allowed_message_ids = HashMap::<NodeMessageId, Instant>::new();
// Limit the number of bytes based on a rolling total and the incoming bytes per second
// that has been configured via the CLI opts.
let bytes_per_second = bytes_per_second.num_bytes();
let mut rolling_total_bytes = RollingTotalBuilder::new()
.granularity(Duration::from_secs(1))
.window_size_multiple(10)
.start();
// This could be a oneshot channel, but it's useful to be able to clone
// messages, and we can't clone oneshot channel senders.
let (close_connection_tx, close_connection_rx) = flume::bounded(1);
// Tell the aggregator about this new connection, and give it a way to close this connection:
let init_msg = FromWebsocket::Initialize {
close_connection: close_connection_tx.clone(),
};
if let Err(e) = tx_to_aggregator.send(init_msg).await {
log::error!("Shutting down websocket connection from {real_addr:?}: Error sending message to aggregator: {e}");
return (tx_to_aggregator, ws_send);
}
// Receiving data isn't cancel safe, so let it happen in a separate task.
// If this loop ends, the outer will receive a `None` message and end too.
// If the outer loop ends, it fires a msg on `close_connection_rx` to ensure this ends too.
let (ws_tx_atomic, mut ws_rx_atomic) = futures::channel::mpsc::unbounded();
tokio::task::spawn(async move {
loop {
let mut bytes = Vec::new();
tokio::select! {
// The close channel has fired, so end the loop. `ws_recv.receive_data` is
// *not* cancel safe, but since we're closing the connection we don't care.
_ = close_connection_rx.recv_async() => {
log::info!("connection to {real_addr:?} being closed");
break
},
// Receive data and relay it on to our main select loop below.
msg_info = ws_recv.receive_data(&mut bytes) => {
if let Err(soketto::connection::Error::Closed) = msg_info {
break;
}
if let Err(e) = msg_info {
log::error!("Shutting down websocket connection from {real_addr:?}: Failed to receive data: {e}");
break;
}
if ws_tx_atomic.unbounded_send(bytes).is_err() {
// The other end closed; end this loop.
break;
}
}
}
}
});
// A periodic interval to check for stale nodes.
let mut stale_interval = tokio::time::interval(stale_node_timeout / 2);
// Our main select loop atomically receives and handles telemetry messages from the node,
// and periodically checks for stale connections to keep our node state tidy.
loop {
tokio::select! {
// We periodically check for stale message IDs and remove nodes associated with
// them, to prevent a buildup. We boot the whole connection if no interpretable
// messages have been sent at all in the time period.
_ = stale_interval.tick() => {
let stale_ids: Vec<NodeMessageId> = allowed_message_ids.iter()
.filter(|(_, last_seen)| last_seen.elapsed() > stale_node_timeout)
.map(|(&id, _)| id)
.collect();
for &message_id in &stale_ids {
log::info!("Removing stale node with message ID {message_id} from {real_addr:?}");
allowed_message_ids.remove(&message_id);
let _ = tx_to_aggregator.send(FromWebsocket::Remove { message_id } ).await;
}
if !stale_ids.is_empty() && allowed_message_ids.is_empty() {
// End the entire connection if no recent messages came in for any ID.
log::info!("Closing stale connection from {real_addr:?}");
break;
}
},
// Handle messages received by the connected node.
msg = ws_rx_atomic.next() => {
// No more messages? break.
let bytes = match msg {
Some(bytes) => bytes,
None => { break; }
};
// Keep track of total bytes and bail if average over last 10 secs exceeds preference.
rolling_total_bytes.push(bytes.len());
let this_bytes_per_second = rolling_total_bytes.total() / 10;
if this_bytes_per_second > bytes_per_second {
block_list.block_addr(real_addr, "Too much traffic");
log::error!("Shutting down websocket connection: Too much traffic ({this_bytes_per_second}bps averaged over last 10s)");
break;
}
// Deserialize from JSON, warning in debug mode if deserialization fails:
let node_message: json_message::NodeMessage = match serde_json::from_slice(&bytes) {
Ok(node_message) => node_message,
#[cfg(debug)]
Err(e) => {
let bytes: &[u8] = bytes.get(..512).unwrap_or_else(|| &bytes);
let msg_start = std::str::from_utf8(bytes).unwrap_or_else(|_| "INVALID UTF8");
log::warn!("Failed to parse node message ({msg_start}): {e}");
continue;
},
#[cfg(not(debug))]
Err(_) => {
continue;
}
};
// Pull relevant details from the message:
let node_message: node_message::NodeMessage = node_message.into();
let message_id = node_message.id();
let payload = node_message.into_payload();
// Until the aggregator receives an `Add` message, which we can create once
// we see one of these SystemConnected ones, it will ignore messages with
// the corresponding message_id.
if let node_message::Payload::SystemConnected(info) = payload {
// Too many nodes seen on this connection? Ignore this one.
if allowed_message_ids.len() >= max_nodes_per_connection {
log::info!("Ignoring new node with ID {message_id} from {real_addr:?} (we've hit the max of {max_nodes_per_connection} nodes per connection)");
continue;
}
// Note of the message ID, allowing telemetry for it.
let prev_join_time = allowed_message_ids.insert(message_id, Instant::now());
if prev_join_time.is_some() {
log::info!("Ignoring duplicate new node with ID {message_id} from {real_addr:?}");
continue;
}
// Tell the aggregator loop about the new node.
log::info!("Adding node with message ID {message_id} from {real_addr:?}");
let _ = tx_to_aggregator.send(FromWebsocket::Add {
message_id,
ip: real_addr,
node: info.node,
genesis_hash: info.genesis_hash,
}).await;
}
// Anything that's not an "Add" is an Update. The aggregator will ignore
// updates against a message_id that hasn't first been Added, above.
else {
if let Some(last_seen) = allowed_message_ids.get_mut(&message_id) {
*last_seen = Instant::now();
if let Err(e) = tx_to_aggregator.send(FromWebsocket::Update { message_id, payload } ).await {
log::error!("Failed to send node message to aggregator: {e}");
continue;
}
} else {
log::info!("Ignoring message with ID {message_id} from {real_addr:?} (we've hit the max of {max_nodes_per_connection} nodes per connection)");
continue;
}
}
}
}
}
// Make sure to kill off the receive-messages task if the main select loop ends:
let _ = close_connection_tx.send(());
// Return what we need to close the connection gracefully:
(tx_to_aggregator, ws_send)
}

View File

@ -0,0 +1,163 @@
use std::net::{IpAddr, SocketAddr};
/**
Extract the "real" IP address of the connection by looking at headers
set by proxies (this is inspired by Actix Web's implementation of the feature).
First, check for the standardised "Forwarded" header. This looks something like:
"Forwarded: for=12.34.56.78;host=example.com;proto=https, for=23.45.67.89"
Each proxy can append to this comma separated list of forwarded-details. We'll look for
the first "for" address and try to decode that.
If this doesn't yield a result, look for the non-standard but common X-Forwarded-For header,
which contains a comma separated list of addresses; each proxy in the potential chain possibly
appending one to the end. So, take the first of these if it exists.
If still no luck, look for the X-Real-IP header, which we expect to contain a single IP address.
If that _still_ doesn't work, fall back to the socket address of the connection.
*/
pub fn real_ip(addr: SocketAddr, headers: &hyper::HeaderMap) -> (IpAddr, Source) {
let forwarded = headers.get("forwarded").and_then(header_as_str);
let forwarded_for = headers.get("x-forwarded-for").and_then(header_as_str);
let real_ip = headers.get("x-real-ip").and_then(header_as_str);
pick_best_ip_from_options(forwarded, forwarded_for, real_ip, addr)
}
/// The source of the address returned
pub enum Source {
ForwardedHeader,
XForwardedForHeader,
XRealIpHeader,
SocketAddr,
}
impl std::fmt::Display for Source {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Source::ForwardedHeader => write!(f, "'Forwarded' header"),
Source::XForwardedForHeader => write!(f, "'X-Forwarded-For' header"),
Source::XRealIpHeader => write!(f, "'X-Real-Ip' header"),
Source::SocketAddr => write!(f, "Socket address"),
}
}
}
fn header_as_str(value: &hyper::header::HeaderValue) -> Option<&str> {
std::str::from_utf8(value.as_bytes()).ok()
}
fn pick_best_ip_from_options(
// Forwarded header value (if present)
forwarded: Option<&str>,
// X-Forwarded-For header value (if present)
forwarded_for: Option<&str>,
// X-Real-IP header value (if present)
real_ip: Option<&str>,
// socket address (if known)
addr: SocketAddr,
) -> (IpAddr, Source) {
let realip = forwarded
.as_ref()
.and_then(|val| {
let addr = get_first_addr_from_forwarded_header(val)?;
Some((addr, Source::ForwardedHeader))
})
.or_else(|| {
// fall back to X-Forwarded-For
forwarded_for.as_ref().and_then(|val| {
let addr = get_first_addr_from_x_forwarded_for_header(val)?;
Some((addr, Source::XForwardedForHeader))
})
})
.or_else(|| {
// fall back to X-Real-IP
real_ip.as_ref().and_then(|val| {
let addr = val.trim();
Some((addr, Source::XRealIpHeader))
})
})
.and_then(|(ip, source)| {
// Try parsing assuming it may have a port first,
// and then assuming it doesn't.
let addr = ip
.parse::<SocketAddr>()
.map(|s| s.ip())
.or_else(|_| ip.parse::<IpAddr>())
.ok()?;
Some((addr, source))
})
// Fall back to local IP address if the above fails
.unwrap_or((addr.ip(), Source::SocketAddr));
realip
}
/// Follow <https://datatracker.ietf.org/doc/html/rfc7239> to decode the Forwarded header value.
/// Roughly, proxies can add new sets of values by appending a comma to the existing list
/// (so we have something like "values1, values2, values3" from proxy1, proxy2 and proxy3 for
/// instance) and then the values themselves are ';' separated name=value pairs. The value in each
/// pair may or may not be surrounded in double quotes.
///
/// Examples from the RFC:
///
/// ```text
/// Forwarded: for="_gazonk"
/// Forwarded: For="[2001:db8:cafe::17]:4711"
/// Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43
/// Forwarded: for=192.0.2.43, for=198.51.100.17
/// ```
fn get_first_addr_from_forwarded_header(value: &str) -> Option<&str> {
let first_values = value.split(',').next()?;
for pair in first_values.split(';') {
let mut parts = pair.trim().splitn(2, '=');
let key = parts.next()?;
let value = parts.next()?;
if key.to_lowercase() == "for" {
// trim double quotes if they surround the value:
let value = if value.starts_with('"') && value.ends_with('"') {
&value[1..value.len() - 1]
} else {
value
};
return Some(value);
}
}
None
}
fn get_first_addr_from_x_forwarded_for_header(value: &str) -> Option<&str> {
value.split(",").map(|val| val.trim()).next()
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn get_addr_from_forwarded_rfc_examples() {
let examples = vec![
(r#"for="_gazonk""#, "_gazonk"),
(
r#"For="[2001:db8:cafe::17]:4711""#,
"[2001:db8:cafe::17]:4711",
),
(r#"for=192.0.2.60;proto=http;by=203.0.113.43"#, "192.0.2.60"),
(r#"for=192.0.2.43, for=198.51.100.17"#, "192.0.2.43"),
];
for (value, expected) in examples {
assert_eq!(
get_first_addr_from_forwarded_header(value),
Some(expected),
"Header value: {}",
value
);
}
}
}