Is it possible to calculate the return type of a Rust function or trait method based on its arguments? - arguments

Can I achieve something similar to boost::math::tools::promote_args in Rust? See also Idiomatic C++11 type promotion
To be more specific: is it possible to calculate the return type of a function or trait method based on its arguments and ensure, that the return type has the same type as one of the arguments?
Consider the following case. I have two structs:
#[derive(Debug, Clone, Copy)]
struct MySimpleType(f64);
#[derive(Debug, Clone, Copy)]
struct MyComplexType(f64, f64);
where MySimpleType can be promoted to MyComplexType via the From trait.
impl From<MySimpleType> for MyComplexType {
fn from(src: MySimpleType) -> MyComplexType {
let MySimpleType(x1) = src;
MyComplexType(x1, 0.0)
}
}
I want to write a function that takes two arguments of types MySimpleType or MyComplexType and return a value of type MySimpleType if all arguments are typed as MySimpleType, otherwise the function should return a value of type MyComplexType. Assuming I have implemented Add<Output=Self> for both types I could do something like this:
trait Foo<S, T> {
fn foo(s: S, t: T) -> Self;
}
impl<S, T, O> Foo<S, T> for O
where O: From<S> + From<T> + Add<Output = Self>
{
fn foo(s: S, t: T) -> Self {
let s: O = From::from(s);
let t: O = From::from(t);
s + t
}
}
but then the compiler doesn't know that O should be either S or T and I have to annotate most method calls.
My second attempt is to use a slightly different trait and write two implementations:
trait Foo<S, T> {
fn foo(s: S, t: T) -> Self;
}
impl Foo<MySimpleType, MySimpleType> for MySimpleType {
fn foo(s: MySimpleType, t: MySimpleType) -> Self {
s + t
}
}
impl<S, T> Foo<S, T> for MyComplexType
where MyComplexType: From<S> + From<T>
{
fn foo(s: S, t: T) -> Self {
let s: MyComplexType = From::from(s);
let t: MyComplexType = From::from(t);
s + t
}
}
but again, the compiler isn't able to figure the return type of
Foo::foo(MySimpleType(1.0), MySimpleType(1.0))
The third attempt is something similar to the std::ops::{Add, Mul, ...}. Use an associated type and write a specific implementation for each possible combination of argument types
trait Foo<T> {
type Output;
fn foo(self, t: T) -> Self::Output;
}
impl<T: Add<Output=T>> Foo<T> for T {
type Output = Self;
fn foo(self, t: T) -> Self::Output {
self + t
}
}
impl Foo<MySimpleType> for MyComplexType {
type Output = Self;
fn foo(self, t: MySimpleType) -> Self::Output {
let t: Self = From::from(t);
self + t
}
}
impl Foo<MyComplexType> for MySimpleType {
type Output = MyComplexType;
fn foo(self, t: MyComplexType) -> Self::Output {
let s: MyComplexType = From::from(self);
s + t
}
}
This seems to be the best solution until one needs a function with n arguments. Because then one has to write 2^n - n + 1 impl statements. Of course, this gets even worse if more then two types being considered.
===
Edit:
In my code I've multiple nested function calls and I want to avoid non necessary type promotion, since the evaluation of the functions for the simple type is cheap and expensive for the complex type. By using #MatthieuM. 's proposed solution, this is not achieved. Please consider the following example
#![feature(core_intrinsics)]
use std::ops::Add;
trait Promote<Target> {
fn promote(self) -> Target;
}
impl<T> Promote<T> for T {
fn promote(self) -> T {
self
}
}
impl Promote<u64> for u32 {
fn promote(self) -> u64 {
self as u64
}
}
fn foo<Result, Left, Right>(left: Left, right: Right) -> Result
where Left: Promote<Result>,
Right: Promote<Result>,
Result: Add<Output = Result>
{
println!("============\nFoo called");
println!("Left: {}", unsafe { std::intrinsics::type_name::<Left>() });
println!("Right: {}",
unsafe { std::intrinsics::type_name::<Right>() });
println!("Result: {}",
unsafe { std::intrinsics::type_name::<Result>() });
left.promote() + right.promote()
}
fn bar<Result, Left, Right>(left: Left, right: Right) -> Result
where Left: Promote<Result>,
Right: Promote<Result>,
Result: Add<Output = Result>
{
left.promote() + right.promote()
}
fn baz<Result, A, B, C, D>(a: A, b: B, c: C, d: D) -> Result
where A: Promote<Result>,
B: Promote<Result>,
C: Promote<Result>,
D: Promote<Result>,
Result: Add<Output = Result>
{
let lhs = foo(a, b).promote();
let rhs = bar(c, d).promote();
lhs + rhs
}
fn main() {
let one = baz(1u32, 1u32, 1u64, 1u32);
println!("{}", one);
}

I would expect the simplest way to implement promotion is to create a Promote trait:
trait Promote<Target> {
fn promote(self) -> Target;
}
impl<T> Promote<T> for T {
fn promote(self) -> T { self }
}
Note: I provide a blanket implementation as all types can be promoted to themselves.
Using associated types is NOT an option here, because a single type can be promoted to multiple types; thus we just use a regular type parameter.
Using this, a simple example is:
impl Promote<u64> for u32 {
fn promote(self) -> u64 { self as u64 }
}
fn add<Result, Left, Right>(left: Left, right: Right) -> Result
where
Left: Promote<Result>,
Right: Promote<Result>,
Result: Add<Output = Result>
{
left.promote() + right.promote()
}
fn main() {
let one: u32 = add(1u32, 1u32);
let two: u64 = add(1u32, 2u64);
let three: u64 = add(2u64, 1u32);
let four: u64 = add(2u64, 2u64);
println!("{} {} {} {}", one, two, three, four);
}
The only issue is that in the case of two u32 arguments, the result type must be specified otherwise the compiler cannot choose between which possible Promote implementation to use: Promote<u32> or Promote<u64>.
I am not sure if this is an issue in practice, however, since at some point you should have a concrete type to anchor type inference. For example:
fn main() {
let v = vec![add(1u32, 1u32), add(1u32, 2u64)];
println!("{:?}", v);
}
compiles without type hint, because add(1u32, 2u64) can only be u64, and therefore since a Vec is a homogeneous collection, add(1u32, 1u32) has to return a u64 here.
As you experienced, though, sometimes you need the ability to direct the result beyond what type inference can handle. It's fine, you just need another trait for it:
trait PromoteTarget {
type Output;
}
impl<T> PromoteTarget for (T, T) {
type Output = T;
}
And then a little implementation:
impl PromoteTarget for (u32, u64) {
type Output = u64;
}
impl PromoteTarget for (u64, u32) {
type Output = u64;
}
With that out of the way, we can rewrite baz signature to correctly account for all intermediate types. Unfortunately I don't know any way to introduce aliases in a where clause, so brace yourself:
fn baz<Result, A, B, C, D>(a: A, b: B, c: C, d: D) -> Result
where
A: Promote<<(A, B) as PromoteTarget>::Output>,
B: Promote<<(A, B) as PromoteTarget>::Output>,
C: Promote<<(C, D) as PromoteTarget>::Output>,
D: Promote<<(C, D) as PromoteTarget>::Output>,
(A, B): PromoteTarget,
(C, D): PromoteTarget,
<(A, B) as PromoteTarget>::Output: Promote<Result> + Add<Output = <(A, B) as PromoteTarget>::Output>,
<(C, D) as PromoteTarget>::Output: Promote<Result> + Add<Output = <(C, D) as PromoteTarget>::Output>,
Result: Add<Output = Result>
{
let lhs = foo(a, b).promote();
let rhs = bar(c, d).promote();
lhs + rhs
}
Link to the playground here, so you can check the result:
============
Foo called
Left: u32
Right: u32
Result: u32
4

Related

Building clean and flexible binary trees in Rust

I'm using binary trees to create a simple computation graph. I understand that linked lists are a pain in Rust, but it's a very convenient data structure for what I'm doing. I tried using Box and Rc<RefCell> for the children nodes, but it didn't work out how I wanted, so I used unsafe:
use std::ops::{Add, Mul};
#[derive(Debug, Copy, Clone)]
struct MyStruct {
value: i32,
lchild: Option<*mut MyStruct>,
rchild: Option<*mut MyStruct>,
}
impl MyStruct {
unsafe fn print_tree(&mut self, set_to_zero: bool) {
if set_to_zero {
self.value = 0;
}
println!("{:?}", self);
let mut nodes = vec![self.lchild, self.rchild];
while nodes.len() > 0 {
let child;
match nodes.pop() {
Some(popped_child) => child = popped_child.unwrap(),
None => continue,
}
if set_to_zero {
(*child).value = 0;
}
println!("{:?}", *child);
if !(*child).lchild.is_none() {
nodes.push((*child).lchild);
}
if !(*child).rchild.is_none() {
nodes.push((*child).rchild);
}
}
println!("");
}
}
impl Add for MyStruct {
type Output = Self;
fn add(self, other: Self) -> MyStruct {
MyStruct{
value: self.value + other.value,
lchild: Some(&self as *const _ as *mut _),
rchild: Some(&other as *const _ as *mut _),
}
}
}
impl Mul for MyStruct {
type Output = Self;
fn mul(self, other: Self) -> Self {
MyStruct{
value: self.value * other.value,
lchild: Some(&self as *const _ as *mut _),
rchild: Some(&other as *const _ as *mut _),
}
}
}
fn main() {
let mut tree: MyStruct;
{
let a = MyStruct{ value: 10, lchild: None, rchild: None };
let b = MyStruct{ value: 20, lchild: None, rchild: None };
let c = a + b;
println!("c.value: {}", c.value); // 30
let mut d = a + b;
println!("d.value: {}", d.value); // 30
d.value = 40;
println!("d.value: {}", d.value); // 40
let mut e = c * d;
println!("e.value: {}", e.value); // 1200
unsafe {
e.print_tree(false); // correct values
e.print_tree(true); // all zeros
e.print_tree(false); // all zeros, everything is set correctly
}
tree = e;
}
unsafe { tree.print_tree(false); } // same here, only zeros
}
Link to the playground
I honestly don't mind that much using unsafe, but is there a safe way doing it? How bad is the use of unsafe here?
You can just box both of the children, since you have a unidirectional tree:
use std::ops::{Add, Mul};
use std::fmt;
#[derive(Clone)]
struct MyStruct {
value: i32,
lchild: Option<Box<MyStruct>>,
rchild: Option<Box<MyStruct>>,
}
impl fmt::Debug for MyStruct {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("MyStruct")
.field("value", &self.value)
.field("lchild", &self.lchild.as_deref())
.field("rchild", &self.rchild.as_deref())
.finish()
}
}
impl MyStruct {
fn print_tree(&mut self, set_to_zero: bool) {
if set_to_zero {
self.value = 0;
}
println!("MyStruct {{ value: {:?}, lchild: {:?}, rchild: {:?} }}", self.value, &self.lchild as *const _, &self.rchild as *const _);
if let Some(child) = &mut self.lchild {
child.print_tree(set_to_zero);
}
if let Some(child) = &mut self.rchild {
child.print_tree(set_to_zero);
}
}
}
impl Add for MyStruct {
type Output = Self;
fn add(self, other: Self) -> MyStruct {
MyStruct {
value: self.value + other.value,
lchild: Some(Box::new(self)),
rchild: Some(Box::new(other)),
}
}
}
impl Mul for MyStruct {
type Output = Self;
fn mul(self, other: Self) -> Self {
MyStruct {
value: self.value * other.value,
lchild: Some(Box::new(self)),
rchild: Some(Box::new(other)),
}
}
}
fn main() {
let tree = {
let a = MyStruct {
value: 10,
lchild: None,
rchild: None,
};
let b = MyStruct {
value: 20,
lchild: None,
rchild: None,
};
let c = a.clone() + b.clone();
println!("c.value: {}", c.value); // 30
let mut d = a.clone() + b.clone();
println!("d.value: {}", d.value); // 30
d.value = 40;
println!("d.value: {}", d.value); // 40
let mut e = c * d;
println!("e.value: {}", e.value); // 1200
println!("");
e.print_tree(false); // correct values
println!("");
e.print_tree(true); // all zeros
println!("");
e.print_tree(false); // all zeros, everything is set correctly
println!("");
e
};
dbg!(tree);
}
I implemented Debug manually and reimplemented print_tree recursively. I don't know if there is a way to implement print_tree as mutable like that without recursion, but it's certainly possible if you take &self instead (removing the set_to_zero stuff).
playground
Edit: Turns out it is possible to mutably iterate over the tree values without recursion. The following code is derived from the playground in this comment by #Shepmaster.
impl MyStruct {
fn zero_tree(&mut self) {
let mut node_stack = vec![self];
let mut value_stack = vec![];
// collect mutable references to each value
while let Some(MyStruct { value, lchild, rchild }) = node_stack.pop() {
value_stack.push(value);
if let Some(child) = lchild {
node_stack.push(child);
}
if let Some(child) = rchild {
node_stack.push(child);
}
}
// iterate over mutable references to values
for value in value_stack {
*value = 0;
}
}
}

How to Sort Vector of Structs by Chronos DateTime Field?

I have this struct
pub struct Items {
pub symbol: String,
pub price: f64,
pub date: DateTime<Utc>,
}
I have a vector of these structs. I would like to sort them by date. How would I go about doing that? I tried deriving PartialEq, Ord, Eq, etc... but Rust complains about the float fields.
The easiest way is to use one of the provides sort functions implemented for Vec like sort_by, sort_by_key, or sort_by_key_cached.
// Using sort_by
foo_items.sort_by(|a, b| a.date.cmp(&b.date));
// Using sort_by_key
foo_items.sort_by_key(|x| x.date);
// Using sort_by_key_cached (Faster if key is very large)
foo_items.sort_by_cached_key(|x| x.date);
And don't forget you always have the option to manually implement traits that are normally derived.
use std::cmp::Ordering;
impl PartialEq for Items {
fn eq(&self, other: &Self) -> bool {
// idk what symbol is, but use it anyway
self.symbol == other.symbol && self.date == other.date
}
}
impl Eq for Items {}
impl PartialOrd for Items {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.date.partial_cmp(&other.date)
}
}
impl Ord for Items {
fn cmp(&self, other: &Self) -> Ordering {
self.date.cmp(&other.date)
}
}
#[derive(Debug)]
pub struct S {
pub s: f64,
}
fn main() {
let mut v = vec![S{s:0.3}, S{s:1.3}, S{s:7.3}];
v.sort_by(|a, b| a.s.partial_cmp(&b.s).unwrap());
println!("{:#?}",v);
}

Idiomatic way to create static iterable collection of named structs?

What is the idiomatic way to create static iterable collection of named structs? I have n instances of a struct, where n is known at compile time and is less than 20. I would like to be able to iterate over all the entries and also be able to refer to each entry by a name instead of an index. All the data is known at compile time.
I could use an array or enum, along with hand written constants which map the labels to indexes; but this seems finicky.
fn common_behaviour(x: f64) {
print!("{}", x);
}
const ADD: usize = 0;
const SUBTRACT: usize = 1;
fn main () {
let mut foos: [f64; 2] = [0.0; 2];
foos[ADD] = 4.0;
foos[SUBTRACT] = 2.0;
for foo in &foos {
common_behaviour(*foo);
}
foos[ADD] += 1.0;
foos[SUBTRACT] -= 1.0;
}
Alternatively, I could just pay the performance cost and use a HashMap as the hashing overhead might not actually matter that much, but this seems suboptimal as well.
Perhaps, I could refactor my code to use function pointers instead special casing the different special cases.
fn common_behaviour(x: f64) {
print!("{}", x);
}
fn add(x: f64) -> f64 {
x + 1.0
}
fn subtract(x: f64) -> f64 {
x - 1.0
}
struct Foo {
data: f64,
special: fn(f64) -> f64
}
impl Foo {
fn new(data: f64, special: fn(f64) -> f64) -> Foo {
Foo { data, special }
}
}
fn main() {
let mut foos = [Foo::new(4.0, add), Foo::new(2.0, subtract)];
for foo in &mut foos {
common_behaviour(foo.data);
foo.data = (foo.special)(foo.data);
}
}
What is most idiomatic way to handle this situation?
Looking at:
fn main() {
let mut foos = [Foo::new(4.0, add), Foo::new(2.0, subtract)];
for foo in &mut foos {
common_behaviour(foo.data);
foo.data = (foo.special)(foo.data);
}
}
I see a Command Pattern struggling to emerge, and Rust is great at expressing this pattern, thanks to enum:
enum Foo {
Add(f64),
Sub(f64),
}
impl Foo {
fn apply(&mut self) {
match self {
Foo::Add(x) => {
Self::common(*x);
*x += 1.0;
},
Foo::Sub(x) => {
Self::common(*x);
*x -= 1.0;
},
}
}
fn common(x: f64) {
print!("{}", x);
}
}
And your example becomes:
fn main() {
let mut foos = [Foo::Add(4.0), Foo::Sub(2.0)];
for foo in &mut foos {
foo.apply();
}
}

How do I check if a slice is sorted?

How do I check if a slice is sorted?
Assuming a function that accepts a slice of i32, is there an idiomatic Rust way of checking if the slice is sorted?
fn is_sorted(data: &[i32]) -> bool {
// ...
}
Would it be possible to generalize the above method so that it would accept an iterator?
fn is_sorted<I>(iter: I)
where
I: Iterator,
I::Item: Ord,
{
// ...
}
I'd grab pairs of elements and assert they are all in ascending (or descending, depending on what you mean by "sorted") order:
fn is_sorted<T>(data: &[T]) -> bool
where
T: Ord,
{
data.windows(2).all(|w| w[0] <= w[1])
}
fn main() {
assert!(is_sorted::<u8>(&[]));
assert!(is_sorted(&[1]));
assert!(is_sorted(&[1, 2, 3]));
assert!(is_sorted(&[1, 1, 1]));
assert!(!is_sorted(&[1, 3, 2]));
assert!(!is_sorted(&[3, 2, 1]));
}
Ditto for generic iterators:
extern crate itertools; // 0.7.8
use itertools::Itertools;
fn is_sorted<I>(data: I) -> bool
where
I: IntoIterator,
I::Item: Ord + Clone,
{
data.into_iter().tuple_windows().all(|(a, b)| a <= b)
}
fn main() {
assert!(is_sorted(&[] as &[u8]));
assert!(is_sorted(&[1]));
assert!(is_sorted(&[1, 2, 3]));
assert!(is_sorted(&[1, 1, 1]));
assert!(!is_sorted(&[1, 3, 2]));
assert!(!is_sorted(&[3, 2, 1]));
}
See also:
Are there equivalents to slice::chunks/windows for iterators to loop over pairs, triplets etc?
In nightly Rust, there are unstable methods to accomplish this:
slice::is_sorted
slice::is_sorted_by
slice::is_sorted_by_key
Iterator::is_sorted
Iterator::is_sorted_by
Iterator::is_sorted_by_key
It is not necessary to have Clone for an iterator is_sorted implementation. Here is a no-dependency Rust implementation of is_sorted:
fn is_sorted<I>(data: I) -> bool
where
I: IntoIterator,
I::Item: Ord,
{
let mut it = data.into_iter();
match it.next() {
None => true,
Some(first) => it.scan(first, |state, next| {
let cmp = *state <= next;
*state = next;
Some(cmp)
}).all(|b| b),
}
}
One more using try_fold():
pub fn is_sorted<T: IntoIterator>(t: T) -> bool
where
<T as IntoIterator>::Item: std::cmp::PartialOrd,
{
let mut iter = t.into_iter();
if let Some(first) = iter.next() {
iter.try_fold(first, |previous, current| {
if previous > current {
Err(())
} else {
Ok(current)
}
})
.is_ok()
} else {
true
}
}

How do I return the string of an enum at a specific index?

What would be the easiest way to achieve this?
enum E {
A,
B,
C,
}
// Should return "A" for 0, "B" for 1 and "C" for 2
fn convert(i: u32) -> str {
// ???
}
You cannot return a str, but you can return a &str. Combine ideas from How do I match enum values with an integer? and Get enum as string:
#[macro_use]
extern crate strum_macros;
extern crate strum;
use strum::IntoEnumIterator;
#[derive(EnumIter, AsRefStr)]
enum E {
A,
B,
C,
}
fn main() {
let e = E::iter().nth(2);
assert_eq!(e.as_ref().map(|e| e.as_ref()), Some("C"));
}
It was necessary to add the #[derive(Debug)] attribute and define an implementation on the enum:
#[derive(Debug)]
#[derive(Copy, Clone)]
enum E {
A,
B,
C,
}
impl E {
fn from(t: u8) -> E {
assert(t<=enum_to_int(E::C), "Enum range check failed!");
let el: E = unsafe { std::mem::transmute(t) };
return el;
}
fn to_string(&self) -> String {
return format!("{:?}", self);
}
fn string_from_int(t: u8) -> String {
return E::from(t).to_string();
}
}
fn enum_to_int(el: &E) -> u8 {
*el as u8
}
It can then by used like this:
fn main() {
let s = E::string_from_int(3 as u8);
println!("{}", s);
}

Resources